import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons

xi=[0,1,2]
yi=[0,1,1]
n=len(xi)

courbes = 1
nuage = 1
vue_adaptative = 1

def vrais_coeffs(xi, yi):
    n = len(xi)
    x_barre = sum(xi)/n
    y_barre = sum(yi)/n
    a = (sum([xi[i]*yi[i] for i in range(n)])/n-x_barre*y_barre)/(sum([xi[i]**2 for i in range(n)])/n-x_barre**2)
    b = y_barre - a*x_barre
    return a, b

def cout(xi, yi, a, b):
    return sum([(yi[i]-a*xi[i]-b)**2 for i in range(len(xi))])

def f(a, b, xi, yi):
    n = len(xi)
    return sum([(a*xi[i]+b-yi[i])**2 for i in range(n)])

def calculs_limx(a, b, larg_min=0.5):
    if 1.4*abs(a-b) < larg_min:
        m = (a+b)/2
        xmin = m-larg_min/2
        xmax = m+larg_min/2
        marge = .2*(xmax-xmin)
    else:
        xmin = min(a, b) - .2*abs(b-a)
        xmax = max(a, b) + .2*abs(b-a)
        marge = .2*abs(b-a)
    return xmin, xmax, marge

def dfa(a, b, xi=xi, yi=yi):
    n = len(xi)
    return sum([2*xi[i]*(a*xi[i]+b-yi[i]) for i in range(n)])

def dfb(a, b, xi=xi, yi=yi):
    n = len(xi)
    return sum([2*(a*xi[i]+b-yi[i]) for i in range(n)])

def tangente(a, x0, y0):
    return lambda x: a*(x-x0)+y0

a_b, b_b = vrais_coeffs(xi, yi)
print(a_b, b_b)
a = 0
b = 0

images_b = [a_b*x+b_b for x in xi]



if courbes and nuage:
    fig, axs = plt.subplots(1,2)
    ax = axs[1]
    ax2 = axs[0]
elif courbes:
    fig, ax2 = plt.subplots()
elif nuage:
    fig, ax = plt.subplots()

plt.subplots_adjust(bottom=0.25)

if nuage:
    l, = ax.plot(xi, [a*x+b for x in xi], lw=2)

    segments = []
    for i in range(len(xi)):
        x = xi[i]
        y = yi[i]
        yb = a*x+b
        li, = ax.plot([x,x], [y,yb], lw=2, ls=":", marker=".", color="red")
        segments.append(li)

    p = ax.scatter(xi, yi)
    ax.set_title(f"coût : {cout(xi,yi,a,b)}")

if courbes:
    if vue_adaptative:
        xmin, xmax, marge = calculs_limx(a, b)
    else:
        xmin, xmax, marge = -1, 2, 0.3
    nb_points = 100
    antecedents = [xmin + i*(xmax-xmin)/(nb_points-1) for i in range(nb_points)]
    c_a, = ax2.plot(antecedents, [f(x, b, xi, yi) for x in antecedents], label="$C_b(a)$")
    c_b, = ax2.plot(antecedents, [f(a, x, xi, yi) for x in antecedents], label="$C_a(b)$")
    pc_a, = ax2.plot([a,a],[0,f(a,b,xi,yi)], marker=".")
    pc_b, = ax2.plot([b,b],[0,f(a,b,xi,yi)], marker=".")
    tang_a = tangente(dfa(a,b,xi,yi), a, f(a,b,xi,yi))
    tg_a, = ax2.plot([a-marge,a+marge],[tang_a(a-marge),tang_a(a+marge)],'--',color="black")
    tang_b = tangente(dfb(a,b,xi,yi), b, f(a,b,xi,yi))
    tg_b, = ax2.plot([b-marge,b+marge],[tang_b(b-marge),tang_b(b+marge)],'--',color="black")
    #ax.set_ylim(ymin, ymax)
    ax2.legend()
    if not nuage:
            ax2.set_title(f"coût : {1.0*cout(xi,yi,a,b):.5}")
if nuage:
    plt.axis('equal')

axcolor = 'lightgoldenrodyellow'
axfreq = plt.axes([0.1, 0.15, 0.75, 0.03], facecolor=axcolor)
axamp = plt.axes([0.1, 0.1, 0.75, 0.03], facecolor=axcolor)

sl_a = Slider(axfreq, 'a', -1, 2, valinit=a)
sl_b = Slider(axamp, 'b', -1, 2, valinit=b)

def update(val):
    a = sl_a.val
    b = sl_b.val
    if nuage:
        images = [a*x+b for x in xi]
        l.set_ydata(images)
        for i in range(len(segments)):
            seg = segments[i]
            x = xi[i]
            y = yi[i]
            yb = a*x+b
            seg.set_ydata([y,yb])
            ax.relim()
            ax.autoscale()
            ax.set_title(f"coût : {1.0*cout(xi,yi,a,b):.5}")
    if courbes:
        if vue_adaptative:
            xmin, xmax, marge = calculs_limx(a, b)
        else:
            xmin, xmax, marge = -1, 2, 0.3
        antecedents = [xmin + i*(xmax-xmin)/(nb_points-1) for i in range(nb_points)]
        c_a.set_xdata(antecedents)
        c_b.set_xdata(antecedents)
        c_a.set_ydata([f(x, b, xi, yi) for x in antecedents])
        c_b.set_ydata([f(a, x, xi, yi) for x in antecedents])
        pc_a.set_xdata([a,a])
        pc_b.set_xdata([b,b])
        pc_a.set_ydata([0,f(a,b,xi,yi)])
        pc_b.set_ydata([0,f(a,b,xi,yi)])
        ax2.relim()
        ax2.autoscale()
        tang_a = tangente(dfa(a,b,xi,yi), a, f(a,b,xi,yi))
        tg_a.set_xdata([a-marge,a+marge])
        tg_a.set_ydata([tang_a(a-marge),tang_a(a+marge)])
        tang_b = tangente(dfb(a,b,xi,yi), b, f(a,b,xi,yi))
        tg_b.set_xdata([b-marge,b+marge])
        tg_b.set_ydata([tang_b(b-marge),tang_b(b+marge)])
        if not nuage:
            ax2.set_title(f"coût : {1.0*cout(xi,yi,a,b):.5}")

    fig.canvas.draw_idle()

sl_a.on_changed(update)
sl_b.on_changed(update)

plt.show()