from importlib.resources import files
from typing import Literal
import random
import json
from aldepyde.data.data_exceptions import AmbiguousTokenException, DataAccessException, BadCategoryException

_PACKAGE = 'aldepyde.data'

#TODO Clean up your exceptions. This is a tad excessive as is
class DataStore():
    def __init__(self):
        path = files(_PACKAGE) / f"chemistry.json"
        self._raw = json.loads(path.read_text(encoding='utf-8'))

    def get_raw(self):
        return self._raw

    def get(self, *args:str, default=None):
        current_level = self.get_raw()
        try:
            for arg in args:
                current_level = current_level[arg]
            return current_level
        except KeyError as e:
            return default

    def require(self, *args:str):
        current_level = self._raw
        attempted = []
        try:
            for arg in args:
                attempted.append(arg)
                current_level = current_level[arg]
            return current_level
        except KeyError as e:
            raise DataAccessException(f"\n\nInvalid data access path: {' | '.join(attempted)}\nValid "
                                      f"keys at this level are:\n {', '.join(list(current_level.keys()))}")

    def has(self, *args:str) -> bool:
        current_level = self._raw
        try:
            for arg in args:
                current_level = current_level[arg]
            return True
        except KeyError as e:
            return False

    def is_value(self, *args:str) -> bool:
        current_level = self._raw
        attempted = []
        try:
            for arg in args:
                attempted.append(arg)
                current_level = current_level[arg]
            return not isinstance(current_level, dict)
        except KeyError as e:
            raise DataAccessException(f"\n\nInvalid data access path: {' | '.join(attempted)}\nValid "
                                      f"keys at this level are:\n {', '.join(list(current_level.keys()))}")

    def is_key(self, *args:str) -> bool:
        return not self.is_value(*args)

    def keys(self, *args:str, strict=True) -> list | None:
        current_level = self._raw
        attempted = []
        try:
            for arg in args:
                attempted.append(arg)
                current_level = current_level[arg]
            if isinstance(current_level, dict):
                return list(current_level.keys())
            else:
                return None
        except KeyError as e:
            raise DataAccessException(f"\n\nInvalid data access path: {' | '.join(attempted)}\nValid "
                                      f"keys at this level are:\n {', '.join(list(current_level.keys()))}")

    def all_residues(self, res_type) -> list[str]:
        category = self.categorize(res_type)
        return list(self.require(category).keys())


    canon_aa = 'amino_acid'
    canon_dna = 'dna'
    canon_rna = 'rna'
    canon_ele = 'element'
    categories = {
        'aa' : canon_aa,
        'a' : canon_aa,
        'amino acid' : canon_aa,
        'amino_acid' : canon_aa,
        'dna' : canon_dna,
        'd' : canon_dna,
        'rna' : canon_rna,
        'r' : canon_rna,
        'element' : canon_ele,
        'atom' : canon_ele,
        'e' : canon_ele
    }

    def _category_exception_grouper(self):
        d = {}
        for k, v in self.categories.items():
            d[v] = d.get(v, []) + [k]
        ret_str = ""
        for k, v in d.items():
            ret_str += f"\t{v} -> {k}\n"
        return ret_str


    def categorize(self, res_type:str) -> str:
        try:
            normalized = self.categories[res_type.lower().strip()]
            return normalized
        except KeyError:
            raise BadCategoryException(f"\n\nUnresolvable residue type {res_type}\nAvailable residue types:\n{self._category_exception_grouper()}")


    def normalize_residue(self, name:str, res_type:str='aa') -> str | None:
        if res_type is None:
            return None
        normalized = self.require('map', self.categorize(res_type), name.lower().strip())
        if isinstance(normalized, list):
            raise AmbiguousTokenException(f'\n\nUnable to resolve normalize name: {name} to a single residue\n{name} -> {normalized}\n'
                                          f'Use resolve_residues() when ambiguous tokens are expected.')
        return normalized


    def resolve_residues(self, name:str, res_type:str='aa')->tuple:
        normalized = self.require('map', self.categorize(res_type), name.lower().strip())
        if isinstance(normalized, list):
            return tuple(normalized)
        return (normalized,)

    def get_entry(self, name, res_type='aa', strict=True, resolve_ambiguous:Literal['none', 'first', 'last', 'random']="none") -> dict|list[dict]|None:
        category = self.categorize(res_type)
        if strict:
            name = self.normalize_residue(name, category)
            return self.require(category, name)
        else:
            name = self.resolve_residues(name, category)
            if len(name) == 1:
                return self.get(category, name[0])
            elif len(name) == 0:
                return None
            if resolve_ambiguous == 'none':
                return self.get(category, name[0])
            elif resolve_ambiguous == 'first':
                return self.get(category, name[0])
            elif resolve_ambiguous == 'last':
                return self.get(category, name[-1])
            elif resolve_ambiguous == 'random':
                return self.get(category, name[random.randint(0, len(name) - 1)])
            else:
                raise ValueError(f"Invalid value {resolve_ambiguous} for parameter resolve_ambiguous=['none', 'first', 'last', 'random']")

    def get_entries(self, name, res_type='aa') -> tuple[dict]:
        category = self.categorize(res_type)
        name = self.resolve_residues(name, category)
        ret_lst = []
        for n in name:
            ret_lst.append(self.get(category, n))
        return tuple(ret_lst)


    def _map(self, category):
        return self.require('map', self.categories[category])

    def is_aa(self, name:str) -> bool:
        return name.lower() in self.require('map', 'amino_acid')

    def is_dna(self, name:str) -> bool:
        return name.lower() in self.require('map', 'dna')

    def is_rna(self, name:str) -> bool:
        return name.lower() in self.require('map', 'rna')

    def is_element(self, name:str) -> bool:
        return name.lower() in self.require('map', 'element')

    # What res_types does a token exist in?
    def possible_categories(self, token):
        return [key for key in self.require('map') if token.lower() in self.require('map')[key].keys()]

    def get_all_possible_entries(self, name, strict=True) -> tuple:
        entries = list()
        for res_type in self.possible_categories(name):
            current_name = self.normalize_residue(name, res_type=res_type)
            if strict:
                entries.append(self.require(res_type, current_name))
            else:
                entries.append(self.get(res_type, current_name))
        return tuple(entries)

    def one_letter(self, res, res_type='aa', strict=True):
        return self.get_entry(name=res, res_type=res_type, strict=strict)['1code']

    def dna_pdb_code(self, res, strict=True):
        return self.get_entry(res, self.canon_dna, strict=strict)['pdb_code']

    def three_letter(self, res, strict=True):
        return self.get_entry(res,self.canon_aa, strict=strict)['3code']

    def atom_list(self, res, res_type='aa', strict=True):
        return self.get_entry(name=res, res_type=res_type, strict=strict)['atoms']

    def centroid_atoms(self, res, res_type='aa', strict=True) -> list:
        return self.get_entry(name=res, res_type=res_type, strict=strict)['centroid']

    def formula(self, res, res_type='aa', strict=True):
        return self.get_entry(name=res, res_type=res_type, strict=strict)['formula']

    # TODO Handle unnatural elements
    def atomic_mass(self, name):
        return self.get_entry(name=name, res_type=self.canon_ele, strict=True)['molar mass']

    def atomic_number(self, name):
        return self.get_entry(name=name, res_type=self.canon_ele,strict=True)['number']

    def citations(self):
        return self.require("Sources")

    def thanks(self):
        return self.require("Special Thanks and Acknowledgment")


_STORE = None

def get_store():
    global _STORE
    if _STORE is None:
        _STORE = DataStore()
    return _STORE

def del_store():
    global _STORE
    _STORE = None