#!/usr/bin/python3
import numpy
from multiprocessing.pool import Pool
import pickle
from inspect import getsource
from datetime import datetime
from lmfit import Parameters, minimize
from fepydas.libs.fit_functions import BiExponential, BiExponentialTail, Exponential, Gaussian, Lorentzian, Linear, TriExponential
from fepydas.datatypes.Data import Transformation
from scipy.signal import convolve

class Fit:
  def __init__(self, function):
    self.function = function

  def saveBinary(self, filename):
    f = open(filename,"bw")
    pickle.dump(self,f)
    f.close()

  def toTransformation(self):
    return Transformation(self.function,self.result.params.valuesdict())

  def initializeParameters(self, parameters: Parameters):
    self.parameters = parameters
 
  def residual(self, params, x, data=None, eps=None):
    parvals = params.valuesdict()
    if eps  is not None:
      return (data-self.function(x,**parvals))*eps
    else:
      return data-self.function(x,**parvals)

  def residualLog(self, params, x, data=None, eps=None):
    parvals = params.valuesdict()
    vals = self.function(x,**parvals)
    vals = numpy.maximum(vals, numpy.full_like(vals, 1E-9))
    if eps  is not None:
      return (numpy.log(data)-numpy.log(vals))*eps
    else:
      return (numpy.log(data)-numpy.log(vals))

  def convolve(self, signal, ref):
    return convolve(signal, ref, mode="full")[:len(signal)]

  def convolutedResidual(self, params, x, data=None, irf=None):
    parvals = params.valuesdict()
    return data - self.convolve(self.function(x,**parvals),irf)

  def fit(self, x, y, eps=None, nan_policy="raise", log=False):
    if log:
      self.result = minimize(self.residualLog, self.parameters, args=(x,y,eps), method="leastsq", nan_policy=nan_policy)
    else:
      self.result = minimize(self.residual, self.parameters, args=(x,y,eps), method="leastsq", nan_policy=nan_policy)
    return self.result

  def fitSpectrum(self, spectrum):
    return self.fit(spectrum.axis.values, spectrum.data.values)

  def batchFit(self, x, y):
    print("BatchFit",x.shape, y.shape)
    args = {}
    for i in range(y.shape[0]):
      args[i] = [x, y[i]]
    self.executeBatchFit(self.fit, args)

  def convolutedFit(self, x, y, irf):
    self.result = minimize(self.convolutedResidual, self.parameters, args =(x,y,irf))
    return self.result

  def convolutedBatchFit(self, x, y, irf):
    args = {}
    for i in range(y.shape[0]):
      args[i] = [x, y[i], irf]
    self.executeBatchFit(self.convolutedFit, args)

  def executeBatchFit(self, func, args):
    pool = Pool()
    jobs = {}
    for key in args.keys():
      jobs[key] = pool.apply_async(func, args[key])
    results = {}
    for key in args.keys():
      results[key]=jobs[key].get(1000) 
    self.results = results

  def batchEvaluate(self, x):
    data = numpy.ndarray(shape=(len(self.results.keys()), len(x)))
    for key in self.results.keys():
      pars = self.results[key].params.valuesdict()
      data[key,:] = self.function(x, **pars)
    return data

  def batchEvaluateConvolution(self, x, irf):
    data = self.batchEvaluate(x)
    for i in range(data.shape[0]):
      data[i,:] = self.convolve(data[i,:],irf)
    return data

  def evaluate(self, x):
    parvals = self.result.params.valuesdict()
    return self.function(x,**parvals)

  def evaluateInput(self, x):
    parvals = self.parameters.valuesdict()
    return self.function(x,**parvals)
  
  def evaluateConvolution(self, x, irf):
    return self.convolve(self.evaluate(x),irf)
  
  def startReport(self, filename):
    f = open(filename,"w") 
    f.write("Generated by Fit: {0}\n".format(self.__class__.__name__))
    f.write("Time: {0}\n".format(datetime.now()))
    f.write("Model: {0}\n".format(self.function.__name__))
    f.write("{0}\n".format("\n".join(getsource(self.function).split("\n")[1:-1])))
    f.write("Input Parameters: \n")
    for p in self.parameters.keys():
      f.write("  {0}:\t{1}\t[{2}:{3}]\n" .format(p,self.parameters[p].value,self.parameters[p].min,self.parameters[p].max))
    f.write("---\n")
    return f

  def saveReport(self, filename):
   f = self.startReport(filename)
   params = self.parameters.keys()
   f.write("Parameter\tValue\tError\n")
   for p in params:
     f.write("{0}\t{1}\t{2}\n".format(p,self.result.params[p].value,self.result.params[p].stderr))
   f.write("\nnfev\t{0}\tchisqr\t{1}\tredchi\t{2}\n".format(self.result.nfev,self.result.chisqr,self.result.redchi))
   f.close()

  def batchSaveReport(self, filename, labels):
    f = self.startReport(filename)
    keys = self.results.keys()
    params = self.parameters.keys()
    f.write("Dataset")
    for p in params:
      f.write("\t{0}\t{0}Error".format(p))
    f.write("\tnfev\tchisqr\tredchi")
    f.write("\n")
    for k in keys:
      f.write("{0}".format(labels[k]))
      for p in params:
        f.write("\t{0}\t{1}".format(self.results[k].params[p].value,self.results[k].params[p].stderr))
      f.write("\t{0}\t{1}\t{2}".format(self.results[k].nfev,self.results[k].chisqr,self.results[k].redchi))
      f.write("\n")
    f.close()

  def initializeAutoFromSpectrum(self, spectrum):
    self.initializeAuto(spectrum.axis.values, spectrum.data.values)

