import numpy as np
import matplotlib.pyplot as plt
from scipy.integrate import odeint


def param(n):
    X = np.linspace(-np.pi,np.pi,1000)
    #X = [i for i in range(-100,100)]
    rho = lambda t : 1 + np.cos(n*t) + (np.sin(n*t))**2
    plt.plot([rho(t)*np.cos(t) for t in X],[rho(t)*np.sin(t) for t in X],label = str(n))
    plt.axis('equal')
    plt.legend()
    

##plt.figure('Param')
##for i in range(3,4):
##    param(i)
##plt.show()


a = 1
f = lambda alpha,t : -alpha + a * np.exp(-t)

##def Euler(f, alpha0, t0, dt, n):
##    Lalpha = [alpha0]
##    Lt = [t0]
##    alpha = alpha0
##    t = t0
##    tmax = t0 + n*dt
##    while Lt[-1] < tmax: # Pas de while car on connait le nb d'itération
##        alpha = alpha + dt*f(alpha,t)
##        t = t+dt
##        Lt.append(t)
##        Lalpha.append(alpha)
##    return Lt,Lalpha

def Euler(f, alpha0, t0, dt, n):
    Lalpha = [alpha0]
    Lt = [t0]
    alpha = alpha0
    t = t0
    for i in range(n):
        p = f(alpha,t)
        alpha = alpha + dt*p
        t = t+dt
        Lt.append(t)
        Lalpha.append(alpha)
    return Lt,Lalpha

alphaexact = lambda t : t*np.exp(-t)

##for i in range(3,-0,-1):
##    X,Y = Euler(f,0,0,10**(-i),10**(i+1))
##    plt.plot(X,Y,label = 'Euler : pas = 10^-' + str(i))
##
##Y2 = odeint(f,0,X)
##plt.plot(X,Y2,label = 'odeint')
##
##
##Yr = [alphaexact(t) for t in X]
##plt.plot(X,Yr,'*',label = 'solution')
##
##plt.legend()
##plt.show()

S0 = 762
I0 = 1
r = 2.18 * 10**(-3)
a = 0.44
def f2(S,I,t):
    return(-r*S*I,r*S*I-a*I)

def Euler2(f,S0,I0,t0,dt,n):
    Lt,LS,LI = [t0],[S0],[I0]
    t,S,I = t0,S0,I0
    for i in range(n):
        dS,dI = f(S,I,t)
        S,I,t = S+dt*dS,I+dI*dt,t+dt
        LS.append(S)
        LI.append(I)
        Lt.append(t)
    return(Lt,LS,LI)


def F(X,t):
    S,I = X[0,0],X[1,0]
    return(np.array([[-r*S*I],[r*S*I-a*I]]))

def EulerVect(F,X0,t0,dt,n):
    t = t0
    S,I = X0
    Lt, LS, LI = [t0], [S],[I]
    for i in range(n):
        X = [S,I]
        dS,dI = F(X,t)
        S,I,t = S+dt*dS,I+dI*dt,t+dt
        LS.append(S)
        LI.append(I)
        Lt.append(t)        
    return Lt,LS,LI

X,Y,Z = Euler2(f2,S0,I0,0,1/24/60,24*30*60)
plt.figure()
plt.plot(X,Y,label = 'Susceptibles')
plt.plot(X,Z,label = 'Infecté.e.s')
plt.legend()

X0 = np.array([[S0],[I0]])
T,LS,LI = EulerVect(F,X0,0,1/24/60,24*30*60)
plt.figure()
plt.plot(T,LS,label = 'Susceptibles')
plt.plot(T,LI,label = 'Infectés')
plt.legend()
plt.show()

LX = odeint(F,X0,T)
plt.figure()
plt.plot(T,LX,label = 'Susceptibles')
plt.legend()
plt.show()























