import numpy as np
import matplotlib.pyplot as plt

K = 1
M = 1
k = 0.1

ws = np.sqrt(K/M)
wa = np.sqrt(K+2*k)/np.sqrt(M)

psi10 = 0
psi20 = -1

sigma0 = psi10+psi20
delta0 = psi10-psi20

phis = 0
phia = 0

T = np.linspace(0,100,3000)

sigma = lambda t : sigma0*np.cos(ws*t + phis)
dsigma = lambda t : -ws*sigma0*np.sin(ws*t + phis)
delta = lambda t : delta0*np.cos(wa*t + phia)
ddelta = lambda t : -wa*delta0*np.sin(wa*t + phia)

psi1 = lambda t : (sigma(t) + delta(t))/2
dpsi1 = lambda t : (dsigma(t) + ddelta(t))/2


psi2 = lambda t : (sigma(t) - delta(t))/2
dpsi2 = lambda t : (dsigma(t) - ddelta(t))/2

Em1 = lambda t : 0.5*M*dpsi1(t)**2 + 0.5*K*(psi1(t))**2 + 0.5*k*(-psi1(t)+psi2(t))**2
Em2 = lambda t : 0.5*M*dpsi2(t)**2 + 0.5*K*(psi2(t))**2 + 0.5*k*(psi2(t)-psi1(t))**2
Em = lambda t : Em1(t) + Em2(t) - 0.5*k*(psi2(t)-psi1(t))**2

fig = plt.figure() # initialise la figure
ax1 = plt.subplot(211)
ax2 = plt.subplot(212, sharex=ax1)


ax2.plot(T, Em1(T),label = "Em1", color = "blue"), plt.plot(T, Em2(T),label = "Em2", color = "red")
ax2.plot(T, Em(T),label = "Em", color = "green")
ax1.plot(T, psi1(T),label = "Psi1", color = "blue"), ax1.plot(T,psi2(T), label = "Psi2", color = "red")

ax2.set_xlabel("Temps")

ax1.set_ylabel("Positions")
ax2.set_ylabel("Energies mécaniques")


ax1.legend(), ax2.legend(), plt.show()