#!/usr/bin/env python

"""
This scripts tests generated classes using Kafka Avro Serde.
"""
import argparse
import importlib
import sys
from dataclasses import asdict
from enum import Enum
from pathlib import Path

from avro.schema import Parse, Schema
from confluent_kafka.avro.serializer import TopicRecordNameStrategy  # pylint: disable=no-name-in-module

from pyavro_gen.codewriters.namespace import ClassItem
from pyavro_gen.codewriters.utils import j
from pyavro_gen.equality_test import are_almost_equal
from pyavro_gen.mock_schema_registry_client import MockSchemaRegistryClient

__author__ = "Nicola Bova"
__copyright__ = "Copyright 2019, Jaumo GmbH"
__email__ = "nicola.bova@jaumo.com"


def do_test_generated_classes(  # pylint: disable=R0914
        module: str,
        domain_namespace: str,
        number_of_cycles: int = 10,
        verbose: bool = True
) -> None:
    """
    Test generated classes. For each generated class that corresponds to an Avro record,
    this function performs a serialization/deserialization cycle and checks if input and output
    are equal.

    :param module: The path of the package of the generated classes
    :param domain_namespace: The namespace of Avro "root" classes that can be a ser/deserialised.
    :param number_of_cycles: How many serde cycles to perform per class
    :param verbose: To show any output or not
    """

    # To be able to load a package from an arbitrary location, we add the path sys.path
    # we perform cleanup later
    module_path = Path(module).absolute()
    parent_path = module_path.parent
    sys.path.insert(0, str(parent_path))
    module_name = str(module_path.name)
    test_module_name = module_name + '_test'

    classes_module = importlib.import_module(module_name)
    test_module = importlib.import_module(test_module_name)

    # reload useful for bulk unittests
    importlib.reload(classes_module)
    test_module = importlib.reload(test_module)
    test_module.testing_classes = importlib.reload(test_module.testing_classes)  # type: ignore

    schema_registry_client = MockSchemaRegistryClient()  # type: ignore

    value_serializer = classes_module.TypedAvroSerializer(  # type: ignore
        schema_registry_client,
        is_key=False,
        subject_strategy=TopicRecordNameStrategy
    )

    value_deserializer = classes_module.TypedAvroDeserializer(  # type: ignore
        schema_registry_client,
        is_key=False
    )

    for test_class_item in test_module.testing_classes.test_classes:  # type: ignore
        class_item: ClassItem = test_class_item

        if verbose:
            print('testing', j(class_item.namespace, '.', class_item.name))

        mod = __import__(class_item.namespace, fromlist=[class_item.name])
        klass = getattr(mod, class_item.name)

        test_mod = __import__(class_item.namespace_test, fromlist=[class_item.factory_name])
        klass_factory = getattr(test_mod, class_item.factory_name)

        if domain_namespace not in class_item.namespace or issubclass(klass, Enum):
            # Other classes cannot be directly serialised by Avro so we skip them
            if verbose:
                print('\tskipping')
            continue

        # we only check records with a full namespace
        avro_schema: Schema = Parse(klass._schema)  # pylint: disable=W0212
        if avro_schema.fullname.startswith('.'):
            continue

        for i in range(number_of_cycles):
            instance = klass_factory()
            if verbose:
                print(i, '\t', instance)

            serialized = value_serializer('mytopic', instance)
            deserialized = value_deserializer('mytopic', serialized)

            try:
                assert are_almost_equal(
                    instance, deserialized, max_abs_ratio_diff=1.000001, max_abs_diff=-1)
            except AssertionError:
                print("expected:", instance)
                print("actual:  ", deserialized)
                print("des dict:", asdict(deserialized))
                clean_path()
                raise

    clean_path()


def clean_path() -> None:
    """
    Removes the path added in do_test_generated_classes()
    """
    sys.path.pop(0)


if __name__ == '__main__':
    PARSER = argparse.ArgumentParser(
        description='A test for Avro typed classes generated by pyavro-gen.')

    PARSER.add_argument('-m', '--module', dest='module', type=str, required=True)
    PARSER.add_argument('-n', '--number_of_cycles', dest='ncycles', type=int, default=10)
    PARSER.add_argument('-d', '--domain-namespace', dest='domain_namespace', type=str,
                        required=True)
    PARSER.add_argument('-v', '--verbose', dest='verbose', action='store_true')
    PARSER.set_defaults(verbose=True)

    ARGS = PARSER.parse_args()
    do_test_generated_classes(ARGS.module, ARGS.domain_namespace, ARGS.ncycles, ARGS.verbose)
