import pylab as plt
import numpy as np
import numpy.linalg as alg
import matplotlib as mpl
from matplotlib.widgets import Slider, Button

mpl.rcParams['text.usetex'] = False #True pour mettre du beau LateX mais ça fait laguer ou False c'est moche mais ça lag pas
mpl.rcParams['axes.titlesize'] = 15
mpl.rcParams['axes.labelsize'] = 30
mpl.rcParams['lines.linewidth'] = 3
mpl.rcParams['lines.markersize'] = 8
mpl.rcParams['xtick.labelsize'] = 20
mpl.rcParams['ytick.labelsize'] = 20
mpl.rcParams['legend.fontsize'] = 25

#Résolution de l'équation de Schrödinger pour une particule d'énergie E<V_0 dans un puit fini de largeur L et de hauteur V_0
#A revérifier parce que la résolution ne dépend pas de E, bizarre mais c'est peut être normal

##Le gros du calcul

# The parametrized function to be plotted
def to_plot(X,V0,E,L,n) : 
    #Le n du slider correspond aux nombre de 'bosses' de la fonction densité de probabilté 
    n=n-1 #On calcule alors le vrai nombre n associé à cette fonction
    #Paramètres physiques de valeurs non réelles pour ne pas s'embêter avec les puissances de 10
    hbar = 1
    m = 1
    
    if n%2 == 0 :
        #Résolution des vecteurs d'ondes et conditions aux limites pour les fonctions d'ondes paires
        def f(X) :
            #équation fonctionnelle à résoudre pur avoir le vecteur d'onde associé à n
            return abs(np.cos(X/2)) - X*hbar*(2*m*V0)**(-0.5)/L
        #on définit les bornes pour la résolution de l'équation fonctionnelle
        a = n*np.pi
        b = (n+1)*np.pi
        #résolution (pas trop opti mais ça marche vraiment bien)
        while (b-a)>1e-3 :
            m = (a+b)/2
            if f(a)*f(m) <= 0 :
                b = m
            else :
                a = m
        #Vecteurs d'ondes
        k=(a+b)/(2*L) #dans le puit
        alpha = k*np.tan(k*L/2) #à l'extérieur du puit
        #Conditions aux limites
        A2 = (np.cos(k*L/2)**2/alpha + L/2 + np.sin(k*L)/(2*k))**(-0.5)
        A1 = A2 * np.exp(alpha*L/2) * np.cos(k*L/2)
        #Calcul de la densité de proba dans le puit
        x2 = X[np.where(np.logical_and(X>=-L/2,X<=L/2))]
        y2 = np.abs(A2*np.cos(k*x2))**2
        
    else : 
        #Même chose pour les fonctions d'onde impaires
        def f(X) :
            return abs(np.sin(X/2)) - X*hbar*(2*m*V0)**(-0.5)/L
        a = n*np.pi
        b = (n+1)*np.pi
        while (b-a)>1e-3 :
            m = (a+b)/2
            if f(a)*f(m) <= 0 :
                b = m
            else :
                a = m
        k=(a+b)/(2*L)
        alpha = -k/(np.tan(k*L/2))
        A2 = (np.cos(k*L/2)**2/alpha + L/2 + np.cos(k*L)/(2*k))**(-0.5)
        A1 = A2 * np.exp(alpha*L/2) * np.sin(k*L/2)
        x2 = X[np.where(np.logical_and(X>=-L/2,X<=L/2))]
        y2 = np.abs(A2*np.sin(k*x2))**2
    
    #Portions de la densité de proba en dehors du puit
    x1 = X[np.where(X<-L/2)]
    y1 = np.abs(A1*np.exp(alpha*x1))**2
    x3 = X[np.where(X>L/2)]
    y3 = np.abs(A1*np.exp(-alpha*x3))**2
            
    return np.concatenate((y1,y2,y3))
    
X = np.linspace(-5,5,1000)

##Le plot avec sliders

# Define initial parameters
init_V0 = 20
init_E = 10
init_L = 4
init_n = 1

# Create the figure and the line that we will manipulate
fig, ax = plt.subplots()
y=to_plot(X, init_V0, init_E,init_L,init_n)
line, = plt.plot(X, y)
line2 = plt.axvline(init_L/2,linestyle='--',color='black')
line3 = plt.axvline(-init_L/2,linestyle='--',color='black')
ax.set_xlabel(r'$x$')
ax.set_ylabel('Densité de probabilité '+r'$|\phi(x)|^2$')
ax.set_xlim(-5,5)
ax.set_ylim(0,np.max(y)+0.1)

# Adjust the main plot to make room for the sliders
#plt.subplots_adjust(left=0.25, bottom=0.25)
plt.subplots_adjust(left=0.1,bottom=0.32,top=0.95)

# Make horizontal sliders to control the parameters.
ax_V0 = plt.axes([0.1, 0.07, 0.8, 0.03])
V0_slider = Slider(
    ax=ax_V0,
    label=r'$V_0$',
    valmin=0.1,
    valmax=30,
    valinit=init_V0,
)
V0_slider.label.set_size(20)
V0_slider.valtext.set_fontsize(20)

ax_E = plt.axes([0.1, 0.1, 0.8, 0.03])
E_slider = Slider(
    ax=ax_E,
    label=r'$E$',
    valmin=0.1,
    valmax=30,
    valinit=init_E,
)
E_slider.label.set_size(20)
E_slider.valtext.set_fontsize(20)

ax_L = plt.axes([0.1, 0.13, 0.8, 0.03])
L_slider = Slider(
    ax=ax_L,
    label=r'$L$',
    valmin=0.001,
    valmax=9,
    valinit=init_L,
)
L_slider.label.set_size(20)
L_slider.valtext.set_fontsize(20)

ax_n = plt.axes([0.1, 0.16, 0.8, 0.03])
n_slider = Slider(
    ax=ax_n,
    label=r'$n$',
    valmin=1,
    valmax=10,
    valinit=init_n,
    valstep=1,
)
n_slider.label.set_size(20)
n_slider.valtext.set_fontsize(20)

# The function to be called anytime a slider's value changes
def update(val):
    y = to_plot(X, V0_slider.val, E_slider.val, L_slider.val,n_slider.val )
    line.set_ydata(y)
    ax.set_ylim(0,np.max(y)+0.1)
    line2.set_xdata(L_slider.val/2)
    line3.set_xdata(-L_slider.val/2)
    fig.canvas.draw_idle()


# register the update function with each slider
V0_slider.on_changed(update)
E_slider.on_changed(update)
L_slider.on_changed(update)
n_slider.on_changed(update)

# Create a `matplotlib.widgets.Button` to reset the sliders to initial values.
resetax = plt.axes([0.8, 0.025, 0.1, 0.04])
button = Button(resetax, 'Reset', hovercolor='0.975')

def reset(event):
    V0_slider.reset()
    E_slider.reset()
    L_slider.reset()
    n_slider.reset()
button.on_clicked(reset)

plt.show()


