#!/usr/bin/env python3

import platform
import shutil
import sys
from pathlib import Path
from pprint import pprint
from subprocess import run
from tempfile import TemporaryDirectory

import click

PYTHON_VERSION = ".".join(platform.python_version_tuple()[:2])

VBASEMODEL = """
class VBaseModel(BaseModel):
    class Config:
        extra = Extra.forbid
        json_encoders = {
            "UUIDModel": lambda b: str(b.__root__)
        }
"""


@click.command(context_settings=dict(ignore_unknown_options=True))
@click.option(
    "-r",
    "--repo-dir",
    type=click.Path(file_okay=False, writable=True, path_type=Path),
    envvar="REPO_ROOT",
    show_envvar=True,
    default=Path.cwd(),
    help="repository root dir",
)
@click.option(
    "-m",
    "--module-name",
    type=click.Path(file_okay=False, writable=True, path_type=Path),
    envvar="MODULE",
    show_envvar=True,
    help="module directory",
)
@click.option(
    "-u",
    "--url",
    type=str,
    envvar="OMG_URL",
    show_envvar=True,
    help="url of openapi json spec",
)
@click.option(
    "-i",
    "--input",
    "_input",
    type=click.Path(
        dir_okay=False, readable=True, exists=True, path_type=Path
    ),
    envvar="OMG_INPUT",
    show_envvar=True,
    help="openapi json input file",
)
@click.option(
    "-o",
    "--output",
    "_output",
    type=click.Path(writable=True, path_type=Path),
    default="model.py",
    help="output directory or source filename",
)
@click.option(
    "-b",
    "--backup-dir",
    type=click.Path(file_okay=False, writable=True, path_type=Path),
    help="backup directory for generator extracted files",
)
@click.option(
    "-v",
    "--python-version",
    type=str,
    default=PYTHON_VERSION,
    help="python version for generate options",
)
@click.option("-C", "--confirm-command", is_flag=True, help="confirm command")
@click.option("-f", "--force", is_flag=True, help="bypass confirmation")
@click.option(
    "-c",
    "--generate-command",
    type=str,
    envvar="OMG_CMD",
    show_envvar=True,
    default="datamodel-codegen",
    help="code generation command with arguments",
)
@click.argument("generate-args", nargs=-1, type=click.UNPROCESSED)
def generate(
    repo_dir,
    module_name,
    backup_dir,
    url,
    _input,
    _output,
    python_version,
    force,
    confirm_command,
    generate_command,
    generate_args,
):
    """OMG openapi model generator

    postprocessor for datamodel-codegen

    pass extra datamodel-codegen options as -O=VALUE or --OPTION=VALUE

    \b
    https://github.com/pydantic/pydantic
    \b
    https://github.com/koxudaxi/datamodel-code-generator/

    """
    repo_dir = repo_dir or Path.cwd()
    module_name = module_name or str(repo_dir.name).replace("-", "_")
    module_directory = repo_dir / module_name
    if module_directory.is_dir() is False:
        raise FileNotFoundError(f"{module_directory=}")

    if backup_dir:
        if backup_dir.is_dir():
            if force is None:
                click.confirm(
                    f"Overwrite backup dir {str(backup_dir)}", abort=True
                )
            shutil.rmtree(str(backup_dir))

    output_mode = None
    _output = module_directory / _output
    if _output.exists():
        if not force:
            click.confirm(
                f"Overwrite {output_mode} {str(_output)}", abort=True
            )
        if _output.is_file():
            output_mode = "file"
        elif _output.is_dir():
            output_mode = "directory"
            shutil.rmtree(str(_output))
        else:
            raise TypeError(
                "existing output target must be a file or directory"
            )
    else:
        if _output.suffix == ".py":
            output_mode = "file"
        else:
            output_mode = "directory"

    with TemporaryDirectory() as temp_dir:
        workdir = Path(temp_dir)
        args = {
            "--target-python-version": python_version,
            "--output": str(workdir),
        }
        if url:
            args["--url"] = url
        if _input:
            args["--input"] = _input

        for arg in list(generate_args):
            k, _, v = arg.partition("=")
            args[k] = v

        cmd = [generate_command]

        for k, v in args.items():
            cmd.extend([k, v])

        if confirm_command:
            click.echo(" ".join(cmd))
            click.confirm("confirm", abort=True)

        run(cmd, check=True)

        sources = [source_file.name for source_file in workdir.iterdir()]
        if not sources:
            click.echo("Generator produced no output", err=True)
            click.get_current_context().exit(-1)

        if backup_dir:
            shutil.copytree(workdir, backup_dir)

        init_file = sources.pop(sources.index("__init__.py"))
        sources.insert(0, init_file)

        if output_mode == "directory":
            processor = DirProcessor(workdir, sources, _output)
        elif output_mode == "file":
            processor = FileProcessor(workdir, sources, _output)
        else:
            RuntimeError(f"unknown {output_mode=}")

        processor.run()


