import matplotlib.pyplot as plt
import random
import math
import timeit

def sigmoid(x):
    return 1/(1+math.exp(-x))

def relu(x):
    return max(0, x)

# tailles :
# nb_entrees, nb_neurones_couche1, nb_neurones_couche2, ..., sorties

class Reseau:
    def __init__(self, tailles_couches, fonctions=None):
        self.tailles_couches = tailles_couches
        if fonctions is None:
            self.fonctions = [None] * self.len(tailles_couches)
        else:
            self.fonctions = fonctions
        self.nb_couches_internes = len(self.tailles_couches)-2
        self.poids = dict()
        for c in range(len(self.tailles_couches)-1):
            for i in range(tailles_couches[c]):
                for j in range(self.tailles_couches[c+1]):
                    self.poids[(c, i, j)] = random.random()
        self.biais = dict()
        for c in range(1, len(self.tailles_couches)):
            for i in range(self.tailles_couches[c]):
                self.biais[(c, i)] = random.random()
        
    def calcule(self, entree, delta_poids=None, delta_biais=None):
        couche_actuelle = entree
        for c in range(len(self.tailles_couches)-1):
            #print(couche_actuelle)
            prochaine_couche = []
            for j in range(self.tailles_couches[c+1]):
                val = self.biais[(c+1, j)]
                if delta_biais is not None and (c+1, j) in delta_biais:
                    val += delta_biais[(c+1, j)]
                for i in range(self.tailles_couches[c]):
                    poids = self.poids[(c, i, j)]
                    if delta_poids is not None and (c, i, j) in delta_poids:
                        poids += delta_poids[(c, i, j)]
                    val += poids * couche_actuelle[i]
                if self.fonctions[c+1] is not None:
                    val = self.fonctions[c+1](val) # on applique la fonction d'activation
                prochaine_couche.append(val)
            couche_actuelle = prochaine_couche
        return prochaine_couche
    
    def calcule_liste(self, entrees):
        return [self.calcule(entree) for entree in entrees]
    
    def cout(self, entrees, attendus, delta_poids=None, delta_biais=None):
        resultats = 0
        n = len(entrees)
        n_sorties = len(attendus[0])
        for i in range(n):
            res = 0
            obtenus = self.calcule(entrees[i], delta_poids, delta_biais) 
            resultats += sum((obtenus[j] - attendus[i][j])**2 for j in range(n_sorties)) / n_sorties
        return resultats / n
    
    def derivee_poids(self, entrees, attendus, couche, i, j, cout, h=0.000001):
        cout2 = self.cout(entrees, attendus, {(couche, i, j): h}, None)
        return (cout2 - cout) / h
    
    def derivee_biais(self, entrees, attendus, couche, i, cout, h=0.000001):
        cout2 = self.cout(entrees, attendus, None, {(couche, i): h})
        return (cout2 - cout) / h
    
    def calcule_derivee(self, entrees, attendus, cout):
        variation_poids = dict()
        for (c, i, j) in self.poids:
            variation_poids[(c, i, j)] = self.derivee_poids(entrees, attendus, c, i, j, cout)
        variation_biais = dict()
        for (c, i) in self.biais:
            variation_biais[(c, i)] = self.derivee_biais(entrees, attendus, c, i, cout)
        return variation_poids, variation_biais
    
    def apprentissage(self, entrees, attendus, pas=2, seuil=0.00001, nmax=10000, objectif=0.0001):
        n = 0 # compteur d'étapes
        cout = self.cout(entrees, attendus)
        while pas > seuil and n <= nmax and cout > objectif:
            var_poids, var_biais = self.calcule_derivee(entrees, attendus, cout)
            cout2 = cout + 1
            pas = 2 * pas
            while cout2 > cout:
                pas = pas / 2
                var_poids2 = {p: -var_poids[p] * pas for p in var_poids}
                var_biais2 = {b: -var_biais[b] * pas for b in var_biais}
                cout2 = self.cout(entrees, attendus, var_poids2, var_biais2)
                #if cout2 > cout:
                #    print(f"on divise au bout de {n} etapes, un pas de {pas} et un cout de {cout}")
            cout = cout2
            pas *= 1.05
            for p in var_poids:
                self.poids[p] += var_poids2[p]
            for b in var_biais:
                self.biais[b] += var_biais2[b]
            n += 1
            if n % 1000 == 0:
                print(f"Étape {n}, cout {cout}, pas {pas}")

# v entre 0 et 1
def couleur_bool(v):
    v = min(1, max(0, v))
    return (1-v, v, .5*v)

