# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/06_sleep.ipynb.

# %% auto 0
__all__ = ['TwoProcessModel', 'sleep_midpoint', 'cluster_sleep_periods_scipy']

# %% ../nbs/06_sleep.ipynb 2
from nbdev import show_doc

# %% ../nbs/06_sleep.ipynb 3
import numpy as np
import pandas as pd
import scipy as sp
from numba import jit
from scipy import interpolate
import torch 
import json
import pytz
import datetime
import copy
import scipy as sp
from scipy.integrate import solve_ivp
from math import *
import pylab as plt

from scipy.optimize import minimize
from .utils import *
from .models import * 
from .lights import *
from .plots import *

# %% ../nbs/06_sleep.ipynb 5
class TwoProcessModel:

    def __init__(self, steps_wake_threshold: float = 10.0):
        self.steps_wake_threshold = steps_wake_threshold
        self.awake = True

    @staticmethod
    def check_wake_status(awake: bool, # current wake status
                          h: float,  # homeostat value
                          phase: float #circadian phase
                          ) -> bool:

        c = np.cos(phase)
        H_minus = 0.17
        H_plus = 0.6
        homeostat_a = 0.10

        upper = (H_plus + homeostat_a * c)
        lower = (H_minus + homeostat_a * c)
        above_threshold = h > upper
        below_threshold = h <= lower

        if above_threshold:
            return False
        elif below_threshold:
            return True
        else:
            return awake

    def dhomeostat(self, 
                   homeostat: float, # homeostat value
                   steps: float, # steps value
                   phase: float, #circadian phase
                   ) -> float:

        h = homeostat[0]
        tau_s = 4.2  # hours
        tau_w = 18.2  # hours
        mu_s = 1
        
        step_awake = (steps > self.steps_wake_threshold) or self.awake
        dH = (mu_s - h) / tau_w if step_awake else -h / tau_s
        self.awake = TwoProcessModel.check_wake_status(
            self.awake, 
            h, 
            phase
            )
        return np.array([dH]) 
    
    # implement a RK4 solver
    def step_rk4(self,
                 state: np.ndarray,
                 steps: float,
                 phase: float,
                 dt=0.10):

        k1 = self.dhomeostat(state, steps, phase) 
        k2 = self.dhomeostat(state + k1 * dt / 2.0, steps, phase)
        k3 = self.dhomeostat(state + k2 * dt / 2.0, steps, phase)
        k4 = self.dhomeostat(state + k3 * dt, steps, phase)
        state = state + (dt / 6.0) * (k1 + 2.0*k2 + 2.0*k3 + k4)
        return state

    def __call__(self, 
                 ts: np.ndarray,
                 phase: np.ndarray,
                 steps: np.ndarray,
                 initial_value: np.array = np.array([0.50]) #  initial value for the homeostat
                 ) -> DynamicalTrajectory:
        sol = np.zeros((len(initial_value), len(ts)))
        current_state = initial_value
        sol[:,0] = current_state
        for idx in range(1,len(ts)):
            current_state = self.step_rk4(current_state, steps[idx], phase[idx])
            sol[:,idx] = current_state
        return(DynamicalTrajectory(ts, np.array(sol)))

# %% ../nbs/06_sleep.ipynb 7
def sleep_midpoint(timetotal: np.ndarray, 
                   Wake: np.ndarray, 
                   durations=True):
    """
        Given a wearable data frame with a Wake column which takes 
        the values 0, 1, missing this routine will create a sleep phase 
        column which is based on constant phase accumulation between sleep 
        midpoints. 

        The sleep midpoints are found using the criteria than they the median 
        time where 

    """

    sleep_start = []
    sleep_end = []
    awake = Wake[0] > 0.50

    if not awake:
        sleep_start.append(timetotal[1])

    for k in range(1, len(Wake)):
        if (Wake[k] > 0.50 and not awake):
            awake = True
            sleep_end.append(timetotal[k])

        if (Wake[k] <= 0.50 and awake):
            awake = False
            sleep_start.append(timetotal[k])

    if Wake[-1] <= 0.50:
        sleep_end.append(timetotal[-1])

    assert len(sleep_start) == len(sleep_end)
    sleep_midpoints = []
    sleep_durations = []
    for (s1, s2) in zip(sleep_start, sleep_end):
        sleep_midpoints += [(s2-s1)/2+s1]
        sleep_durations += [s2-s1]

    if durations:
        return np.array(sleep_midpoints), np.array(sleep_durations)
    else:
        return np.array(sleep_midpoints)



# %% ../nbs/06_sleep.ipynb 8
def cluster_sleep_periods_scipy(wake_data: np.ndarray, 
                                epsilon: float,
                                makeplot: bool = False,
                                max_sleep_clusters=None, 
                                min_sleep_clusters=None):
    """
        Given a binary vector wake_data which gives a prediction for the sleep/wake  
        status and a regularization penalty ε this function will create smoothed 
        sleep-wake periods. This helps as preprocessing to remove erroneous short sleep 
        periods (and wake) which may mess up calculations like the sleep midpoint for 
        the day

        cluster_sleep_periods(wake_data : np.ndarray, epsilon: float, makeplot: bool=False):
    """

    np.nan_to_num(wake_data, 0.50)

    def objective(w):
        return sum(w * (1 - wake_data)) + sum((1 - w)*wake_data) + epsilon*sum((w[1:]-w[0:-1])**2)

    max_clusters = max_sleep_clusters or len(wake_data)

    def constraint1(x):
        return max_clusters-sum((x[1:]-x[0:-1])**2)  # geq 0

    min_clusters = min_sleep_clusters or 0

    print(
        f"The max clusters are {max_clusters} and the min clusters are {min_clusters}")

    def constraint2(x):
        return sum((x[1:]-x[0:-1])**2)-min_clusters  # geq 0

    bnds = (0.0, 1.0)
    all_bnds = [bnds for b in range(len(wake_data))]

    constraint1d = {'type': 'ineq', 'fun': constraint1}
    constraint2d = {'type': 'ineq', 'fun': constraint2}
    all_cons = [constraint1d, constraint2d]

    x0 = wake_data
    sol = minimize(objective, x0, method='SLSQP', bounds=all_bnds)

    if makeplot:
        pl = plt.scatter(range(len(wake_data)), wake_data + 0.1 *
                         np.random.randn(len(wake_data)), label="", color="blue")
        plt.plot(range(len(wake_data)), convert_binary(
            sol.x), lw=2.0, label="", color="red")
        plt.show()

    print(
        f"The max clusters are {max_clusters} takes value {constraint1(sol.x)}>=0.0")
    print(
        f"The min clusters are {min_clusters} and takes the value {constraint2(sol.x)}>=0.0")
    return(convert_binary(sol.x))
