"""This module provides the Cocke-Younger-Kasami algorithm for parsing

"""

from typing import Sequence, List, Tuple, Union, Set

from formgram.grammars.transformations.context_free import to_chomsky_normal_form
from formgram.grammars.types import GrammarDict, Symbol


def cyk_recognize(grammar: GrammarDict, word: Sequence[Symbol]) -> bool:
    """Check if the given word can be generated by given context free grammar

    :param grammar:
    :param word:
    :return: True if the provided word is generated by provided grammar
    """
    parse_table, _ = create_cyk_tables(grammar, word)
    return grammar["starting_symbol"] in parse_table[-1][0]


def create_cyk_forest(grammar: GrammarDict, word: Sequence[Symbol]) -> Union[List[dict], None]:
    """Create a list of all possible parse trees

    The individual parse trees are dictionaries with keys
    * symbol: which is the terminal or nonterminal symbol which is used or respectively produced
    * children: a Tuple of tree dictionaries

    This is compatible with the anytree package.

    :param grammar:
    :param word:
    :return: a list of parse tree dict recognizing the given word,
        None if list is empty
    """
    _, backpointer_table = create_cyk_tables(grammar, word)
    if not backpointer_table[len(word) - 1][0]:
        return None
    return create_forest_by_recursion(
        backpointer_table, grammar["starting_symbol"], word, row=len(word) - 1, col=0
    )


def create_forest_by_recursion(backpointer_table: List[List[Set[Tuple[Symbol, Symbol, Symbol, int]]]],
                               current: Symbol,
                               word: Sequence[Symbol],
                               row: int,
                               col: int) -> List[dict]:
    """Reconstruct all subtrees from the backpointer table starting at the specified cell with the specified nonterminal

    :param word: The word being parsed
    :param current:
    :param backpointer_table:
    :param row:
    :param col:
    :return: The forest of trees as given by the backpointer table, calculated
        from the given starting coordinates
    """

    def recursion(row_index, column_index):
        """Call parent function with context prefilled

        :param row_index:
        :param column_index:
        :return:
        """
        return create_forest_by_recursion(backpointer_table, current, word, row_index, column_index)

    forest = []
    if row == 0:
        forest.append({"name": current, "children": ({"name": word[col], "children": ()},)})
    elif row > 0:
        for nonterminal, left_child, right_child, split_index in backpointer_table[row][col]:
            if nonterminal != current:
                continue
            for left_subtree in recursion(split_index, col):
                for right_subtree in recursion(row - split_index - 1, col + split_index + 1):
                    tree = {"name": current, "children": (left_subtree, right_subtree)}
                    forest.append(tree)
    return forest


def create_cyk_tables(
    grammar: GrammarDict, word: Sequence
) -> Tuple[List[List[set]], List[List[set]]]:
    """Create a parse table and backpointer table using the CYK algorithm

    The tables are implemented as lists of lists, and are triangle shaped.
    I.e. row 0 is full, row -1 consists of exactly one entry.

    The parse table entries are sets of nonterminals.
    The backpointer table are sets of tuples.
    Each tuple consists of 3 entries for the production
    and 1 entry for the split_index used.


    :param grammar:
    :param word:
    :return: CYK parse table and a table containing backpointers
    """
    word_length = len(word)
    grammar = to_chomsky_normal_form(grammar)
    parse_table = [
        [set() for cell in range(word_length - row)] for row in range(word_length)
    ]
    backpointer_table = [
        [set() for cell in range(word_length - row)] for row in range(word_length)
    ]
    terminal_productions = {
        production
        for production in grammar["productions"]
        if len(production[1]) == 1  # righthand_side := production[1]
    }
    bifurcating_productions = {
        production
        for production in grammar["productions"]
        if len(production[1]) == 2  # righthand_side := production[1]
    }

    # Fill bottom row with nonterminals which can produce the word symbols directly
    for i, symbol in enumerate(word):
        parse_table[0][i] = {
            nonterminal
            for (nonterminal,), (terminal,) in terminal_productions
            if terminal == symbol
        }

    # Fill rest of rows by checking for all possible places where one could split the word if those substrings can be
    # created by the corresponding cells in the already filled rows
    for row_index in range(1, word_length):
        for column_index in range(0, word_length - row_index):
            for split_index in range(0, row_index):
                for (left_nonterminal,), (
                    first_right_nonterminal,
                    second_right_nonterminal,
                ) in bifurcating_productions:
                    if (
                        first_right_nonterminal in parse_table[split_index][column_index]
                        and second_right_nonterminal
                        in parse_table[row_index - split_index - 1][column_index + split_index + 1]
                    ):
                        parse_table[row_index][column_index].add(left_nonterminal)
                        backpointer_table[row_index][column_index].add(
                            (
                                left_nonterminal,
                                first_right_nonterminal,
                                second_right_nonterminal,
                                split_index,
                            )
                        )

    return parse_table, backpointer_table
