'''
======================
3D surface (color map)
======================

Demonstrates plotting a 3D surface colored with the coolwarm color map.
The surface is made opaque by using antialiased=False.

Also demonstrates using the LinearLocator and custom formatting for the
z axis tick labels.
'''

from mpl_toolkits.mplot3d import Axes3D
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider, Button, RadioButtons
from matplotlib import colormaps
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import numpy as np

fig = plt.figure()
ax = fig.add_subplot(projection='3d')

plt.subplots_adjust(bottom=0.25)



#xi=[0,1]
#yi=[0,2]

xi=[0,1,2]
yi=[0,1,1]
n=len(xi)

cmap = colormaps.get_cmap('plasma')

def f(a,b):
    return sum([(a*xi[i]+b-yi[i])**2 for i in range(n)])

def dfa(a,b):
    return sum([2*xi[i]*(a*xi[i]+b-yi[i]) for i in range(n)])

def dfb(a,b):
    return sum([2*(a*xi[i]+b-yi[i]) for i in range(n)])

a=0
b=0
scale=0.1

a_s, b_s = a, b

axcolor = 'lightgoldenrodyellow'
axfreq = plt.axes([0.1, 0.15, 0.75, 0.03], facecolor=axcolor)
axamp = plt.axes([0.1, 0.1, 0.75, 0.03], facecolor=axcolor)

sl_a = Slider(axfreq, 'a', -1, 2, valinit=a_s)
sl_b = Slider(axamp, 'b', -1, 2, valinit=b_s)


A=[a]
B=[b]
valeurs=[(a,b)]
couts=[f(a,b)]

for i in range(100):
    a,b=a-dfa(a,b)*scale,b-dfb(a,b)*scale
    valeurs.append((a,b))
    A.append(a)
    B.append(b)
    couts.append(f(a,b))



cmax=max(couts)
cmin=min(couts)
#print(cmax,cmin)

#xmin,xmax=-0.5,1.5
#ymin,ymax=-0.5,1.5
marge = .1
xmin,xmax=min(A)-marge,max(A)+marge
ymin,ymax=min(B)-marge,max(B)+marge

#X1=np.linspace(xmin,xmax, 100)
#X2=np.linspace(ymin,ymax, 100)

# Make data.
#X = np.arange(xmin,xmax, 0.25)
#Y = np.arange(ymin,ymax, 0.25)
X = np.linspace(xmin,xmax,30)
Y = np.linspace(ymin,ymax, 30)
X, Y = np.meshgrid(X, Y)
#R = np.sqrt(X**2 + Y**2)
#Z = np.sin(R)
Z = f(X,Y)
# Plot the surface.
#surf = ax.plot_surface(X, Y, Z, cmap=cm.coolwarm,linewidth=0, antialiased=0)
wire = ax.plot_wireframe(X, Y, Z,rstride=10,cstride=10, color='darkmagenta')




delta=0.1

X1=np.linspace(xmin,xmax, 10)
X2=np.linspace(ymin,ymax, 10)
Ya=f(X1,b_s)
Yb=f(a_s,X2)
l_a, = ax.plot(X1, Ya, zs=b_s, zdir='y')
l_b, = ax.plot(X2, Yb, zs=a_s, zdir='x')

"""
for i,(a,b) in enumerate(valeurs):
    #Ya=f(X1,b)
    #Yb=f(a,X2)
    if i%1==0:
        X1=np.linspace(a-delta,a+delta, 10)
        X2=np.linspace(b-delta,b+delta, 10)
        #X1=np.linspace(xmin,xmax, 10)
        #X2=np.linspace(ymin,ymax, 10)
        Ya=f(X1,b)
        Yb=f(a,X2)
        coul=cmap((cmax-couts[i])/(cmax-cmin))
        #ax.plot(X1, Ya, zs=b, zdir='y', color=coul)
        #ax.plot(X2, Yb, zs=a, zdir='x', color=coul)
        #ax.plot([X1[0],X1[0],X1[-1],X1[-1]],[Ya[0],0,0,Ya[-1]],zs=b, zdir='y', color=coul)
        #ax.plot([X2[0],X2[0],X2[-1],X2[-1]],[Yb[0],0,0,Yb[-1]],zs=a, zdir='x', color=coul)
        ax.plot([a,a],[b,b],[couts[i],0],color=coul)

for i in range(len(A)-1):
    a1=A[i]
    b1=B[i]
    c1=couts[i]
    a2=A[i+1]
    b2=B[i+1]
    c2=couts[i+1]
    if (a2-a1)**2+(b2-b1)**2+(c2-c1)**2>0.25**2:
        Al=np.linspace(a1,a2, 10)
        Bl=np.linspace(b1,b2, 10)
        Cl=f(Al,Bl)
        ax.plot(Al,Bl,Cl,color=(0,0,0))
    else:
        ax.plot([a1,a2],[b1,b2],[c1,c2],color=(0,0,0))

#ax.plot(A,B,couts,color=(0,0,0))
ax.plot(A,B,0,color=(0.5,0.5,0.5))
ax.scatter(A,B,couts,color=(0,0,0))
"""

def update(val):
    a_s = sl_a.val
    b_s = sl_b.val
    X1=np.linspace(xmin,xmax, 10)
    X2=np.linspace(ymin,ymax, 10)
    Ya=f(X1,b_s)
    Yb=f(a_s,X2)
    l_a.set_data_3d(X1, [b_s]*len(X1), Ya)
    l_b.set_data_3d([a_s]*len(X2), X2, Yb)
    fig.canvas.draw_idle()

sl_a.on_changed(update)
sl_b.on_changed(update)


#cset = ax.contour(X, Y, Z, 50, cmap=cm.autumn)
#ax.clabel(cset, fontsize=9, inline=1)

# Customize the z axis.
#ax.set_zlim(-1.01, 100.01)
#ax.zaxis.set_major_locator(LinearLocator(50))
#ax.zaxis.set_major_formatter(FormatStrFormatter('%.02f'))

# Add a color bar which maps values to colors.
#fig.colorbar(surf, shrink=0.5, aspect=5)

plt.show()
