__author__ = "Vini Salazar"
__license__ = "MIT"
__maintainer__ = "Vini Salazar"
__url__ = "https://github.com/vinisalazar/bioprov"
__version__ = "0.1.6"


"""
Contains the File class and related functions.
"""
import numpy as np
from dataclasses import dataclass
from pathlib import Path
from Bio import SeqIO, AlignIO
from bioprov.utils import get_size, Warnings, serializer
from prov.model import ProvEntity


class File:
    """
    Class for holding file and file information.
    """

    def __init__(self, path, tag=None, document=None, attributes=None):
        """
        :param path: A UNIX-like file _path.
        :param tag: optional tag describing the file.
        :param document: prov.model.ProvDocument
        :param attributes: Miscellaneous attributes.
        """
        self.path = Path(path).absolute()
        self.name = self.path.stem
        self.basename = self.path.name
        self.directory = self.path.parent
        self.extension = self.path.suffix
        if tag is None:
            tag = self.name
        if attributes is None:
            attributes = {}
        self.tag = tag
        self.attributes = attributes
        self._exists = self.path.exists()
        self.size = get_size(self.path)
        self.raw_size = get_size(self.path, convert=False)

        # Provenance attributes
        self._document = document
        self._entity = None

    def __repr__(self):
        return str(self.path)

    def __str__(self):
        return self.__repr__()

    @property
    def exists(self):
        return self.path.exists()

    @exists.setter
    def exists(self, value):
        self._exists = value

    @property
    def size(self):
        return self._size

    @size.setter
    def size(self, value):
        self._size = value

    @property
    def document(self):
        return self._document

    @document.setter
    def document(self, document):
        self._document = document

    @property
    def entity(self):
        if self._entity is None:
            ProvEntity(self._document, identifier="files:{}".format(self.basename))
        return self._entity

    def serializer(self):
        return serializer(self)


