import yaml
from yaml.scanner import ScannerError

import numpy as np

from qary.chat.dialog import generate_domain_filepaths
from qary.constants import DOMAINS
from qary.spacy_language_model import nlp, UNKNOWN_WORDVEC

import logging
log = logging.getLogger(locals().get('__name__'))


log.debug(f'ETL FAQS DOMAIN {DOMAINS}')


def normalize_docvectors(docvectors):
    """ Convert a table (2D matrix) of row-vectors into a table of normalized row-vectors

    >>> vecs = normalize_docvectors([[1, 2, 3], [4, 5, 6], [0, 0, 0], [-1, 0, +2]])
    >>> vecs.shape
    (4, 3)
    >>> np.linalg.norm(vecs, axis=1).round()
    array([1., 1., 0., 1.])
    """
    docvectors = np.array(docvectors)
    log.info(f'docvectors.shape: {docvectors.shape}')
    norms = np.linalg.norm(docvectors, axis=1)
    iszero = norms <= 0
    log.info(f'norms.shape: {norms.shape}')
    norms_reshaped = norms.reshape(-1, 1).dot(np.ones((1, docvectors.shape[1])))
    log.info(f'norms_reshaped.shape: {norms_reshaped.shape}')
    if np.any(iszero):
        log.warning(
            f'Some doc vectors are zero like this first one: \n'
            f'docvectors[{iszero},:] = {docvectors[iszero,:]}'
        )
    norms_reshaped[iszero, :] = 1
    normalized_docvectors = docvectors / norms_reshaped
    log.info(f'normalized_docvectors.shape: {normalized_docvectors.shape}')
    assert normalized_docvectors.shape == docvectors.shape
    return normalized_docvectors


def load(domains=DOMAINS):
    """ Load yaml file, use hashtags to create context tags as multihot columns

    Load faq*.yml into dictionary: question: answer

    >>> g = load()
    >>> len(g['questions']) == len(g['answers']) > 30
    True
    """
    questions, answers, question_vectors = [], [], []

    for filepath in generate_domain_filepaths(domains=DOMAINS, prefix='faq', data_subdir='faq'):
        filepointer = open(filepath)
        with filepointer:
            log.info(f"loading: {filepath.name}\n    with file pointer: {filepointer}")
            try:
                qa_list = yaml.safe_load(filepointer)
            except ScannerError as e:
                log.error(f"{e}\n    yaml.safe_load unable to read {filepointer.name}")
                continue
        for i, qa_dict in enumerate(qa_list):
            if qa_dict is None:
                log.warning(f'Found None instead of dict in FAQ pair {filepointer.name}[{i}].')
                qa_dict = {}
            questions.append(qa_dict.get('Q', qa_dict.get('q', '')))
            answers.append(qa_dict.get('A', qa_dict.get('a', '')))
            try:
                question_vectors.append(list(nlp(questions[-1] or '').vector))
            except TypeError:
                question_vectors.append(list(UNKNOWN_WORDVEC))
                continue
            assert len(UNKNOWN_WORDVEC) == len(question_vectors[-1])
        log.debug(f"Loaded {len(question_vectors)} for file {filepath}")

    log.debug(f'len(question_vectors): {len(question_vectors)}')

    questions = np.array(questions)
    log.debug(f'len(questions): {len(questions)}')
    answers = np.array(answers)
    log.debug(f'len(answers): {len(answers)}')
    mask = np.array([(bool(a) and bool(q) and len(str(a).strip()) > 0 and len(str(q).strip()) > 0)
                     for a, q in zip(questions, answers)])

    question_vectors = normalize_docvectors(question_vectors)

    # This should be a Kendra/gensim/annoy class (with methods like .find_similar)
    return dict(
        questions=questions[mask],
        answers=answers[mask],
        question_vectors=np.array([qv for qv, m in zip(question_vectors, mask) if m])
    )