class LorentzianFit(Fit):
  def __init__(self):
    super().__init__(Lorentzian)

class SpectralLine(LorentzianFit):
  def __init__(self):
    super().__init__()
  def initializeAuto(self, x, y): 
    bg = (y[0]+y[-1])/2
    I = numpy.max(y)-bg
    x0 = x[numpy.argmax(y)]
    idx = numpy.where(y>bg+I/2)[0]
    fwhm = numpy.abs(x[idx[-1]] - x[idx[0]])
    params = Parameters()
    params.add("bg", value=bg)
    params.add("I", value=I)
    params.add("x0", value=x0)
    params.add("fwhm", value=fwhm)
    self.initializeParameters(params)

class GaussianFit(Fit):
  def __init__(self):
    super().__init__(Gaussian)

  def initializeAuto(self, x, y): 
    bg = (y[0]+y[-1])/2
    I = numpy.max(y)-bg
    x0 = x[numpy.argmax(y)]
    idx = numpy.where(y>bg+I/2)[0]
    fwhm = numpy.abs(x[idx[-1]] - x[idx[0]])
    params = Parameters()
    params.add("bg", value=bg)
    params.add("I", value=I)
    params.add("x0", value=x0)
    params.add("fwhm", value=fwhm+x0/1000)
    self.initializeParameters(params)

class LimitedGaussianFit(GaussianFit):
  def __init__(self):
    super().__init__()
  
  def initializeAutoLimited(self, x, y, center, range, thresh=1):
    if (len(numpy.where(y == numpy.max(y))[0])) > 1:
      return False, False
    lowerIdx = (numpy.abs(x - (center-range))).argmin()
    higherIdx = (numpy.abs(x - (center+range))).argmin()
    if numpy.abs(higherIdx-lowerIdx) < 15:
      return False, False
    if higherIdx < lowerIdx:
      t = lowerIdx
      lowerIdx = higherIdx
      higherIdx = t
    x = x[lowerIdx:higherIdx]
    y = y[lowerIdx:higherIdx]
    if numpy.amax(y) <= thresh:
      return False, False
    self.initializeAuto(x, y)
    return x, y  

class LinearFit(Fit):
  def __init__(self):
    super().__init__(Linear)

class CalibrationFit(LinearFit):
  def __init__(self):
    super().__init__()
  def initializeAuto(self):
    params = Parameters()
    params.add("a",value=1)
    params.add("b",value=0)
    self.initializeParameters(params)

class AutomaticCalibration(CalibrationFit):
  def __init__(self, spectrum, references, threshold=10, width=10):
    super().__init__()
    peaks = spectrum.identifyPeaks(threshold=threshold, width=width)
    SpectralFit = SpectralLine()
    peakVals = []
    peakErrs = []
    for peak in peaks:
      x,y = spectrum.axis.values[peak[0]:peak[1]], spectrum.data.values[peak[0]:peak[1]]
      SpectralFit.initializeAuto(x,y)
      SpectralFit.fit(x,y)
      peakVals.append(SpectralFit.result.params["x0"].value)
      peakErrs.append(SpectralFit.result.params["x0"].stderr)
    references = numpy.array(references,dtype=numpy.float64)
    peakVals = numpy.array(peakVals,dtype=numpy.float64)
    peakErrs = numpy.array(peakErrs,dtype=numpy.float64)
    idx = numpy.where(~numpy.isnan(references))[0]
    print("Calibration with ",references,peakVals,idx)
    self.initializeAuto()
    self.fit(peakVals[idx], references[idx])
  

