import math
from abc import ABC, abstractmethod
from dataclasses import dataclass


@dataclass
class FrequencyClass():
    class_name:str
    class_members:str
    frequency:float

# Define constraints when generating a polymer
class PolymerConstraints(ABC):
    def __init__(self, classes:dict = None, charge:int=0, charge_tolerance:int=0, ratio_tolerance:int=0):
        self.classes = classes
        self.constraint_ratios = self.new_constraint_ratios()
        self.charge = charge
        self.charge_tolerance = charge_tolerance
        self.ratio_tolerance = ratio_tolerance


    @abstractmethod
    def get_default_classes(self):
        pass

    def new_constraint_ratios(self) -> dict:
        return self.normalize_table({k:1 for k in self.classes})

    def set_constraint_classes(self, table:dict):
        self.classes = table

    def normalize_table(self, table:dict) -> dict:
        non_zero = [x for x,v in self.classes.items() if len(v) > 0]
        total = sum([table[x] for x in non_zero])
        return {k:(v/total) if len(self.classes.get(k)) > 0 else 0 for k, v in table.items() }

    def normalize(self):
        self.constraint_ratios = self.normalize_table(self.constraint_ratios)

    def set_constraint_ratios(self, table:dict):
        self.constraint_ratios = self.normalize_table(table)

    def set_class_ratio(self, class_name:str, value:float, normalize:bool=False):
        self.constraint_ratios[class_name] = value
        if normalize:
            self.constraint_ratios = self.normalize_table(self.constraint_ratios)

    def hard_set(self, class_name:str, percentage:float):
        if percentage > 1:
            raise ValueError("Percentage must be less than 1.0")
        self.constraint_ratios[class_name] = 0
        self.normalize()
        self.constraint_ratios = {k:v*(1-percentage) for k,v in self.constraint_ratios.items()}
        self.constraint_ratios[class_name] = percentage

    def combine(self, class_1:str, class_2:str, new_class:str, ratio:float=-1):
        s1 = self.classes.pop(class_1)
        s2 = self.classes.pop(class_2)
        self.constraint_ratios.pop(class_1)
        self.constraint_ratios.pop(class_2)
        self.constraint_ratios[new_class] = ratio
        self.classes[new_class] = frozenset().union(s1, s2)
        if ratio == -1:
            self.wipe()

    def alphabet(self) -> set:
        return set().union(*[val for val in self.classes.values()])

    #TODO Get this working

    # def remove(self, name:str):
    #     self.classes.pop(name)
    #     self.constraint_ratios = self.normalize_table(self.constraint_ratios)

    def negate(self, name:str):
        self.constraint_ratios[name] = 0
        self.constraint_ratios = self.normalize_table(self.constraint_ratios)

    def wipe(self):
        self.constraint_ratios = {k:1 for k in self.constraint_ratios}

    def clear_all(self):
        self.constraint_ratios = {k:0 for k in self.constraint_ratios}

    def expected_per_class(self, length):
        constraint_ratios = self.normalize_table(self.constraint_ratios)
        pairs = sorted([[k,length*v] for (k, v) in constraint_ratios.items()], key=lambda x: x[1]%1, reverse=True)
        floored = [[x[0], math.floor(x[1])] for x in pairs]
        total_residues = sum([x[1] for x in floored])
        index = 0
        while total_residues < length:
            floored[index][1] += 1
            total_residues += 1
            index += 1
        return {k:math.floor(v) for (k, v) in floored}
