import pylab as plt
import numpy as np
import matplotlib as mpl
from mpl_toolkits.mplot3d import Axes3D
mpl.rcParams['text.usetex'] = False

#Permet d'afficher en 3D les orbitales atomiques 'classiques'

##Le gros du calcul
def rayon(theta,phi,type) :
    if type == 's' : return ( (1)/(2*np.pi**0.5) )**2
    elif type == 'pz' : return ( np.cos(theta) * (3**0.5)/(2*np.pi**0.5) )**2
    elif type == 'px' : return ( np.sin(theta)*np.cos(phi) * (3**0.5)/(2*(2*np.pi)**0.5) )**2
    elif type == 'py' : return ( np.sin(theta)*np.sin(phi) * (3**0.5)/(2*(2*np.pi)**0.5) )**2
    elif type == 'dz2' : return ( (3*np.cos(theta)**2-1) * (5**0.5)/(4*np.pi**0.5) )**2
    elif type == 'dxz' : return ( np.sin(theta)*np.cos(theta)*np.cos(phi) * (15**0.5)/(2*(2*np.pi)**0.5) )**2
    elif type == 'dyz' : return ( np.sin(theta)*np.cos(theta)*np.sin(phi) * (15**0.5)/(2*(2*np.pi)**0.5) )**2
    elif type == 'dx2y2' : return ( np.sin(theta)**2*np.cos(2*phi) * (15**0.5)/(4*(2*np.pi)**0.5) )**2
    elif type == 'dxy' : return ( np.sin(theta)**2*np.sin(2*phi) * (15**0.5)/(4*(2*np.pi)**0.5) )**2

def orbitale(type) :    
    fig = plt.figure()
    ax = fig.gca(projection ='3d')
    u, v = np.mgrid[0:np.pi:100j, 0:2*np.pi:100j]
    X = rayon(u,v,type)*np.sin(u) * np.cos(v)
    Y = rayon(u,v,type)*np.sin(u) * np.sin(v)
    Z = rayon(u,v,type)*np.cos(u)
    
    ax.plot_surface(X, Y, Z,alpha=0.5)
    
    a=np.max([np.max(X),np.max(Y),np.max(Z)])
    ax.set_xlim(-a,a)
    ax.set_ylim(-a,a)
    ax.set_zlim(-a,a)
    
    #ax.axis('equal')
    ax.set_zlabel(r'$z$',size=20)
    plt.xlabel(r'$x$',size=20)
    plt.ylabel(r'$y$',size=20)
    
    if type == 's' : plt.title('Orbitale '+r'$s$',size=20)
    elif type == 'pz' : plt.title('Orbitale '+r'$p_z$',size=20)
    elif type == 'px' : plt.title('Orbitale '+r'$p_x$',size=20)
    elif type == 'py' : plt.title('Orbitale '+r'$p_y$',size=20)
    elif type == 'dz2' : plt.title('Orbitale '+r'$d_{z^2}$',size=20)
    elif type == 'dxz' : plt.title('Orbitale '+r'$d_{xz}$',size=20)
    elif type == 'dyz' : plt.title('Orbitale '+r'$d_{yz}$',size=20)
    elif type == 'dx2y2' : plt.title('Orbitale '+r'$d_{x^2-y^2}$',size=20)
    elif type == 'dxy' : plt.title('Orbitale '+r'$d_{xy}$',size=20)
    
    plt.show()

##Affichage
orbitale('dz2')