class DirProcessor:
    def __init__(self, workdir, sources, outdir):
        self.workdir = workdir
        self.sources = sources
        self.outdir = outdir

    def run(self):
        self.outdir.mkdir()
        init_file = self.sources.pop(0)

        classes, _r, _f = self.write_file(
            self.workdir / init_file,
            self.outdir / "model.py",
            filter_import=True,
            filter_webhook_prefix=True,
            define_vbasemodel=True,
            filter_basemodel=True,
        )

        model_classes = [c for c in classes if c.endswith("Model")]
        model_classes.append("UUID")
        model_classes.append("IWebhookUnParsed")
        model_classes.append("ITinyPayload")
        # pprint({"model_classes": model_classes})

        imports = {"model": classes}
        remap_classes = {}
        filtered_classes = {}
        for source in self.sources:
            classes, remaps, filtered = self.write_file(
                self.workdir / source,
                self.outdir / source,
                filter_import=True,
                filter_models=model_classes,
                filter_basemodel=True,
            )
            if not source.endswith(".py"):
                raise RuntimeError(f"unexpected filename: {source=}")
            module = source[:-3]
            imports[module] = classes
            remap_classes[module] = remaps
            if filtered:
                filtered_classes[module] = filtered

        self.write_init(self.outdir / "__init__.py", imports)

        self.add_classes(
            self.outdir / "model.py", filtered_classes["webhookTypes"]
        )

        # pprint({"remap": remap_classes})
        # pprint({"filtered": filtered_classes})

    def add_classes(self, outfile, classes):
        with outfile.open("a") as ofp:
            for class_lines in classes.values():
                for line in class_lines:
                    ofp.write(line)

    def write_init(self, outfile, classes):
        with outfile.open("w") as ofp:
            ofp.write("# autogenerated models\n\n")
            all_classes = []
            for module, names in classes.items():
                if names:
                    class_list = ", ".join(names)
                    all_classes.extend(names)
                    ofp.write(f"from .{module} import {class_list}\n")
            ofp.write("\n")
            ofp.write("__all__ = [\n")
            ofp.write(", ".join([f"'{c}'" for c in all_classes]))
            ofp.write("\n]\n")

    def write_file(
        self,
        infile,
        outfile,
        *,
        filter_import=False,
        filter_webhook_prefix=False,
        filter_models=[],
        filter_basemodel=False,
        define_vbasemodel=False,
    ):
        classes = []
        remaps = []
        filtered = {}
        filter_class = None
        last = "."
        skip_next_line = False
        with infile.open("r") as ifp:
            with outfile.open("w") as ofp:
                for line in ifp.readlines():
                    if skip_next_line:
                        skip_next_line = False
                        continue
                    if define_vbasemodel:
                        if "from pydantic import" in line:
                            ofp.write(line)
                            ofp.write(VBASEMODEL)
                            continue
                    if filter_basemodel:
                        if "(BaseModel)" in line:
                            line = line.replace("(BaseModel)", "(VBaseModel)")
                        if "class Config:" in line:
                            skip_next_line = True
                            continue

                    if filter_webhook_prefix:
                        if "webhookTypes." in line:
                            line = line.replace("webhookTypes.", "")
                    if filter_import:
                        line = self.rewrite_import(line)
                    if filter_class is None:
                        if line.startswith("class "):
                            class_name = self.class_name(line)
                            if (
                                infile.name == "Partial_streamsTypes.py"
                                and class_name == "StreamsModelCreate"
                            ):
                                class_name = "StreamsModelUpdate"
                                line = line.replace(
                                    "StreamsModelCreate", class_name
                                )
                            if class_name.endswith("1"):
                                orig_name = class_name
                                line = line.replace("1(", "(")
                                class_name = self.class_name(line)
                                remaps.append((orig_name, class_name))
                                print(
                                    f"{outfile.name}: remap {orig_name} to {class_name}"
                                )
                            if class_name in filter_models:
                                print(
                                    f"{outfile.name}: filtering {class_name}"
                                )
                                filter_class = class_name
                                filtered[filter_class] = [line]
                            else:
                                classes.append(self.class_name(line))
                        if filter_class is None:
                            ofp.write(line)
                    else:
                        filtered[filter_class].append(line)
                        if last.strip() == "" and line.strip() == "":
                            filter_class = None
                    last = line

        return classes, remaps, filtered

    def class_name(self, line):
        line = line.strip()
        line = line.split()[1]
        name, _, base = line.partition("(")
        return name

    def rewrite_import(self, line):
        if "from . import webhookTypes" in line:
            line = "\n"
        elif "from pydantic import" in line:
            line = line + "from .model import VBaseModel\n"
        elif "from . import" in line:
            line = line.replace("from . import", "from .model import")
            words = line.split()
            imports = set([word.strip(",") for word in words[3:]])
            imports.discard("webhookTypes")
            imports.discard("UUIDModel")
            imports = list(imports)
            if imports:
                line = "from .model import " + ",".join(imports) + "\n"
            else:
                line = "\n"
        return line


class FileProcessor:
    def __init__(self, workdir, sources, outfile):
        self.workdir = workdir
        self.outfile = outfile
        self.headers = ["# autogenerated pydantic models\n", "\n"]
        self.sources = []
        self.imports = {}
        for source in sources:
            self.read_file_lines(workdir / source)

    def read_file_lines(self, path):
        for i, line in enumerate(path.open("r").readlines()):
            if line.startswith("#"):
                if path.name != "__index__.py":
                    continue
                else:
                    self.headers.append(line)
            elif line.startswith("from "):
                if line.startswith("from . "):
                    continue
                else:
                    words = line.strip().split()
                    key = words[1]
                    self.imports[key] = self.imports.setdefault(key, [])
                    if words[0] != "from" or words[2] != "import":
                        raise RuntimeError(
                            f"Parse failure: {path.name}:{i} {repr(line)}"
                        )
                    for word in words[3:]:
                        word = word.strip()
                        word = word.strip(",")
                        if word not in self.imports[key]:
                            self.imports[key].append(word)
            else:
                self.sources.append(line)

    def run(self):
        with self.outfile.open("w") as ofp:
            for line in self.output_lines():
                ofp.write(line)

    def output_lines(self):

        for line in self.headers:
            yield line
        yield "\n"

        for k, v in self.imports.items():
            yield f"from {k} import {', '.join(v)}\n"

        for line in self.sources:
            yield line


if __name__ == "__main__":
    sys.exit(generate())
