# Auteur : Clément de la Salle
# Agrégation de physique, ENS de Lyon, 2019-2020

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation
from time import time
from matplotlib.widgets import Slider, Button, RadioButtons

def f(x, R) :
    
    return np.sqrt(R ** 2 - x ** 2) - x * np.tan(x)

def g(x, R) :
    
    return np.sqrt(R ** 2 - x ** 2) + x / np.tan(x)


def change_puits(label) :
    
    global puits
    
    
    if label == 'Infini' :
        
        puits = 0
        
        axa0.set_visible(True)      
        axV0.set_visible(False)        
        axk.set_visible(False)
        
        tracer()
        plt.draw()
    
    elif label == 'Fini' :
        
        puits = 1
        
        axa0.set_visible(True)    
        axV0.set_visible(True)        
        axk.set_visible(False)
        
        tracer()
        plt.draw()
    
    elif label == 'Harmonique' :
        
        puits = 2
        
        axa0.set_visible(False)    
        axV0.set_visible(False)        
        axk.set_visible(True)
        
        tracer()
        plt.draw()
        

def change_a0(event) :
    
    global a0
    
    a0 = sa0.val
    tracer()

def change_V0(event) :
    
    global V0
    
    V0 = sV0.val
    tracer()
    
def change_k(event) :
    
    global k
    
    k = sk.val
    tracer()

def change_Lx(event) :
    
    global Lx
    
    Lx = sLx.val
    ax.set_xlim(-Lx/2, Lx/2)
    tracer()

def change_Ly(event) :
    
    global Ly
    
    Ly = sLy.val
    ax.set_ylim(-.05 * Ly, Ly)
    tracer()

def change_masse(event) :
    
    global m
    
    m = sm.val
    tracer()

def tracer() :
    
    global niveaux

    if puits == 0 :
        line.set_data([-a0 / 2, -a0 / 2, a0 / 2, a0 / 2], [Ly, 0, 0, Ly])    
        E = np.arange(1, 1000) ** 2 / (m * a0 ** 2)
        E = E[E < Ly]
        for n in niveaux :
            n.set_visible(False)
        niveaux = []
        for e in E :
            niveaux.append(ax.plot([-a0/2, a0/2], [e, e], color = (1, .5, 0))[0])
    
    elif puits == 1 :
        V1 = 4 / np.pi ** 2 * V0
        line.set_data([-Lx / 2, -a0 / 2, -a0 / 2, a0 / 2, a0 / 2, Lx / 2], [V1, V1, 0, 0, V1, V1])
        R = a0 * np.sqrt(m * V0)
        x = np.linspace(0, R, 1000)
        y1 = f(x, R)
        y2 = g(x, R)
        idx1 = np.argwhere(np.diff(np.sign(y1)) != 0).reshape(-1)
        idx2 = np.argwhere(np.diff(np.sign(y2)) != 0).reshape(-1)
        idx1 = idx1[np.abs(y1[idx1]) < R]
        idx2 = idx2[np.abs(y2[idx2]) < R]
        idx = np.concatenate((idx1, idx2))
        E = x[idx] ** 2 / (m * a0 ** 2 * np.pi ** 2) * 4
        # E = E[E < V0]
        for n in niveaux :
            n.set_visible(False)
        niveaux = []
        for e in E :
            niveaux.append(ax.plot([-a0/2, a0/2], [e, e], color = (1, .5, 0))[0])
    
    elif puits == 2 :
        X = np.linspace(-Lx/2, Lx/2, 1000)
        line.set_data(X, k * X ** 2)
        E = (np.arange(1000) + 1/2) * np.sqrt(k / m)
        E = E[E < Ly]
        for n in niveaux :
            n.set_visible(False)
        niveaux = []
        for e in E :
            x = np.sqrt(e / k)
            niveaux.append(ax.plot([-x, x], [e, e], color = (1, .5, 0))[0])
    


a0, a1 = 1, 1
m = 10
V0 = 1
puits = 0
Lx, Ly = 2, 5 
k = Lx


fig = plt.figure()
ax = fig.add_subplot(121, xlim = (-Lx / 2, Lx / 2), ylim = (-.05 * Ly, Ly))
position = list(ax.get_position().bounds)
position[2] += 0.2
ax.set_position(position)

line, = ax.plot([], [], linewidth = 5)
niveaux = []
tracer()

axpuits = fig.add_axes([.78, .5, 0.15, .1])
boutton_puits = RadioButtons(axpuits, ['Infini', 'Fini', 'Harmonique'], active = 0)
boutton_puits.on_clicked(change_puits)

axa0 = fig.add_axes([.77, .85, .15, .03])
sa0 = Slider(axa0, r'$a$', 0, Lx, valinit = a0)
sa0.on_changed(change_a0)

axV0 = fig.add_axes([.77, .75, .15, .03])
sV0 = Slider(axV0, r'$V_0$', 0, 5 * Ly, valinit = V0)
sV0.on_changed(change_V0)

axk = fig.add_axes([.77, .65, .15, .03])
sk = Slider(axk, r'$k$', 0, 10 * Lx, valinit = k)
sk.on_changed(change_k)

axLx = fig.add_axes([.77, .2, .15, .03])
sLx = Slider(axLx, r'$L_x$', 0, 10, valinit = Lx)
sLx.on_changed(change_Lx)

axLy = fig.add_axes([.77, .1, .15, .03])
sLy = Slider(axLy, r'$L_y$', 0, 10, valinit = Ly)
sLy.on_changed(change_Ly)

axm = fig.add_axes([.77, .3, .15, .03])
sm = Slider(axm, r'$m$', 0, 100, valinit = m)
sm.on_changed(change_masse)
     
axa0.set_visible(True)
axV0.set_visible(False)
axk.set_visible(False)

# mng = plt.get_current_fig_manager()       # Vous pouvez décommenter ça si vous utilisez 'TkAgg'
# mng.window.state('zoomed')
plt.show()