class SeqFile(File):
    """
    Class for holding sequence file and sequence information. Inherits from File.

    This class support records parsed with the BioPython.SeqIO module.
    """

    seqfile_formats = (
        "fasta",
        "clustal",
        "fastq",
        "fastq-sanger",
        "fastq-solexa",
        "fastq-illumina",
        "genbank",
        "gb",
        "nexus",
        "stockholm",
        "swiss",
        "tab",
        "qual",
    )

    def __init__(
        self,
        path,
        tag=None,
        format="fasta",
        parser="seq",
        document=None,
        import_records=False,
        calculate_seqstats=False,
    ):
        """
        :param path: A UNIX-like file _path.
        :param tag: optional tag describing the file.
        :param format: Format to be parsed by SeqIO.parse()
        :param parser: Bio parser to be used. Can be 'seq' (default) to be parsed by SeqIO or 'align'
                     to be parsed with AlignIO.
        :param document: prov.model.ProvDocument.
        :param import_records: Whether to import sequence data as Bio objects
        :param calculate_seqstats: Whether to calculate SeqStats
        """
        format_l = format.lower()
        assert format in SeqFile.seqfile_formats, Warnings()["choices"](
            format, "format", SeqFile.seqfile_formats
        )
        super().__init__(path, tag, document)
        self.format = format_l
        self.records = None
        self._generator = None
        self._seqstats = None
        self._parser = parser
        self.number_seqs: int
        self.total_bps: int
        self.mean_bp: float
        self.min_bp: int
        self.max_bp: int
        self.N50: int
        self.GC: float

        if self.exists:
            self._seqrecordgenerator()
        else:
            import_records = False
            calculate_seqstats = False

        if import_records:
            self.import_records()
            calculate_seqstats = True

        if calculate_seqstats:
            self._seqstats = self._calculate_seqstats(self.records)

    def _seqrecordgenerator(self):
        """
        Runs _seqrecordgenerator with the format.
        """
        self._generator = seqrecordgenerator(
            self.path, format=self.format, parser=self._parser
        )

    @property
    def generator(self):
        if self._generator is None:
            self._seqrecordgenerator()
        return self._generator

    @generator.setter
    def generator(self, value):
        self._generator = value

    @property
    def seqstats(self):
        if self._seqstats is None:
            self._seqstats = self._calculate_seqstats()
        return self._seqstats

    @seqstats.setter
    def seqstats(self, value):
        self._seqstats = value

    def import_records(self):
        assert self.exists, "Cannot import, file does not exist."
        self.records = SeqIO.to_dict(self._generator)

    def serializer(self):
        serial_out = self.__dict__
        key = "records"
        if key in serial_out.keys() and serial_out[key] is not None:
            if isinstance(serial_out[key], dict):
                serial_out[key] = [v.description for k, v in serial_out[key].items()]
            else:
                serial_out[key] = str(serial_out)

        for key in ("namespace", "ProvEntity"):
            if key in serial_out.keys():
                del serial_out[key]
        return serializer(serial_out)

    def _calculate_seqstats(
        self, calculate_gc=True, megabases=False, percentage=False, decimals=5,
    ):
        """
        :param calculate_gc: Whether to calculate GC content. Disabled if amino acid file.
        :param megabases: Whether to convert number of sequences to megabases.
        :param percentage: Whether to convert GC content to percentage (value * 100)
        :param decimals: Number of decimals to round.
        :return: SeqStats instance.
        """
        assert isinstance(self.records, dict), Warnings()["incorrect_type"](
            self.records, dict
        )

        bp_array, GC = [], 0
        aminoacids = "LMFWKQESPVIYHRND"

        # We use enumerate to check the first item for amino acids.
        for ix, (key, SeqRecord) in enumerate(self.records.items()):
            if ix == 0:
                seq = str(SeqRecord.seq)
                if any(i in aminoacids for i in seq):
                    calculate_gc = False

            # Add length of sequences (number of base pairs)
            bp_array.append(len(SeqRecord.seq))

            # Only count if there are no aminoacids.
            if calculate_gc:
                GC += SeqRecord.seq.upper().count("G")
                GC += SeqRecord.seq.upper().count("C")

        # Convert to array
        bp_array = np.array(bp_array)
        number_seqs = len(bp_array)
        total_bps = bp_array.sum()
        mean_bp = round(bp_array.mean(), decimals)
        N50 = calculate_N50(bp_array)
        min_bp = bp_array.min()
        max_bp = bp_array.max()

        if calculate_gc:
            GC = round(GC / total_bps, decimals)
            if percentage:
                GC *= 100
        else:
            GC = np.nan

        if megabases:
            total_bps /= 10e5

        self._seqstats = SeqStats(
            number_seqs, total_bps, mean_bp, min_bp, max_bp, N50, GC
        )

        for k, value in self._seqstats.__dict__.items():
            setattr(self, k, value)

        return self._seqstats


@dataclass
class SeqStats:
    """
    Dataclass to describe sequence statistics.
    """

    number_seqs: int
    total_bps: int
    mean_bp: float
    min_bp: int
    max_bp: int
    N50: int
    GC: float


def calculate_N50(array):
    """
    Calculate N50 from an array of contig lengths.
    https://github.com/vikas0633/python/blob/master/N50.py

    Based on the Broad Institute definition:
    https://www.broad.harvard.edu/crd/wiki/index.php/N50
    :param array: list of contig lengths
    :return: N50 value
    """
    array.sort()
    new_array = []
    for x in array:
        new_array += [x] * x

    if len(new_array) % 2 == 0:
        ix = int(len(new_array) / 2)
        return (new_array[ix] + new_array[ix - 1]) / 2
    else:
        ix = int((len(new_array) / 2) - 0.5)
        return new_array[ix]


def seqrecordgenerator(path, format, parser="seq"):
    """
    :param path: Path to file.
    :param format: format to pass to SeqIO.parse().
    :param parser: Whether to import records with SeqIO (default) or AlignIO
    :return: A generator of SeqRecords.
    """
    parser_l = parser.lower()
    available_parsers = ("seq", "align")
    assert parser in available_parsers, Warnings()["choices"](
        parser, available_parsers, "parser"
    )
    kind_dict = {
        "seq": lambda _path, _format: SeqIO.parse(path, format=format),
        "align": lambda _path, _format: AlignIO.parse(path, format=format),
    }
    try:
        records = kind_dict[parser_l](path, format)
        return records
    except FileNotFoundError:
        raise
