# coding: utf-8

import pandas as pd
import numpy as np
from .likelihood import perform_mle
from .hessian import evaluate_precision
from .fit_tools import read_a2z, plot_from_param, a2z_to_pkb
from .bayesian import wrap_explore_distribution 
from os import path

def update_with_centiles (df_a2z, input_centiles) :
  '''
  Update df_a2z with fitted parameter.
  '''
  centiles = np.copy (input_centiles)
  centiles = centiles [:, :centiles[1,:].size-1]
  #cond_exp = ((df_a2z[2]=='height')|(df_a2z[2]=='width'))&(df_a2z[6]==0)
  aux = df_a2z.loc[df_a2z[6]==0]
  cond_exp = (aux[2]=='height')|(aux[2]=='width')
  a_cond_exp = cond_exp.to_numpy ()
  #retransform width and height (the explored distribution is the distribution of the logarithm)
  for ii in range (centiles.shape[0]) :
    centiles[ii,a_cond_exp] = np.exp (centiles[ii,a_cond_exp])
  # Update df_a2z with the parameters extracted from the sampled posterior probability
  df_a2z.loc[df_a2z[6]==0, 4] = centiles [1,:]
  sigma_1 = centiles[1,:] - centiles[0,:]
  sigma_2 = centiles[2,:] - centiles[1,:]
  sigma = np.maximum (sigma_1, sigma_2) 
  df_a2z.loc[df_a2z[6]==0, 5] = sigma

  return df_a2z

