import torch
import numpy as np
import matplotlib.pyplot as plt


def residual_vertical(xy,s):
    # returns sum of squared vertical residuals from a line to a set of points
    # xy are the points, s is the slope and intercept
    # s is a vector desccribing the line as [[slope],[intercept]]
    x = xy[0,:].unsqueeze(0);
    y = xy[1,:].unsqueeze(0);
    xh = torch.cat((x,torch.ones([1, xy.shape[1]])),0).transpose(0,1)
    return sum((xh.mm(s)-y.transpose(0,1))**2)

def residual_perpendicular(xy,s):
    # returns the sum of squares of perpendicular residuals from a line to a set of points
    # xy are the points
    # s is a vector desccribing the line as [[slope],[intercept]]
    diff=torch.tensor([0.0])
    for i in range(xy.shape[1]):
        x = xy[0,i]
        y = xy[1,i]
        nm = torch.sqrt(s[0]*s[0]+1)
        proj = (1/nm)*(x+s[0]*y-s[0]*i)
        proj_x = proj
        proj_y = s[0]*proj+s[1]
        diff += ((proj_x-x)**2) + ((proj_y-y)**2)
    return diff


# define some points
xy = torch.tensor([[1,-1],[2,2],[ 3,1], [4,4], [5, 3], [6,6]],dtype=torch.float32).transpose(0,1)
x = xy[0,:].unsqueeze(0);
y = xy[1,:].unsqueeze(0);
xh = torch.cat((x,torch.ones([1,xy.shape[1]])),0).transpose(0,1) # homogeneous version of x's

s=np.linalg.lstsq(xh.numpy(), y.transpose(0,1).numpy(), rcond=None)
# solve for slope+intercept using least squares (on vertical residual)
print('The line parameters found to minimize least squares of vertical residuals slope={} intercept={}'.format(s[0][0],s[0][1]))

# now use autograd to look at the derivative of the sum of squared vertical residuals wrt s
s = torch.tensor(s[0],requires_grad=True)


r = residual_vertical(xh,s)
print('sum of squared vertical residuals = {}'.format(r.data))

r.backward() # do backprop

s.grad  # the gradient of r with respect to s should be zero!

# and, pedantically, if we take a step in this (0) direction the residual stays unchanged
rate = 0.01
s.data.sub_(rate*s.grad)
s.grad.zero_()
r = residual_vertical(xh,s) 
print('sum of squared vertical residuals after a step in (0) gradient direction = {}'.format(r.data))


#
#  Now consider the perpendicular residuals:
#
r = residual_perpendicular(xy,s)

print('sum of squared perpendicular residuals = {}'.format(r.data))


r.backward() # do backprop

s.grad  # the gradient of r w.r.t. s is not zero
# let's take a step in the gradient direction:

rate = 0.01
s.data.sub_(rate*s.grad)


# and compute the resdidual for the new slope/intercept

r = residual_perpendicular(xy,s)
print('sum of squared perpendicular residuals after one step in gradient direction= {}'.format(r.data))


# now let's try doing this 100 times (very naive gradient descent optimization)
for i in range(100):
    _ = s.grad.zero_()
    r = residual_perpendicular(xy,s)
    r.backward()
    _ =s.data.sub_(rate*s.grad);

# and check how it works now
r = residual_perpendicular(xy,s)
print('sum of squared perpendicular residuals after 100 more stpes in gradient direction = {}'.format(r.data))
print('And the new line parameters  slope={} intercept={}'.format(s[0].detach(),s[1].detach()))
