from typing import Dict, List, Optional, Tuple

import pysam
import shortuuid

from ..chemistry import Chemistry
from ..progress import progress
from .Fastq import (
    Fastq,
    FastqError,
)
from .Read import (
    Read,
    ReadError,
    Quality,
)


def fastq_to_bam(
    fastq_path: str,
    bam_path: str,
    name: Optional[str] = None,
    n_threads: int = 1,
    show_progress: bool = False,
) -> str:
    """Convert a Fastq to unmapped BAM.

    Args:
        fastq_path: Path to the input FASTQ
        bam_path: Path to the output BAM
        name: Name for this set of reads. Defaults to None. If not provided,
            a random string is generated by calling :func:`shortuuid.uuid`. This
            value is added as the read group (RG tag) for all the reads in the BAM.
        n_threads: Number of threads to use. Defaults to 1.
        show_progress: Whether to display a progress bar. Defaults to False.

    Returns:
        Path to BAM
    """
    rg = name or shortuuid.uuid()
    header = pysam.AlignmentHeader.from_dict({
        'HD': {
            'VN': pysam.version.__samtools_version__,
            'SO': 'unsorted'
        },
        'RG': [{
            'ID': rg
        }],
    })
    with Fastq(fastq_path,
               'r') as f_in, pysam.AlignmentFile(bam_path, 'wb', header=header,
                                                 threads=n_threads) as f_out:
        for read in progress(f_in, desc='Writing BAM',
                             disable=not show_progress):
            al = pysam.AlignedSegment(header)
            al.query_name = read.name
            al.query_sequence = read.sequence
            al.query_qualities = read.qualities.values
            al.flag = 4  # unmapped
            al.tags = [('RG', rg)]
            f_out.write(al)

    return bam_path


def fastqs_to_bam_with_chemistry(
    fastq_paths: List[str],
    chemistry: Chemistry,
    tag_map: Dict[str, Tuple[str, str]],
    bam_path: str,
    name: Optional[str] = None,
    sequence_key: str = 'cdna',
    n_threads: int = 1,
    show_progress: bool = False,
) -> str:
    """Convert FASTQs to an unmapped BAM according to the provided
    :class:`ngs_tools.chemistry.Chemistry` instance.

    Note that any split features (i.e. split barcode where barcode is in multiple
    positions) are concatenated.

    Args:
        fastq_paths: List of FASTQ paths. The order must match that of the
            chemistry.
        chemistry: :class:`ngs_tools.chemistry.Chemistry` instance to use to parse the reads.
        tag_map: Mapping of parser names to their corresponding BAM tags.
            The keys are the parser names, and the values must be a tuple of
            ``(sequence BAM tag, quality BAM tag)``, where the former is the
            tag that will be used for the nucleotide sequence, and the latter is
            the tag that will be used for the quality scores.
        bam_path: Path to the output BAM
        name: Name for this set of reads. Defaults to None. If not provided,
            a random string is generated by calling :func:`shortuuid.uuid`. This
            value is added as the read group (RG tag) for all the reads in the BAM.
        sequence_key: Parser key to use as the actual alignment sequence.
            Defaults to `cdna`.
        n_threads: Number of threads to use. Defaults to 1.
        show_progress: Whether to display a progress bar. Defaults to False.

    Returns:
        Path to BAM

    Raises:
        FastqError: If the number of FASTQs provided does not meet the number
            required for the specified chemistry, if the tag map provides
            keys that do not exist for the chemistry, or if the tag map contains
            multiple BAM tags.
    """
    if len(fastq_paths) != chemistry.n:
        raise FastqError(
            f'Chemistry `{chemistry}` requires {chemistry.n} FASTQs, but only '
            f'{len(fastq_paths)} were provided.'
        )
    keys = set(tag_map.keys())
    keys.add(sequence_key)
    unknown_keys = keys - set(chemistry.parsers.keys())
    if unknown_keys:
        raise FastqError(f'Unknown keys in `tag_map`: {unknown_keys}')
    all_tags = []
    for tags in tag_map.values():
        all_tags.extend(tags)
    if len(set(all_tags)) != len(all_tags):
        raise FastqError('Tag map contains duplicate BAM tags.')

    rg = name or shortuuid.uuid()
    header = pysam.AlignmentHeader.from_dict({
        'HD': {
            'VN': pysam.version.__samtools_version__,
            'SO': 'unsorted'
        },
        'RG': [{
            'ID': rg
        }],
    })
    with pysam.AlignmentFile(bam_path, 'wb', header=header,
                             threads=n_threads) as f:
        fastqs = []
        try:
            # Open all FASTQs
            for fastq_path in fastq_paths:
                fastqs.append(Fastq(fastq_path, 'r'))

            # Parse each set of reads
            for reads in progress(zip(*fastqs), desc='Writing BAM',
                                  disable=not show_progress):
                parsed = chemistry.parse_reads(reads, concatenate=True)
                read_sequence, read_quality = parsed[sequence_key]
                read_quality = pysam.qualitystring_to_array(read_quality)
                tags = [('RG', rg)]
                for parser_name, (sequence_tag, quality_tag) in tag_map.items():
                    sequence, quality = parsed[parser_name]
                    tags.append((sequence_tag, sequence))
                    tags.append((quality_tag, quality))

                al = pysam.AlignedSegment(header)
                al.query_name = reads[0].name
                al.query_sequence = read_sequence
                al.query_qualities = read_quality
                al.flag = 4  # unmapped
                al.tags = tags
                f.write(al)

        finally:
            for fastq in fastqs:
                fastq.close()

    return bam_path