def peakbagging (a2z_file, freq, psd, back=None, wdw=None, spectro=True,
                 fit_splittings=False, nsteps_mcmc=1000, show_corner=False,
                 store_chains=False, mcmcDir='.', order_to_fit=None, parallelise=False, progress=False,
                 strategy='order', fit_02=True, fit_13=True, nwalkers=64, normalise=False, instr='geometric', show=True,
                 filename_summary=None, show_prior=False, coeff_discard=50, use_sinc=False, asym_profile='korzennik') :
  '''
  Read an a2z input file with a prior guess of the modes to fit and return the result 
  of the peakbagging protocol processed over those modes. 
  Fitting strategy is the following: pair 1-3 and 2-0 of same order are successively fitted from low frequencies
  to high frequencies following a pseudo-global strategy. This protocol is repeted n_process times. 

  :param a2z_file: name of the file to read the parameters.
  :type a2z_file: string

  :param freq: frequency vector
  :type freq: ndarray

  :param psd: real power vector
  :type psd: ndarray

  :param back: activity background vector that will be used to complete the model to fit. Optional default ``None``.
    Must have the same length than freq and psd. 
  :type back: ndarray.

  :param wdw: observation window (0 and 1 array of the same lenght as the original timeseries)
    to analyse in order to predict the sidelob pattern. Optional, default ``None``. 
  :type wdw: ndarray.

  :param fit_splittings: if set to ``True``, the global param *split* in the a2z input will be adjusted. Optional, default to ``False``.
    It is not necessary to change this parameter if a different splitting value is fitted on each order. 
  :type fit_splittings: bool

  :param nsteps_mcmc: number of steps to process into each MCMC exploration.
  :type nsteps_mcmc: int

  :param show_corner: if set to ``True`` and using MCMC, show the corner plot for the posterior distribution sampling.
  :type show_corner: bool

  :param store_chains: if set to ``True``, each MCMC sampler will be stored as an hdf5 files. Filename will be autogenerated
    with modes parameters. Optional, default ``False``
  :type store_chains: bool

  :param mcmcDir: directory where to save the MCMC sampler. Optional, default ``.``.
  :type mcmcDir: string
 
  :param order_to_fit: list of order to fit if the input a2z file contains order that are supposed not to be fitted.
    Optional, default ``None``.
  :type order_to_fit: array-like

  :param parallelise: If set to ``True``, use Python multiprocessing tool to parallelise process.
    Optional, default ``False``.
  :type parallelise: bool

  :param strategy: strategy to use for the fit, ``order`` or ``pair``. Optional, default ``order``. If 'pair' is used, a2z input must contain
    individual heights, widths and splittings for each degree.
  :type strategy: str

  :param fit_02: if strategy is set to ``pair``, l=0 and 2 modes will only be fitted if this parameter is set to ``True``.
    Optional, default ``True``.
  :type fit_02: bool

  :param fit_13: if strategy is set to ``pair``, l=1 and 3 modes will only be fitted if this parameter is set to ``True``.
    Optional, default ``True``.
  :type fit_13: bool

  :param nwalkers: number of wlakers in the MMCM process.
  :type nwalkers: int

  :param instr: instrument to consider (amplitude ratio inside degrees depend on geometry 
    AND instrument and should be adaptated). Possible argument : ``geometric``, ``kepler``, ``golf``, ``virgo``.
    Optional, default ``geometric``. 
  :type instr: str

  :param coeff_discard: coeff used to compute the number of values to discard at the beginning
    of each MCMC : total amount of sampled values will be divided by coeff_discard. Optional, default 50.
  :type coeff_discard: int

  :param use_sinc: if set to ``True``, mode profiles will be computed using cardinal sinus and not Lorentzians.
    No asymmetry term will be used if it is the case. Optional, default ``False``.
  :type use_sinc: bool

  :param asym_profile: depending on the chosen argument, asymmetric profiles will be computed following Korzennik 2005 (``korzennik``)
    or Nigam & Kosovichev 1998 (``nigam-kosovichev``). 
  :type asym_profile: str

  :return: a2z fitted modes parameters as a DataFrame
  :rtype: pandas DataFrame
  '''

  df_a2z = read_a2z (a2z_file)

  # show prior
  param_prior = a2z_to_pkb (df_a2z)
  plot_from_param (param_prior, freq, psd, back=back, wdw=wdw, smoothing=10, spectro=spectro, correct_width=1.,
                   show=show_prior, instr=instr)

  # by default fix all parameters
  df_a2z.loc[:,6] = 1
  #... unless fit_splittings is set to True
  if fit_splittings :
    df_a2z.loc[(df_a2z[2]=='split')&(df_a2z[0]=='a'), 6] = 0 

  #sort by ascending order
  df_a2z = df_a2z.sort_values (by=0)
  #extract a list of order
  list_order = df_a2z.loc[df_a2z[0]!='a', 0].to_numpy ()
  list_order = list_order.astype (np.int_)
  order = np.unique (list_order)
  order = order[order > np.amin (order)] #TODO if there is no l=2 or l=3 this line will ignore the first order...

  if order_to_fit is None :
    order_to_fit = order

  print ('Orders to fit')
  print (np.intersect1d (order, order_to_fit))

  for n in np.intersect1d (order, order_to_fit) :
    if strategy=='order' :
      print ('Fitting on order', n)
      df_a2z.loc[(df_a2z[0]==str(n))&(df_a2z[1]=='0'), 6] = 0
      df_a2z.loc[(df_a2z[0]==str(n))&(df_a2z[1]=='1'), 6] = 0
      df_a2z.loc[(df_a2z[0]==str(n-1))&(df_a2z[1]=='2'), 6] = 0
      df_a2z.loc[(df_a2z[0]==str(n-1))&(df_a2z[1]=='3'), 6] = 0
      df_a2z.loc[(df_a2z[0]==str(n-1))&(df_a2z[1]=='4'), 6] = 0
      df_a2z.loc[(df_a2z[0]==str(n))&(df_a2z[1]=='a'), 6] = 0

      #automatic determination of low and up bound for the window over which the 
      #fit is realised.
      frequencies = df_a2z.loc[(df_a2z[6]==0)&(df_a2z[2]=='freq'), 4].to_numpy ()
      low_freq = np.amin (frequencies)
      up_freq = np.amax (frequencies)
      gap = up_freq - low_freq
      low_bound = low_freq - gap/3.
      up_bound = up_freq + gap/3.

      if store_chains :
        #designing the filename of the hdf5 file that will be used to store the mcmc chain. 
        if len (str (n)) == 1 :
          filename = 'mcmc_sampler_order_0' + str(n) + '.h5'
        else :
          filename = 'mcmc_sampler_order_' + str(n) + '.h5'
        filename = path.join (mcmcDir, filename)
        print ('Chain will be saved at:', filename)
      else :
        filename = None
      centiles = wrap_explore_distribution (df_a2z, freq, psd, back, 
                                 low_bound_freq=low_bound, up_bound_freq=up_bound, wdw=wdw, nsteps=nsteps_mcmc,
                                 show_corner=show_corner, filename=filename, parallelise=parallelise,
                                 progress=progress, nwalkers=nwalkers, normalise=normalise, instr=instr,
                                 coeff_discard=coeff_discard, use_sinc=use_sinc, asym_profile=asym_profile)
      if centiles is None :
        print (filename + ' already exists, no sampling has been performed, proceeding to next step.') 
      else :
        df_a2z = update_with_centiles (df_a2z, centiles)
        print ('Ensemble sampling achieved')
      # --------------------------------------------------------------------------------------------
  
      df_a2z.loc[:,6] = 1
      if fit_splittings :
        df_a2z.loc[(df_a2z[2]=='split')&(df_a2z[0]=='a'), 6] = 0 

    if strategy=='pair' : 
      print ('Fitting on order', n)

      if fit_02 :

        print ('Fitting degrees 0 and 2')

        df_a2z.loc[(df_a2z[0]==str(n))&(df_a2z[1]=='0'), 6] = 0
        df_a2z.loc[(df_a2z[0]==str(n-1))&(df_a2z[1]=='2'), 6] = 0

        #automatic determination of low and up bound for the window over which the 
        #fit is realised.
        frequencies = df_a2z.loc[(df_a2z[6]==0)&(df_a2z[2]=='freq'), 4].to_numpy ()
        low_freq = np.amin (frequencies)
        up_freq = np.amax (frequencies)
        gap = up_freq - low_freq
        low_bound = low_freq - gap
        up_bound = up_freq + gap

        if store_chains :
          #designing the filename of the hdf5 file that will be used to store the mcmc chain. 
          if len (str (n)) == 1 :
            filename = 'mcmc_sampler_order_0' + str(n) + '_degrees_02.h5'
          else : 
            filename = 'mcmc_sampler_order_' + str(n) + '_degrees_02.h5'
          filename = path.join (mcmcDir, filename)
          print ('Chain will be saved at:', filename)
        else :
          filename = None
        centiles = wrap_explore_distribution (df_a2z, freq, psd, back, 
                                   low_bound_freq=low_bound, up_bound_freq=up_bound, wdw=wdw, nsteps=nsteps_mcmc,
                                   show_corner=show_corner, filename=filename, parallelise=parallelise,
                                   progress=progress, nwalkers=nwalkers, normalise=normalise, instr=instr, 
                                   coeff_discard=coeff_discard, use_sinc=use_sinc, asym_profile=asym_profile)
        if centiles is None :
          print (filename + ' already exists, no sampling has been performed, proceeding to next step.') 
        else :
          df_a2z = update_with_centiles (df_a2z, centiles)
          print ('Ensemble sampling achieved')
        # --------------------------------------------------------------------------------------------

        # Fixing again all parameters
        df_a2z.loc[:,6] = 1
        if fit_splittings :
          df_a2z.loc[(df_a2z[2]=='split')&(df_a2z[0]=='a'), 6] = 0 

      if fit_13 :

        print ('Fitting degrees 1 and 3')
        df_a2z.loc[(df_a2z[0]==str(n))&(df_a2z[1]=='1'), 6] = 0
        df_a2z.loc[(df_a2z[0]==str(n-1))&(df_a2z[1]=='3'), 6] = 0

        #automatic determination of low and up bound for the window over which the 
        #fit is realised.
        frequencies = df_a2z.loc[(df_a2z[6]==0)&(df_a2z[2]=='freq'), 4].to_numpy ()
        low_freq = np.amin (frequencies)
        up_freq = np.amax (frequencies)
        gap = up_freq - low_freq
        low_bound = low_freq - gap
        up_bound = up_freq + gap

        if store_chains :
          #designing the filename of the hdf5 file that will be used to store the mcmc chain. 
          if len (str (n)) == 1 :
            filename = 'mcmc_sampler_order_0' + str(n) + '_degrees_13.h5'
          else :
            filename = 'mcmc_sampler_order_' + str(n) + '_degrees_13.h5'
          filename = path.join (mcmcDir, filename)
          print ('Chain will be saved at:', filename)
        else :
          filename = None
        centiles = wrap_explore_distribution (df_a2z, freq, psd, back, 
                                   low_bound_freq=low_bound, up_bound_freq=up_bound, wdw=wdw, nsteps=nsteps_mcmc,
                                   show_corner=show_corner, filename=filename, parallelise=parallelise,
                                   progress=progress, nwalkers=nwalkers, normalise=normalise, instr=instr,
                                   coeff_discard=coeff_discard, use_sinc=use_sinc, asym_profile=asym_profile)
        if centiles is None :
          print (filename + ' already exists, no sampling has been performed, proceeding to next step.') 
        else :
          df_a2z = update_with_centiles (df_a2z, centiles)
          print ('Ensemble sampling achieved')
        # --------------------------------------------------------------------------------------------
  
        df_a2z.loc[:,6] = 1
        if fit_splittings :
          df_a2z.loc[(df_a2z[2]=='split')&(df_a2z[0]=='a'), 6] = 0 

  # show result
  param_result = a2z_to_pkb (df_a2z)
  plot_from_param (param_result, freq, psd, back=back, wdw=wdw, smoothing=10, spectro=spectro, correct_width=1.,
                   show=show, filename=filename_summary, instr=instr)

  return df_a2z







