from typing import List, Tuple, Optional, Dict
from dataclasses import dataclass, field
import math


from aldepyde.data import get_store

@dataclass
class Atom():
    serial: str
    name: str
    coord: Tuple[float, float, float]
    occupancy: float
    b_factor: float
    is_het: bool
    element: Optional[str]
    charge: Optional[str]
    altloc: Optional[str]
    pseudo: bool = False

    @property
    def location(self):
        return self.coord

    @property
    def x(self):
        return self.coord[0]

    @property
    def y(self):
        return self.coord[1]

    @property
    def z(self):
        return self.coord[2]


    def distance_to(self, other:'Atom') -> float:
        return math.sqrt((self.location[0] - other.location[0])**2 + (self.location[1] - other.location[1])**2+ (self.location[2] - other.location[2])**2)


@dataclass
class Residue():
    name: str
    id: str
    ins_code: Optional[str]
    atoms: List[Atom] = field(default_factory=list)
    residue_type: str|None = None

    def __post_init__(self):
        from aldepyde.Ops.residue_ops import residue_type
        self.residue_type = residue_type(self)

    def add_atom(self, atom: Atom) -> None:
        self.atoms.append(atom)

    def dump_atom_labels(self, lower=False) -> List[str]:
        if lower:
            return [a.name.lower() for a in self.atoms]
        else:
            return [a.name for a in self.atoms]

    def get_atoms_by_label(self, labels:str|tuple) -> Tuple[Atom,...]:
        if isinstance(labels, str):
            labels = tuple(labels.split())
        atoms = list()
        for atom in self.atoms:
            if atom.name.lower() in labels:
                atoms.append(atom)
        return tuple(atoms)

    def centroid(self, as_atom=False, as_residue_type=None) -> Tuple[float, float, float]|Atom:
        from aldepyde.Ops.residue_ops import centroid
        if as_residue_type is not None:
            res_type = as_residue_type
        else:
            res_type = self.residue_type
        if not as_atom:
            return centroid(self, res_type)
        else:
            return Atom("-1", "Dum", centroid(self, res_type), occupancy=0, b_factor=0, is_het=True, element=None, charge=None, altloc=None)

    def center_of_mass(self, as_residue_type=None, as_atom=False):
        from aldepyde.Ops.residue_ops import center_of_mass
        if as_residue_type is not None:
            res_type = as_residue_type
        else:
            res_type = self.residue_type
        if not as_atom:
            return center_of_mass(self, res_type)
        else:
            return Atom("-1", "Dum", center_of_mass(self, res_type), occupancy=0, b_factor=0, is_het=True, element=None,
                        charge=None, altloc=None)

    def yield_atoms(self):
        for atom in self.atoms:
            yield atom

    def dump_atoms(self):
        return [a for a in self.yield_atoms()]

    def __str__(self):
        return f"{self.name}:{self.id}"

    def __repr__(self):
        return f"{self.name}:{self.id}"

@dataclass
class Chain():
    id: str
    residues: Dict[str, Residue] = field(default_factory=dict)

    def add_residue(self, residue: Residue):
        self.residues[residue.id] = residue
        # self.residues.append(residue)

    def get_residue(self, id):
        try:
            return self.residues[id]
        except KeyError:
            return self.residues[str(id)]

    def yield_atoms(self):
        for residue in self.residues.values():
            yield from residue.yield_atoms()

    def dump_atoms(self):
        return [a for a in self.yield_atoms()]

    def __str__(self):
        return self.id

    def __repr__(self):
        return self.id

@dataclass
class Model:
    id: str
    chains: Dict[str, Chain] = field(default_factory=dict)

    def add_chain(self, chain: Chain) -> None:
        self.chains[chain.id] = chain
        # self.chains.append(chain)

    def get_chain(self, id) -> Chain:
        try:
            return self.chains[id]
        except KeyError:
            return self.chains[str(id)]

    def yield_atoms(self):
        for chain in self.chains.values():
            yield from chain.yield_atoms()

    def dump_atoms(self):
        return [a for a in self.yield_atoms()]

    def __str__(self):
        return self.id

    def __repr__(self):
        return self.id

@dataclass
class Structure:
    id: str = "Unnamed Structure"
    # title: str = ""
    models: Dict[str, Model] = field(default_factory=dict)

    def add_model(self, model: Model) -> None:
        self.models[model.id] = model
        # self.models.append(model)

    def get_model(self, id) -> Model:
        try:
            return self.models[id]
        except KeyError:
            return self.models[str(id)]

    def yield_atoms(self):
        for model in self.models.values():
            yield from model.yield_atoms()

    def dump_atoms(self):
        return [a for a in self.yield_atoms()]

    def __str__(self):
        return self.id

    def __repr__(self):
        return self.id