import scipy as sc
import scipy.sparse as sp
import numpy as np
import pandas as pd
import os
import re
import matplotlib.pyplot as plt

def timereport(path, file):
    data = pd.read_csv(path + "report/" + file)
    data = data.to_numpy()
    fig = plt.figure()
    plt.plot(data)
    plt.xlabel("Generation")
    plt.ylabel("Compute time")
    plt.show()  

def basic_plotter(num, c, mutlim, path, select=0):
    # path = "E:/Simulations/Pos Neg Evolution/x5 Test/Test1/%5Ctest"
    
    y = [0 for i in range(num)]
    x = [i for i in range(num)]
    
    for i in range(num):
        a = sc.sparse.load_npz(path + "_" + str(i) + ".npz")
        y[i] = a._shape[0]
        if i % c == 0:
            temp = []
            if select:
                temp= a[:,1].toarray() + a[:,2].toarray()
            else:
                temp = a[:,1].toarray()
            fig = plt.figure()
            plt.hist(temp)
            plt.xlim(mutlim)
            plt.xlabel("Mutation number")
            plt.ylabel("Cells number")
            plt.show()
            
            # temp = a[:,0].toarray()
            # fig = plt.figure()
            # plt.hist(np.log10(temp))
            # plt.xlim(mutlim)
            # plt.xlabel("Fitness parameter [log10]")
            # plt.ylabel("Cells number")
            # plt.show()
    
    fig = plt.figure()
    plt.plot(x,y)
    plt.xlabel("Generation")
    plt.ylabel("Population size")
    plt.show()   
  
def npztotxt(path_in, fname):
    a = sc.sparse.load_npz(path_in + fname)
    df = pd.DataFrame(a.toarray())
    if a._shape[1] == 3:
        df.columns = ["fitness", "positive mutation number", "negative mutation number"]
        os.makedirs(path_in + "/CSV/", exist_ok=True)
        df.to_csv(path_in + "/CSV/" + fname.rstrip(".npz") + ".csv", index=False)
    else:
        columns = ["fitness", "mutation number"]
        # for i in range(a._shape[0] - 1):
        #     columns.append(str(i))
        df.columns = columns      
        os.makedirs(path_in + "/CSV/", exist_ok=True)
        df.to_csv(path_in + "/CSV/" + fname.rstrip(".npz") + ".csv", index=False)
 
def stoalfafit(path, mi, ma, step):
    dt = []
    _filter = (lambda x: [a.endswith('.npz') for a in x])
    columns = []
    for i in range(1,12,1):
        columns = np.append(columns, np.repeat("%.4f" % ((i-1)*step + mi*(i-1==0)), 3))
        files = np.array(os.listdir(path + str(i)))
        files = files[_filter(files)]
        a = np.zeros((len(files), 3))
        for j in files:
            idx = int(float(re.search(r'\d+.\d+', j[::1]).group())/50)
            file = sp.load_npz(path + str(i) + '/' + j).toarray()
            a[idx,0] = np.sum(file, axis=0)[1] / len(file)
            a[idx,1] = np.std(file[:,1])
            a[idx,2] = np.median(file[:,1])
        if dt == []:
            dt = a
        else:
            dt = np.append(dt, a, axis=1)
            
    
    df = pd.DataFrame(dt)
    df.columns = columns
    df.to_csv(path + 'data.csv')
 
if __name__ == "__main__":
    path = 'E:/Simulations/alfa test 2/'
    stoalfafit(path, 0.001, 0.05, 0.005)
    # t = [float(x) for x in range(0,5351,25)]
    # fname = "pos_n_normal"
    # # file = "pos_report_3.txt"
    # path = "E:/Simulations/Stationary Wave/Normal positive stationary/"
    # # # timereport(path, file)
    # # basic_plotter(t, 10, (0,2000), path + fname)
    # # # npztotxt("E:/Simulations/Pos Neg Evolution/x5 Test/Test1/", "%5Ctest_1349.npz")
    # for i in t:
    #     name = fname + "_" + str(i) + ".npz"
    #     npztotxt(path, name)
    # alfa = 10
    # x = np.array([x/100 for x in range(100000)])
    # a = x > 1
    # b = 1 - np.exp(-alfa * x)
    # c = np.arctan(alfa*x) / (np.pi/2)
    # d = -1 / (alfa*x + 1) + 1
    # plt.plot(x, a)
    # plt.plot(x, b)
    # plt.plot(x, c)
    # plt.plot(x, d)
    