from abc import ABC, abstractmethod
from typing import Iterable, List

from .polymer_specs import PolymerClassifier
from .polymer_constraints import PolymerConstraints
import random
import os
from aldepyde.distribution.distribution import Distribution, load
from aldepyde.distribution.distributions.standards import distribution_head
import json
import math
from itertools import combinations

class ResidueGenerator(ABC):
    def __init__(self, classifier:PolymerClassifier):
        self.classifier = classifier

    def _generate(self):
        pass

    # Just pure, random generation. Nothing fancy
    def random(self, length:int, n:int=1) -> list:
        alphabet = list(self.classifier.alphabet)
        sequences = []
        for i in range(n):
            sequences.append("".join(random.choices(alphabet, k=length)))
        return sequences

    def _list_remove(self, l, element) -> list:
        new_list = []
        for val in l:
            if val != element:
                new_list.append(val)
        return new_list


    # Actual magic
    def _gosper(self, n, k):
        if k < 0 or k > n:
            return
        x = (1 << k) - 1
        limit = 1 << n
        while x < limit:
            yield format(x, f"0{n}b")
            c = x & -x
            r = x + c
            x = (((r ^ x) >> 2) // c) | r


    def possible_mutations(self, sequence:str|Iterable[str], n=1, mutable:Iterable|None=None, alphabet:Iterable|None=None, strict:bool=False):
        r_alphabet = list(self.classifier.alphabet) if alphabet is None else alphabet
        mutable_sequence = sequence if mutable is None else "".join([x for x in sequence if x in mutable])
        total = 0
        for binary in self._gosper(len(mutable_sequence), n):
            if not binary.count('1') == n:
                continue
            subproduct = 1
            for index in [i for i, digit in enumerate(binary) if digit == '1']:
                subproduct *= len(r_alphabet) if mutable_sequence[index] not in r_alphabet else len(r_alphabet) - 1
            total += subproduct
        return total


    def random_mutation(self, sequence:str, k:int=1, n:int=1, mutable:Iterable|None=None, alphabet:Iterable|None=None, strict:bool=False) -> List[str]:
        r_alphabet = list(self.classifier.alphabet) if alphabet is None else alphabet
        # Assume an iterable
        rets = set()
        allowed_indeces = range(0, len(sequence)) if mutable is None else [x for x, val in enumerate(sequence) if
                                                                           val in mutable]
        max_mutations = self.possible_mutations(sequence, n=n, mutable=mutable, alphabet=alphabet)
        target_number = min([k, max_mutations])
        while len(rets) < target_number:
            s = list(sequence)
            if strict:
                k_n = n
            else:
                k_n = n if n <= len(allowed_indeces) else len(allowed_indeces)
            for index in  random.sample(allowed_indeces, k_n):
                new_res = random.choice(self._list_remove(r_alphabet, s[index]))
                s[index] = new_res
            rets.add("".join(s))
        return list(rets)

    def random_from_distribution(self, distribution: os.PathLike|str|dict[str, float]|Distribution, length:int|tuple[int,int], n:int=1, exclude:tuple=()) -> list:
        if isinstance(distribution, str) and os.path.isfile(distribution): # Given a file
            with open(distribution) as fp:
                data = json.load(fp)
                if distribution_head not in data.keys():
                    raise ValueError(f"Key '{distribution_head}' must be in the top level of your provided file")
            usable_distribution = Distribution(None, data[distribution_head])
        elif isinstance(distribution, str) and not os.path.isfile(distribution): # Given a preload
            usable_distribution = load(distribution)
        elif isinstance(distribution, Distribution): # Given a distribution
            usable_distribution = distribution
        elif isinstance(distribution, dict): # Given a dictionary
            usable_distribution = Distribution(None, distribution)
        else:
            raise ValueError('Invalid distribution')
        for r in exclude:
            usable_distribution.eliminate_entry(r)
        if not isinstance(length, int):
            size = lambda: random.randint(length[0], length[1])
        else:
            size = lambda: length
        usable_distribution.normalize_map()
        sequences = []
        for _ in range(n):
            sequence = random.choices(population=list(usable_distribution.frequency_map.keys()), weights=list(usable_distribution.frequency_map.values()), k=size())
            sequences.append("".join(sequence))
        return sequences

    def random_mutation_from_distribution(self, sequences:str|Iterable[str], distribution: os.PathLike|str|dict[str, float]|Distribution, k:int=1, exclude:tuple=()) -> list:
        if isinstance(sequences, str):
            all_values = [sequences]
        else:
            all_values = sequences

        if isinstance(distribution, str) and os.path.isfile(distribution): # Given a file
            with open(distribution) as fp:
                data = json.load(fp)
                if distribution_head not in data.keys():
                    raise ValueError(f"Key '{distribution_head}' must be in the top level of your provided file")
            usable_distribution = Distribution(None, data[distribution_head])
        elif isinstance(distribution, str) and not os.path.isfile(distribution): # Given a preload
            usable_distribution = load(distribution)
        elif isinstance(distribution, Distribution): # Given a distribution
            usable_distribution = distribution
        elif isinstance(distribution, dict): # Given a dictionary
            usable_distribution = Distribution(None, distribution)
        else:
            raise ValueError('Invalid distribution')
        for r in exclude:
            usable_distribution.eliminate_entry(r)
        usable_distribution.normalize_map()

        rets = []
        for sequence in all_values:
            s = list(sequence)
            for index in  random.sample(range(0, len(sequence)), k):
                while True:
                    selection = random.choices(population=list(usable_distribution.frequency_map.keys()),
                                   weights=list(usable_distribution.frequency_map.values()), k=1)[0]
                    if selection != s[index]:
                        s[index] = selection
                        break
            rets.append("".join(s))
        return rets


    @abstractmethod
    def generate(self, length:int, k:int, constraints:PolymerConstraints) -> List[str]:
        expected_per_class = constraints.expected_per_class(length)
        class_distinctions = constraints.classes
        results = []
        for _ in range(k):
            new_sequence = []
            for key in expected_per_class:
                new_sequence += random.choices(tuple(class_distinctions[key]), k=expected_per_class[key])
            random.shuffle(new_sequence)
            results.append("".join(new_sequence))
        return results
        # for _ in range(k):
        #     for key in expected:

