import numpy as np
import scipy.optimize as spo
import matplotlib.pyplot as plt
from math import *
from matplotlib.widgets import Button
from time import time
from matplotlib.animation import FuncAnimation

Tc = 1 / (2 * asinh(1))

class Ising :
    
    def __init__(self, Nx = 100, Ny = 100, T = 1 * Tc, B = 0) :
        
        self.etat = 2 * np.random.randint(0, 2, (Ny, Nx)) - 1
        self.T = T
        self.Mhist = [1 / (Nx * Ny) * self.etat.sum()]
        self.temps = [0]
        self.hist = np.array([self.etat])
        self.bool_anim = False
        self.B = B
        
        self.fig = plt.figure()
        
        self.ax = self.fig.add_subplot(121)
        self.ax.axes.get_xaxis().set_visible(False)
        self.ax.axes.get_yaxis().set_visible(False)
        self.im = self.ax.imshow(self.etat)
        
        self.courbes = self.fig.add_subplot(222)
        self.line, = self.courbes.plot(self.temps, self.Mhist)
        self.line2, = self.courbes.plot([0, self.temps[-1]], [0, 0], color = 'red')

        case_temps = plt.axes([0.7, 0.1, 0.15, 0.075])
        boutton_temps = Button(case_temps, '+ temps       +++ temps')
        boutton_temps.on_clicked(self.evol)
        
        self.ax.text(1.2, 0.08, r'$T = $' + str(round(self.T / Tc, 1)) + r'$ \times Tc$', transform = self.ax.transAxes, fontsize = 24)
               
        case_Tplus = plt.axes([0.55, 0.25, 0.1, 0.075])
        boutton_Tplus = Button(case_Tplus, '+ T       +++ T')
        boutton_Tplus.on_clicked(self.Tplus)

        case_Tmoins = plt.axes([0.55, 0.1, 0.1, 0.075])
        boutton_Tmoins = Button(case_Tmoins, '-T       --- T')
        boutton_Tmoins.on_clicked(self.Tmoins)

        case_reset = plt.axes([0.745, 0.25, 0.05, 0.075])
        boutton_reset = Button(case_reset, 'RESET')
        boutton_reset.on_clicked(self.reset)

        case_anim = plt.axes([0.875, 0.25, 0.07, 0.075])
        boutton_anim = Button(case_anim, 'Animer :')
        boutton_anim.on_clicked(self.animer)
        
        self.ax.text(2.2, 0.08, 'Off', transform = self.ax.transAxes, fontsize = 24)
               
        case_Bplus = plt.axes([0.015, 0.25, 0.1, 0.075])
        boutton_Bplus = Button(case_Bplus, '+ B       +++ B')
        boutton_Bplus.on_clicked(self.Bplus)

        case_Bmoins = plt.axes([0.015, 0.1, 0.1, 0.075])
        boutton_Bmoins = Button(case_Bmoins, '-B       --- B')
        boutton_Bmoins.on_clicked(self.Bmoins)
        
        self.ax.text(-.25, 0.08, r'$B = $' + str(round(self.B, 1)), transform = self.ax.transAxes, fontsize = 24)

        mng = plt.get_current_fig_manager()
        mng.window.state('zoomed')
        plt.show()
    
    
    def Bplus(self, event) :

        if event.xdata > 0.5 : self.B += 1
        else : self.B += 0.1
        self.ax.texts = []        
        self.ax.text(-.25, 0.08, r'$B = $' + str(round(self.B, 1)), transform = self.ax.transAxes, fontsize = 24)
        self.ax.text(1.2, 0.08, r'$T = $' + str(round(self.T / Tc, 1)) + r'$ \times Tc$', transform = self.ax.transAxes, fontsize = 24)
        if self.bool_anim : self.ax.text(2.2, 0.08, 'On', transform = self.ax.transAxes, fontsize = 24)
        else : self.ax.text(2.2, 0.08, 'Off', transform = self.ax.transAxes, fontsize = 24)
          
    def Bmoins(self, event) :

        if event.xdata > 0.5 : self.B -= 1
        else : self.B -= 0.1
        self.ax.texts = []        
        self.ax.text(-.25, 0.08, r'$B = $' + str(round(self.B, 1)), transform = self.ax.transAxes, fontsize = 24)
        self.ax.text(1.2, 0.08, r'$T = $' + str(round(self.T / Tc, 1)) + r'$ \times Tc$', transform = self.ax.transAxes, fontsize = 24)
        if self.bool_anim : self.ax.text(2.2, 0.08, 'On', transform = self.ax.transAxes, fontsize = 24)
        else : self.ax.text(2.2, 0.08, 'Off', transform = self.ax.transAxes, fontsize = 24)
    
    def Tplus(self, event) :

        if event.xdata > 0.5 : self.T += 1 * Tc
        else : self.T += 0.1 * Tc
        self.ax.texts = []
        self.ax.text(-.25, 0.08, r'$B = $' + str(round(self.B, 1)), transform = self.ax.transAxes, fontsize = 24)
        self.ax.text(1.2, 0.08, r'$T = $' + str(round(self.T / Tc, 1)) + r'$ \times Tc$', transform = self.ax.transAxes, fontsize = 24)
        if self.bool_anim : self.ax.text(2.2, 0.08, 'On', transform = self.ax.transAxes, fontsize = 24)
        else : self.ax.text(2.2, 0.08, 'Off', transform = self.ax.transAxes, fontsize = 24)
          
    def Tmoins(self, event) :

        if event.xdata > 0.5 : self.T -= 1 * Tc
        else : self.T -= 0.1 * Tc
        self.ax.texts = []
        self.ax.text(-.25, 0.08, r'$B = $' + str(round(self.B, 1)), transform = self.ax.transAxes, fontsize = 24)
        self.ax.text(1.2, 0.08, r'$T = $' + str(round(self.T / Tc, 1)) + r'$ \times Tc$', transform = self.ax.transAxes, fontsize = 24)
        if self.bool_anim : self.ax.text(2.2, 0.08, 'On', transform = self.ax.transAxes, fontsize = 24)
        else : self.ax.text(2.2, 0.08, 'Off', transform = self.ax.transAxes, fontsize = 24)
    
    def reset(self, event) :
        
        debut = time()
        
        Ny, Nx = self.etat.shape
        self.etat = 2 * np.random.randint(0, 2, (Ny, Nx)) - 1
        self.Mhist = [1 / (Nx * Ny) * self.etat.sum()]
        self.temps = [0]
        self.hist = np.array([self.etat])
        self.im = self.ax.imshow(self.etat)
        self.courbes.cla()
        self.line, = self.courbes.plot(self.temps, self.Mhist)
        self.line2, = self.courbes.plot([0, self.temps[-1]], [0, 0], color = 'red')
        self.bool_anim = False
        self.ax.texts = []
        self.ax.text(-.25, 0.08, r'$B = $' + str(round(self.B, 1)), transform = self.ax.transAxes, fontsize = 24)
        self.ax.text(1.2, 0.08, r'$T = $' + str(round(self.T / Tc, 1)) + r'$ \times Tc$', transform = self.ax.transAxes, fontsize = 24)
        self.ax.text(2.2, 0.08, 'Off', transform = self.ax.transAxes, fontsize = 24)
        
        # print(time() - debut)
        
    def evol(self, event, init = True) :
        
        Ny, Nx = self.etat.shape
        if event.xdata > 0.5 : init = 5 * Nx * Ny
        else : init = int(0.5 * Nx * Ny)
        
        debut = time()
        t_ini = self.temps[-1]
        
        # Algorithme de Metropolis
        for t in range(init) :    
            
            # On choisit au hasard un spin dans la matrice
            i, j = np.random.randint(Ny), np.random.randint(Nx)
            
            # Calcul du Hamiltonien autour de [i, j] avant inversion
            if i != Ny - 1 and j != Nx - 1 :
                H = - self.etat[i, j] * (self.etat[i, j + 1] + self.etat[i, j - 1] + self.etat[i + 1, j] + self.etat[i - 1, j] + self.B)
            elif i == Ny - 1 and j != Nx - 1 :
                H = - self.etat[i, j] * (self.etat[i, j + 1] + self.etat[i, j - 1] + self.etat[0, j] + self.etat[i - 1, j] + self.B)
            elif i != Ny - 1 and j == Nx - 1 :
                H = - self.etat[i, j] * (self.etat[i, 0] + self.etat[i, j - 1] + self.etat[i + 1, j] + self.etat[i - 1, j] + self.B)
            elif i == Ny - 1 and j == Nx - 1 :
                H = - self.etat[i, j] * (self.etat[i, 0] + self.etat[i, j - 1] + self.etat[0, j] + self.etat[i - 1, j] + self.B)
            
            # Si l'inversion est favorable (H > 0), le spin est inversé. Sinon, il a une probabilité exp(H / T) d'être inversé.
            # NB : Le vrai hamiltonien est la moitié de celui défini précédemment (sinon on compte deux fois chaque paire)
            # NB : Le hamiltonien autour de [i, j] après inversion est simplement - 1/2 H donc la variation énergétique introduite lors du changement est 2 * 1/2 H = H
            if H > 0 :
                self.etat[i, j] *= -1
            else :
                if np.random.uniform() < exp(H / self.T) :
                    self.etat[i, j] *= -1   
                         
            if self.bool_anim :        

                self.hist = np.concatenate((self.hist, [self.etat]))
                self.Mhist.append(1 / (Nx * Ny) * self.etat.sum())
                self.temps.append(t_ini + t + 1)
        
        if not self.bool_anim : 
        
            self.Mhist.append(1 / (Nx * Ny) * self.etat.sum())
            self.temps.append(self.temps[-1] + init)
            self.hist = np.concatenate((self.hist, [self.etat]))
    
            self.im.set_data(self.etat)
            self.courbes.set_xlim(0, self.temps[-1])
            etend = max(self.Mhist) - min(self.Mhist)
            self.courbes.set_ylim(min(min(self.Mhist) - 0.1 * etend, - 0.1 * etend), max(max(self.Mhist) + 0.1 * etend, 0.1 * etend))
            self.line.set_xdata(self.temps)
            self.line.set_ydata(self.Mhist)
            self.line2.set_xdata([0, self.temps[-1]])
            self.line2.set_ydata([0, 0])
        
        else :

            self.courbes.set_xlim(0, self.temps[-1])
            etend = max(self.Mhist) - min(self.Mhist)
            self.courbes.set_ylim(min(min(self.Mhist) - 0.1 * etend, - 0.1 * etend), max(max(self.Mhist) + 0.1 * etend, 0.1 * etend))
            self.line2.set_xdata([0, self.temps[-1]])
            self.line2.set_ydata([0, 0])
            self.a = - init
            self.ani = FuncAnimation(self.fig, self.iteration_animation, init - 1, interval = 1, repeat = False, blit = True)
    
    def iteration_animation(self, k) :
            
        self.line.set_xdata(self.temps[:self.a])
        self.line.set_ydata(self.Mhist[:self.a])
        self.im.set_data(self.hist[self.a])
        self.a += 1
        return self.im, self.line
    
    def animer(self, event) :
        
        if self.bool_anim :
            self.bool_anim = False
            self.ax.texts = []
            self.ax.text(2.2, 0.08, 'Off', transform = self.ax.transAxes, fontsize = 24)
            
        else :
            self.bool_anim = True 
            self.ax.texts = []
            self.ax.text(2.2, 0.08, 'On', transform = self.ax.transAxes, fontsize = 24)
        
        self.ax.text(1.2, 0.08, r'$T = $' + str(round(self.T / Tc, 1)) + r'$ \times Tc$', transform = self.ax.transAxes, fontsize = 24)
        self.ax.text(-.25, 0.08, r'$B = $' + str(round(self.B, 1)), transform = self.ax.transAxes, fontsize = 24)

Is = Ising()