# Auteur : Clément de la Salle
# Agrégation de physique, ENS de Lyon, 2019-2020

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.image as im
import pylab as plb
from matplotlib.widgets import Slider, Button, RadioButtons
import matplotlib.patches as pc

def ColorToGrey(image) :
    
    return np.dot(image[:, :, :3], [0.299, 0.587, 0.144])


def FFT(image_grise) :

    return plb.fftshift(plb.fft2(image_grise))

def FFTinv(fft_masque) :
    
    return np.real(plb.ifft2(plb.ifftshift(fft_masque)))

def dessin(event) :
    
    global click, masque
            
    if event.inaxes == None : click = []
    
    elif event.inaxes.get_title() == 'Plan de Fourier' :
        
        if len(click) == 0 :
            
            click = [event.xdata, event.ydata]
            
        else :
            
            if forme_filtre == 'Cercle' :
                
                rayon = np.sqrt((click[0] - event.xdata) ** 2 + (click[1] - event.ydata) ** 2)
                q = quadrillage - click
                d = np.sqrt((q[:, :, 0] ** 2 + q[:, :, 1] ** 2))
                if plein == 'Plein' : masque = d < rayon
                else : masque = d > rayon
            
            elif forme_filtre == 'Rectangle' :
                
                xy = [min(click[0], event.xdata), min(click[1], event.ydata)]
                width = abs(click[0] - event.xdata)
                height = abs(click[1] - event.ydata)
                q = quadrillage - xy
                masque = (q[:, :, 0] < width) * (q[:, :, 1] < height) * (q[:, :, 0] > 0) * (q[:, :, 1] > 0)
                if plein != 'Plein' : masque ^= True
                
            click = []
            tracer()
    
    else :  click = []
    
    plt.draw()
    

def change_forme_filtre(label) :
    
    global forme_filtre
    
    forme_filtre = label    

def change_plein_creux(label) :
    
    global plein, masque
    
    plein = label
    masque ^= True
    tracer()
    

def tracer() :

    TF_filtre = TF_image.copy()
    TF_filtre[masque] = 0
    
    TF_aff = TF_image.copy()
    TF_aff[TF_filtre == 0] = 1
    plan_Fourrier.set_data(np.log10(np.abs(TF_aff)))
    
    ecran.set_data(FFTinv(TF_filtre))
    plt.draw()


image = ColorToGrey(im.imread('grille.png'))

TF_image = FFT(image)
x, y = np.mgrid[100:-100:complex(0, image.shape[0]), -100:100:complex(0, image.shape[1])]
quadrillage = np.concatenate((y.reshape(*y.shape, 1), x.reshape(*x.shape, 1)), axis = 2)
ecran = FFTinv(TF_image)

TF_filtre = TF_image.copy()
masque = np.ones(TF_filtre.shape, dtype = np.bool)

fig = plt.figure()
ax1 = fig.add_subplot(131, aspect = 'equal', title = 'Objet')
ax2 = fig.add_subplot(132, aspect = 'equal', title = 'Plan de Fourier')
ax3 = fig.add_subplot(133, aspect = 'equal', title = 'Image sur l\'écran')
objet = ax1.imshow(image, cmap = 'gray', extent = [-100, 100, -100, 100])
plan_Fourrier = ax2.imshow(np.log10(np.abs(TF_image)), cmap = 'gray', extent = [-100, 100, -100, 100])
ecran = ax3.imshow(ecran, cmap = 'gray', extent = [-100, 100, -100, 100])

axe_radio = fig.add_axes([.38, .1, 0.1, .1])
axe_radio.set_xlim(-1000, -999)
bouton_radio = RadioButtons(axe_radio, ['Cercle', 'Rectangle'], active = 0)
forme_filtre = 'Cercle'
bouton_radio.on_clicked(change_forme_filtre)

axe_radio2 = fig.add_axes([.52, .1, 0.1, .1])
axe_radio2.set_xlim(-1000, -999)
bouton_radio2 = RadioButtons(axe_radio2, ['Plein', 'Creux'], active = 0)
plein = 'Plein'
bouton_radio2.on_clicked(change_plein_creux)

click = []
fig.canvas.mpl_connect('button_press_event', dessin)

mng = plt.get_current_fig_manager()       # Vous pouvez décommenter ça si vous utilisez 'TkAgg'
#mng.window.state('zoomed')
plt.show()