#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt
import numpy as np
import widgets
import scipy.constants as constants
from matplotlib import rc
from numpy.random import random
from math import floor

Tmax = 501
N = 100

parameters = {
    'Temps' : widgets.FloatSlider(value=0, description='Temps', min=0, max=Tmax-1)}

T = np.arange(0,Tmax)

def diffusion(X0,Y0,Tmax = Tmax):
    P = np.zeros([Tmax,2])
    P[0] = [X0,Y0]
    for i in range(1,Tmax):
        angle = random()*2*np.pi
        P[i] = P[i-1] + np.array([np.cos(angle),np.sin(angle)])
    return P
    
Collection = []


for i in range(N) :
    Collection.append(diffusion(random(), random(),Tmax))
Collection = np.array(Collection)
Meandistance = np.sqrt(np.mean(Collection[:, :, 0]**2+Collection[:, :, 1]**2, axis = 0))
Meanposition = np.mean(Collection[:, :, :], axis = 0)
def plot_data(Temps):
    timestamp = floor(Temps)
    lines["Coll"].set_data(Collection[:, timestamp, 0], Collection[:, timestamp, 1])
    lines["Middle"].set_data(Meanposition[timestamp])
    lines["Eloignement moyen"].set_data(T[:timestamp], Meandistance[:timestamp])
    lines["Middlefollow"].set_data(Meanposition[:timestamp].T)
    fig.canvas.draw_idle()
    
fig = plt.figure(figsize=(12,6))
fig.suptitle("Marche aléatoire pour {} entités".format(N))

ax = fig.add_axes([0.05, 0.3, 0.6, 0.6])
ax2 = fig.add_axes([0.7, 0.3, 0.23, 0.6])
ax2.plot(np.sqrt(T+1), label = "f(Temps) = Temps$^{1/2}$")
ax.axis("equal")
lines = {}
lines['Coll'], = ax.plot([], [],"o", color='red', label = "Positions des réalisations")
lines["Middle"], = ax.plot([], [],"o", color='blue', label = "Position moyenne des réalisations")
lines["Middlefollow"], = ax.plot([], [], color='blue', label = "Historique de la position moyenne")
lines["Eloignement moyen"], = ax2.plot([], [], color='red', label = "Distance moyenne au centre")

xmin, xmax, ymin, ymax = np.min(Collection[:, :, 0]),np.max(Collection[:, :, 0]),np.min(Collection[:, :, 1]),np.max(Collection[:, :, 1]) 

ytop = max(ymax, np.abs(ymin))
xtop = max(xmax, np.abs(xmin))

ax.set_xlim(-xtop, xtop)
ax.set_ylim(-ytop, ytop)

ax2.set_xlim(0,Tmax)
ax2.set_ylim(0,np.max(Meandistance))

ax.set_xlabel('X')
ax.set_ylabel('Y')

ax2.set_xlabel('Temps (s)')
ax2.set_ylabel('Distance (u)')
ax2.legend(loc = "lower right")
ax.legend(loc = "lower right")
param_widgets = widgets.make_param_widgets(parameters, plot_data, slider_box=[0.15, 0.07, 0.4, 0.15])
#choose_widget = widgets.make_choose_plot(lines, box=[0.015, 0.25, 0.2, 0.15])
reset_button = widgets.make_reset_button(param_widgets)

if __name__=='__main__':
    plt.show()