from typing import Literal, Tuple

from aldepyde.data import get_store, DataStore
from aldepyde.biomolecule import Residue, Atom

# Determine if a residue is DNA, RNA, or an amino acid
# 1. Check atom list for backbone
# 2. Check sidechain for necessary atoms
# 3. Parse name
def residue_type(residue:Residue) -> Literal['amino_acid', 'dna', 'rna']:
    res_atoms = residue.dump_atom_labels(lower=True)
    telling_nucleotide = ["c1'", "c2'", "c3'", "c4'", "c5'", "o3'", "o5'", "p", "op1", "op2"]
    telling_rna = ["o2'"]
    telling_aa = ["c", "n", "o"]

    if len([x for x in res_atoms if x in telling_rna]) > 0:
        return 'rna'
    elif len([x for x in res_atoms if x in telling_aa]) > 0:
        return 'amino_acid'

    # This is technically robust by PDB/mmCIF standards
    if len(residue.name) == 3:
        return 'amino_acid'
    elif len(residue.name) == 2:
        return 'dna'
    elif len(residue.name) == 1:
        return 'rna'

def centroid(residue:Residue, res_type) -> Tuple[float, float, float] | None:
    centroid_atoms = get_store().centroid_atoms(residue.name, res_type=res_type)
    included_atoms = residue.get_atoms_by_label(centroid_atoms)
    try:
        x = sum([a.x for a in included_atoms]) / len(included_atoms)
        y = sum([a.y for a in included_atoms]) / len(included_atoms)
        z = sum([a.z for a in included_atoms]) / len(included_atoms)
        return (x, y, z)
    except ZeroDivisionError:
        return None

# TODO This method will surprisingly be a massive pain. For now only use if safely pulled from a file or provided
# because wouldn't you know... there's no PDB standard apparently!
def derive_element(atom:Atom) -> str|None:
    return atom.element.title()

def get_element_mass(symbol:str) -> float:
    return get_store().atomic_mass(symbol)


def center_of_mass(residue:Residue, res_type) -> Tuple[float, float, float] | None:
    centroid_atoms = get_store().centroid_atoms(residue.name, res_type=res_type)
    included_atoms = residue.get_atoms_by_label(centroid_atoms)
    point_tups = []
    for atom in included_atoms:
        point_tups.append((atom, get_element_mass(derive_element(atom))))
    try:
        x = sum([a[0].x * a[1] for a in point_tups]) / sum([x[1] for x in point_tups])
        y = sum([a[0].y * a[1] for a in point_tups]) / sum([x[1] for x in point_tups])
        z = sum([a[0].z * a[1] for a in point_tups]) / sum([x[1] for x in point_tups])
        return (x, y, z)
    except ZeroDivisionError:
        return None