class ExponentialFit(Fit):
  def __init__(self):
    super().__init__(Exponential)
  def initialize(self,bg,I,x0,tau,rise):
    params = Parameters()
    params.add("bg",value=bg)
    params.add("I",value=I,min=0)
    params.add("x0",value=x0)
    params.add("tau",value=tau,min=0)
    params.add("rise",value=rise,min=0)
    self.initializeParameters(params)

  def initializeAuto(self,x,y):
    params = Parameters()
    params.add("bg",value=y[0])
    params.add("I",value=numpy.max(y)-y[0],min=0)
    #params.add("I_rise",value=numpy.max(y)-y[0],min=0)
    params.add("x0",value=x[numpy.argmax(y)])
    params.add("tau",value=1,min=0)
    params.add("rise",value=0.001,min=0)
    self.initializeParameters(params)

class BiExponentialFit(Fit):
  def __init__(self):
    super().__init__(BiExponential)
  def initializeAuto(self,x,y):
    params = Parameters()
    params.add("bg",value=y[0])
    params.add("I_1",value=(numpy.max(y)-y[0])/2,min=0)
    params.add("I_2",value=(numpy.max(y)-y[0])/2,min=0)
    params.add("x0",value=x[numpy.argmax(y)])
    params.add("tau_1",value=5,min=0)
    params.add("tau_2",value=50,min=0)
    params.add("rise",value=0.001,min=0)
    self.initializeParameters(params)

  def initialize(self,I_1, I_2, tau_1, tau_2, x0=0, rise= 0.001):
    params = Parameters()
    params.add("bg",value=0)
    params.add("I_1",value=I_1,min=0)
    params.add("I_2",value=I_2,min=0)
    params.add("x0",value=x0)
    params.add("tau_1",value=tau_1,min=0)
    params.add("tau_2",value=tau_2,min=0)
    params.add("rise",value=rise,min=0)
    self.initializeParameters(params)

class TriExponentialFit(Fit):
  def __init__(self):
    super().__init__(TriExponential)
  def initializeAuto(self,x,y):
    params = Parameters()
    params.add("bg",value=y[0])
    params.add("I_1",value=(numpy.max(y)-y[0])/3,min=0)
    params.add("I_2",value=(numpy.max(y)-y[0])/3,min=0)
    params.add("I_3",value=(numpy.max(y)-y[0])/3,min=0)
    params.add("x0",value=x[numpy.argmax(y)])
    params.add("tau_1",value=5,min=0)
    params.add("tau_2",value=50,min=0)
    params.add("tau_3",value=500,min=0)
    params.add("rise",value=0.001,min=0)
    self.initializeParameters(params)

  def initialize(self,I_1, I_2, I_3, tau_1, tau_2, tau_3,x0=0, rise= 0.001):
    params = Parameters()
    params.add("bg",value=0)
    params.add("I_1",value=I_1,min=0)
    params.add("I_2",value=I_2,min=0)
    params.add("I_3",value=I_3,min=0)
    params.add("x0",value=x0)
    params.add("tau_1",value=tau_1,min=0)
    params.add("tau_2",value=tau_2,min=0)
    params.add("tau_3",value=tau_3,min=0)
    params.add("rise",value=rise,min=0)
    self.initializeParameters(params)


    
class BiExponentialTailFit(Fit):
  def __init__(self):
    super().__init__(BiExponentialTail)
  def initializeAuto(self,x,y):
    params = Parameters()
    params.add("bg",value=y[-1])
    params.add("I_1",value=(numpy.max(y)-y[-1])/2,min=0)
    params.add("I_2",value=(numpy.max(y)-y[-1])/2,min=0)
    params.add("tau_1",value=5,min=0)
    params.add("tau_2",value=50,min=0)
    self.initializeParameters(params)
  def initialize(self, I_1, I_2, tau_1, tau_2):
    params = Parameters()
    params.add("bg",value=0)
    params.add("I_1",value=I_1,min=0)
    params.add("I_2",value=I_2,min=0)
    params.add("tau_1",value=tau_1,min=0)
    params.add("tau_2",value=tau_2,min=0)
    self.initializeParameters(params)