def afficher_fonction_booleenne(reseau, nb_points, entrees=None):
    if entrees is None:
        xi = [random.random() for _ in range(nb_points)]
        yi = [random.random() for _ in range(nb_points)]
        entrees = zip(xi, yi)
    else:
        xi = [entree[0] for entree in entrees]
        yi = [entree[1] for entree in entrees]
        nb_points = len(xi)
    valeurs = reseau.calcule_liste(entrees)
    couleurs = [couleur_bool(v[0]) for v in valeurs]
    symbs = ['.' if v[0] > .5 else 'x' for v in valeurs]
    fig, ax = plt.subplots()
    xi0 = [xi[i] for i in range(nb_points) if valeurs[i][0] <= .5]
    xi1 = [xi[i] for i in range(nb_points) if valeurs[i][0] > .5]
    yi0 = [yi[i] for i in range(nb_points) if valeurs[i][0] <= .5]
    yi1 = [yi[i] for i in range(nb_points) if valeurs[i][0] > .5]
    couls0 = [couleur_bool(valeurs[i][0]) for i in range(nb_points) if valeurs[i][0] <= .5]
    couls1 = [couleur_bool(valeurs[i][0]) for i in range(nb_points) if valeurs[i][0] > .5]
    ax.scatter(xi0, yi0, color=couls0, marker='x') # Nuage de points
    ax.scatter(xi1, yi1, color=couls1, marker='o')
    #ax.set_title(f"coût : {C(a, b, xi,yi)}")
    plt.show()

def afficher_vraie_fonction_booleenne(f, nb_points):
    xi = [random.random() for _ in range(nb_points)]
    yi = [random.random() for _ in range(nb_points)]
    valeurs = [f(xi[i], yi[i]) for i in range(nb_points)]
    couleurs = [couleur_bool(v) for v in valeurs]
    symbs = ['.' if v > .5 else 'x' for v in valeurs]
    fig, ax = plt.subplots()
    xi0 = [xi[i] for i in range(nb_points) if valeurs[i] <= .5]
    xi1 = [xi[i] for i in range(nb_points) if valeurs[i] > .5]
    yi0 = [yi[i] for i in range(nb_points) if valeurs[i] <= .5]
    yi1 = [yi[i] for i in range(nb_points) if valeurs[i] > .5]
    couls0 = [couleur_bool(valeurs[i]) for i in range(nb_points) if valeurs[i] <= .5]
    couls1 = [couleur_bool(valeurs[i]) for i in range(nb_points) if valeurs[i] > .5]
    ax.scatter(xi0, yi0, color=couls0, marker='x') # Nuage de points
    ax.scatter(xi1, yi1, color=couls1, marker='o')
    #ax.set_title(f"coût : {C(a, b, xi,yi)}")
    plt.show()

"""
reseau_abs = Reseau([1, 2, 1], [None, relu, None])

entrees = [[i] for i in range(-10, 10)]
sorties = [[abs(*entrees[i])] for i in range(len(entrees))]

reseau_abs.apprentissage(entrees, sorties)
"""

def vrai(a):
    return a > 0.5

def ou(a, b):
    return int(vrai(a) or vrai(b))

def et(a, b):
    return int(vrai(a) and vrai(b))

def xor(a, b):
    return int(vrai(a) ^ vrai(b))

def cercle(a, b):
    return int(((a-.5)**2+(b-.5)**2)<.25)


# Vous pouvez changer le nombre de neurones et de couches. Il faut 2 entrées et 1 sortie.
reseau = Reseau([2, 4, 1], [None, sigmoid, sigmoid])

# Vous pouvez changer la fonction à simuler
fonction_cible = xor

# Pour voir ce qui est attendu
#afficher_vraie_fonction_booleenne(fonction_cible, 10000)

cout = 1
vraies_entrees = [[0, 0], [0, 1], [1, 0], [1, 1]]
vraies_sorties = [[fonction_cible(*vraies_entrees[i])] for i in range(len(vraies_entrees))]
n = 7 # nombre de points par ligne et par colonne pour l'apprentissage
entrees = [[i/n, j/n] for i in range(n+1) for j in range(n+1)]
sorties = [[fonction_cible(*entrees[i])] for i in range(len(entrees))]
while cout > 0.05:
    reseau.apprentissage(entrees, sorties, pas=5, objectif=0.01)
    cout = reseau.cout(entrees, sorties)
    print(f"fin de l'apprentissage avec un cout de {cout}")
    cout = 0

n = 20 # nombre de points par ligne et par colonne pour l'affichage
entrees = [[i/n, j/n] for i in range(n+1) for j in range(n+1)]
afficher_fonction_booleenne(reseau, 100, entrees)