#
# This file is part of LiteX.
#
# Copyright (c) 2021 Franck Jullien <franck.jullien@collshade.fr>
# SPDX-License-Identifier: BSD-2-Clause

import os
import csv
import re
import datetime

from xml.dom import expatbuilder
import xml.etree.ElementTree as et

from litex.build import tools

namespaces = {
    "efxpt" : "http://www.efinixinc.com/peri_design_db",
    "xi"    : "http://www.w3.org/2001/XInclude"
}

# Interface Writer Block ---------------------------------------------------------------------------

class InterfaceWriterBlock(dict):
    def generate(self):
        raise NotImplementedError # Must be overloaded

class InterfaceWriterXMLBlock(dict):
    def generate(self):
        raise NotImplementedError # Must be overloaded

# Interface Writer  --------------------------------------------------------------------------------

class InterfaceWriter:
    def __init__(self, efinity_path):
        self.efinity_path = efinity_path
        self.blocks       = []
        self.xml_blocks   = []
        self.fix_xml      = []
        self.filename     = ""
        self.platform     = None

    def set_build_params(self, platform, build_name):
        self.filename = build_name
        self.platform = platform

    def fix_xml_values(self):
        et.register_namespace("efxpt", "http://www.efinixinc.com/peri_design_db")
        tree = et.parse(self.filename + ".peri.xml")
        root = tree.getroot()
        for tag, name, values in self.fix_xml:
            for e in tree.iter():
                if (tag in e.tag) and (name == e.get("name")):
                    for n, v in values:
                        e.set(n, v)

        xml_string = et.tostring(root, "utf-8")
        reparsed = expatbuilder.parseString(xml_string, False)
        print_string = reparsed.toprettyxml(indent="    ")

        # Remove lines with only whitespaces. Not sure why they are here
        print_string = os.linesep.join([s for s in print_string.splitlines() if s.strip()])

        tools.write_to_file("{}.peri.xml".format(self.filename), print_string)

    def generate_xml_blocks(self):
        et.register_namespace("efxpt", "http://www.efinixinc.com/peri_design_db")
        tree = et.parse(self.filename + ".peri.xml")
        root = tree.getroot()

        for block in self.xml_blocks:
            if isinstance(block, InterfaceWriterXMLBlock):
                block.generate(root, namespaces)
            else:
                if block["type"] == "LVDS":
                    self.add_lvds_xml(root, block)
                if block["type"] == "DRAM":
                    self.add_dram_xml(root, block)

        if self.platform.iobank_info:
            self.add_iobank_info_xml(root, self.platform.iobank_info)

        xml_string = et.tostring(root, "utf-8")
        reparsed = expatbuilder.parseString(xml_string, False)
        print_string = reparsed.toprettyxml(indent="    ")

        # Remove lines with only whitespaces. Not sure why they are here
        print_string = os.linesep.join([s for s in print_string.splitlines() if s.strip()])

        tools.write_to_file("{}.peri.xml".format(self.filename), print_string)

    def header(self, build_name, partnumber):
        header = "# Autogenerated by LiteX / git: " + tools.get_litex_git_revision()
        header += """
import os
import sys
import pprint

home = "{0}"

os.environ["EFXPT_HOME"]  = home + "/pt"
os.environ["EFXPGM_HOME"] = home + "/pgm"
os.environ["EFXDBG_HOME"] = home + "/debugger"
os.environ["EFXIPM_HOME"] = home + "/ipm"

sys.path.append(home + "/pt/bin")
sys.path.append(home + "/lib/python3.8/site-packages")

from api_service.design import DesignAPI
from api_service.device import DeviceAPI

is_verbose = {1}

design = DesignAPI(is_verbose)
device = DeviceAPI(is_verbose)

design.create("{2}", "{3}", "./../gateware", overwrite=True)

"""
        return header.format(self.efinity_path, "True", build_name, partnumber)

    def get_block(self, name):
        for b in self.blocks:
            if b["name"] == name:
                return b
        return None

    def generate_mipi_tx(self, block, verbose=True):
        name = block["name"]
        cmd = "# ---------- MIPI TX {} ---------\n".format(name)
        cmd += f'design.create_block("{name}","MIPI_TX_LANE", mode="{block["mode"]}")\n'
        for p, v in block["props"].items():
            cmd += f'design.set_property("{name}","{p}","{v}","MIPI_TX_LANE")\n'
        cmd += f'design.assign_resource("{name}","{block["ressource"]}","MIPI_TX_LANE")\n'
        cmd += "# ---------- END MIPI TX {} ---------\n\n".format(name)
        return cmd

    def generate_mipi_rx(self, block, verbose=True):
        name = block["name"]

        conn_type = ""
        if "conn_type" in block:
            conn_type = f', conn_type="{block["conn_type"]}"'

        cmd = "# ---------- MIPI RX {} ---------\n".format(name)
        cmd += f'design.create_block("{name}","MIPI_RX_LANE", mode="{block["mode"]}"' + conn_type + ')\n'
        for p, v in block["props"].items():
            cmd += f'design.set_property("{name}","{p}","{v}","MIPI_RX_LANE")\n'
        cmd += f'design.assign_resource("{name}","{block["ressource"]}","MIPI_RX_LANE")\n'
        cmd += "# ---------- END MIPI RX {} ---------\n\n".format(name)
        return cmd

    def generate_gpio(self, block, verbose=True):
        name = block["name"]
        mode = block["mode"]
        prop = block["properties"]
        cmd = ""

        if mode == "INOUT":
            if len(block["location"]) == 1:
                cmd += f'design.create_inout_gpio("{name}")\n'
                cmd += f'design.assign_pkg_pin("{name}","{block["location"][0]}")\n'
            else:
                cmd += f'design.create_inout_gpio("{name}",{block["size"]-1},0)\n'
                for i, pad in enumerate(block["location"]):
                    cmd += f'design.assign_pkg_pin("{name}[{i}]","{pad}")\n'

            if "out_reg" in block:
                cmd += f'design.set_property("{name}","OUT_REG","{block["out_reg"]}")\n'
                cmd += f'design.set_property("{name}","OUT_CLK_PIN","{block["out_clk_pin"]}")\n'
                if "out_delay" in block:
                    cmd += f'design.set_property("{name}","OUTDELAY","{block["out_delay"]}")\n'

            if "out_clk_inv" in block:
                cmd += f'design.set_property("{name}","IS_OUTCLK_INVERTED","{block["out_clk_inv"]}")\n'
                cmd += f'design.set_property("{name}","OE_CLK_PIN_INV","{block["out_clk_inv"]}")\n'

            if "in_reg" in block:
                cmd += f'design.set_property("{name}","IN_REG","{block["in_reg"]}")\n'
                cmd += f'design.set_property("{name}","IN_CLK_PIN","{block["in_clk_pin"]}")\n'
                if "in_delay" in block:
                    cmd += f'design.set_property("{name}","INDELAY","{block["in_delay"]}")\n'

            if "in_clk_inv" in block:
                cmd += f'design.set_property("{name}","IS_INCLK_INVERTED","{block["in_clk_inv"]}")\n'

            if "oe_reg" in block:
                cmd += f'design.set_property("{name}","OE_REG","{block["oe_reg"]}")\n'
            if "oe_clk_pin" in block:
                cmd += f'design.set_property("{name}","OE_CLK_PIN","{block["oe_clk_pin"]}")\n'

            if prop:
                for p, val in prop:
                    cmd += 'design.set_property("{}","{}","{}")\n'.format(name, p, val)
            cmd += "\n"
            return cmd

        if mode == "INPUT":
            if len(block["location"]) == 1:
                cmd += f'design.create_input_gpio("{name}")\n'
                cmd += f'design.assign_pkg_pin("{name}","{block["location"][0]}")\n'
            else:
                cmd += f'design.create_input_gpio("{name}",{block["size"]-1},0)\n'
                for i, pad in enumerate(block["location"]):
                    cmd += f'design.assign_pkg_pin("{name}[{i}]","{pad}")\n'
            if "in_reg" in block:
                cmd += f'design.set_property("{name}","IN_REG","{block["in_reg"]}")\n'
                cmd += f'design.set_property("{name}","IN_CLK_PIN","{block["in_clk_pin"]}")\n'
                if "in_delay" in block:
                    cmd += f'design.set_property("{name}","INDELAY","{block["in_delay"]}")\n'
            if prop:
                for p, val in prop:
                    cmd += 'design.set_property("{}","{}","{}")\n'.format(name, p, val)
            cmd += "\n"
            return cmd

        if mode == "OUTPUT":
            if len(block["location"]) == 1:
                cmd += 'design.create_output_gpio("{}")\n'.format(name)
                cmd += 'design.assign_pkg_pin("{}","{}")\n'.format(name, block["location"][0])
            else:
                cmd += 'design.create_input_gpio("{}",{},0)\n'.format(name, block["size"]-1)
                for i, pad in enumerate(block["location"]):
                    cmd += 'design.assign_pkg_pin("{}[{}]","{}")\n'.format(name, i, pad)

            if "out_reg" in block:
                cmd += 'design.set_property("{}","OUT_REG","{}")\n'.format(name, block["out_reg"])
                cmd += 'design.set_property("{}","OUT_CLK_PIN","{}")\n'.format(name, block["out_clk_pin"])
                if "out_delay" in block:
                    cmd += 'design.set_property("{}","OUTDELAY","{}")\n'.format(name, block["out_delay"])

            if "out_clk_inv" in block:
                cmd += f'design.set_property("{name}","IS_OUTCLK_INVERTED","{block["out_clk_inv"]}")\n'
                cmd += f'design.set_property("{name}","OE_CLK_PIN_INV","{block["out_clk_inv"]}")\n'

            if "drive_strength" in block:
                cmd += 'design.set_property("{}","DRIVE_STRENGTH","4")\n'.format(name, block["drive_strength"])

            if prop:
                for p, val in prop:
                    cmd += 'design.set_property("{}","{}","{}")\n'.format(name, p, val)
            cmd += "\n"
            return cmd

        if mode == "INPUT_CLK":
            cmd += 'design.create_input_clock_gpio("{}")\n'.format(name)
            cmd += 'design.set_property("{}","IN_PIN","{}")\n'.format(name, name)
            cmd += 'design.assign_pkg_pin("{}","{}")\n\n'.format(name, block["location"])
            if prop:
                for p, val in prop:
                    cmd += 'design.set_property("{}","{}","{}")\n'.format(name, p, val)
            cmd += "\n"
            return cmd

        if mode == "MIPI_CLKIN":
            cmd += 'design.create_mipi_input_clock_gpio("{}")\n'.format(name)
            cmd += 'design.assign_pkg_pin("{}","{}")\n\n'.format(name, block["location"])
            return cmd

        if mode == "OUTPUT_CLK":
            cmd += 'design.create_clockout_gpio("{}")\n'.format(name)
            cmd += 'design.set_property("{}","OUT_CLK_PIN","{}")\n'.format(name, name)
            cmd += 'design.assign_pkg_pin("{}","{}")\n\n'.format(name, block["location"])
            if prop:
                for p, val in prop:
                    cmd += 'design.set_property("{}","{}","{}")\n'.format(name, p, val)
            cmd += "\n"
            return cmd

        cmd = "# TODO: " + str(block) +"\n"
        return cmd

    def generate_pll(self, block, partnumber, verbose=True):
        name = block["name"]
        cmd = "# ---------- PLL {} ---------\n".format(name)
        cmd += 'design.create_block("{}", block_type="PLL")\n'.format(name)
        cmd += 'pll_config = {{ "REFCLK_FREQ":"{}" }}\n'.format(block["input_freq"] / 1e6)
        cmd += 'design.set_property("{}", pll_config, block_type="PLL")\n\n'.format(name)

        if block["input_clock"] == "EXTERNAL":
            # PLL V1 has a different configuration
            if partnumber[0:2] in ["T4", "T8"]:
                cmd += 'design.gen_pll_ref_clock("{}", pll_res="{}", refclk_res="{}", refclk_name="{}", ext_refclk_no="{}")\n\n' \
                    .format(name, block["resource"], block["input_clock_pad"], block["input_clock_name"], block["clock_no"])
            else:
                cmd += 'design.gen_pll_ref_clock("{}", pll_res="{}", refclk_src="{}", refclk_name="{}", ext_refclk_no="{}")\n\n' \
                    .format(name, block["resource"], block["input_clock"], block["input_clock_name"], block["clock_no"])
        else:
            cmd += 'design.gen_pll_ref_clock("{}", pll_res="{}", refclk_name="{}", refclk_src="CORE")\n'.format(name, block["resource"], block["input_signal"])
            cmd += 'design.set_property("{}", "CORE_CLK_PIN", "{}", block_type="PLL")\n\n'.format(name, block["input_signal"])

        cmd += 'design.set_property("{}","LOCKED_PIN","{}", block_type="PLL")\n'.format(name, block["locked"])
        if block["rstn"] != "":
            cmd += 'design.set_property("{}","RSTN_PIN","{}", block_type="PLL")\n\n'.format(name, block["rstn"])

         # Output clock 0 is enabled by default
        for i, clock in enumerate(block["clk_out"]):
            if i > 0:
                cmd += 'pll_config = {{ "CLKOUT{}_EN":"1", "CLKOUT{}_PIN":"{}" }}\n'.format(i, i, clock[0])
            else:
                cmd += 'pll_config = {{ "CLKOUT{}_PIN":"{}" }}\n'.format(i, clock[0])

            cmd += 'design.set_property("{}", pll_config, block_type="PLL")\n\n'.format(name)

        for i, clock in enumerate(block["clk_out"]):
            if block["version"] == "V1_V2":
                cmd += 'design.set_property("{}","CLKOUT{}_PHASE","{}","PLL")\n'.format(name, i, clock[2])
            else:
                cmd += 'design.set_property("{}","CLKOUT{}_PHASE_SETTING","{}","PLL")\n'.format(name, i, clock[2] // 45)

        cmd += "target_freq = {\n"
        for i, clock in enumerate(block["clk_out"]):
            cmd += '    "CLKOUT{}_FREQ": "{}",\n'.format(i, clock[1] / 1e6)
        cmd += "}\n"
        cmd += 'calc_result = design.auto_calc_pll_clock("{}", target_freq)\n'.format(name)

        if "extra" in block:
            cmd += block["extra"]
            cmd += "\n"

        if verbose:
            cmd += 'print("#### {} ####")\n'.format(name)
            cmd += 'clksrc_info = design.trace_ref_clock("{}", block_type="PLL")\n'.format(name)
            cmd += 'pprint.pprint(clksrc_info)\n'
            cmd += 'clock_source_prop = ["REFCLK_SOURCE", "CORE_CLK_PIN", "EXT_CLK", "REFCLK_FREQ", "RESOURCE"]\n'
            for i, clock in enumerate(block["clk_out"]):
                cmd += 'clock_source_prop += ["CLKOUT{}_FREQ", "CLKOUT{}_PHASE", "CLKOUT{}_EN"]\n'.format(i, i, i)
            cmd += 'prop_map = design.get_property("{}", clock_source_prop, block_type="PLL")\n'.format(name)
            cmd += 'pprint.pprint(prop_map)\n'

            for i, clock in enumerate(block["clk_out"]):
                cmd += '\nfreq = float(prop_map["CLKOUT{}_FREQ"])\n'.format(i)
                cmd += 'if freq != {}:\n'.format(clock[1]/1e6)
                cmd += '    print("ERROR: CLKOUT{} configured for {}MHz is {{}}MHz".format(freq))\n'.format(i, clock[1]/1e6)
                cmd += '    exit("PLL ERROR")\n'

        cmd += "\n#---------- END PLL {} ---------\n\n".format(name)
        return cmd

    def generate_jtag(self, block, verbose=True):
        name = block["name"]
        id   = block["id"]
        pins = block["pins"]

        def get_pin_name(pin):
            return pin.backtrace[-1][0]

        cmds = []
        cmds.append(f"# ---------- JTAG {id} ---------")
        cmds.append(f'jtag = design.create_block("jtag_soc", block_type="JTAG")')
        cmds.append(f'design.assign_resource(jtag, "JTAG_USER{id}", "JTAG")')
        cmds.append(f'jtag_config = {{')
        cmds.append(f'    "CAPTURE" : "{get_pin_name(pins.CAPTURE)}",')
        cmds.append(f'    "DRCK"    : "{get_pin_name(pins.DRCK)}",')
        cmds.append(f'    "RESET"   : "{get_pin_name(pins.RESET)}",')
        cmds.append(f'    "RUNTEST" : "{get_pin_name(pins.RUNTEST)}",')
        cmds.append(f'    "SEL"     : "{get_pin_name(pins.SEL)}",')
        cmds.append(f'    "SHIFT"   : "{get_pin_name(pins.SHIFT)}",')
        cmds.append(f'    "TCK"     : "{get_pin_name(pins.TCK)}",')
        cmds.append(f'    "TDI"     : "{get_pin_name(pins.TDI)}",')
        cmds.append(f'    "TMS"     : "{get_pin_name(pins.TMS)}",')
        cmds.append(f'    "UPDATE"  : "{get_pin_name(pins.UPDATE)}",')
        cmds.append(f'    "TDO"     : "{get_pin_name(pins.TDO)}"')
        cmds.append(f'}}')
        cmds.append(f'design.set_property("jtag_soc", jtag_config, block_type="JTAG")')
        cmds.append(f"# ---------- END JTAG {id} ---------\n")
        return "\n".join(cmds)

    def generate(self, partnumber):
        output = ""
        for block in self.blocks:
            if isinstance(block, InterfaceWriterBlock):
                output += block.generate()
            else:
                if block["type"] == "PLL":
                    output += self.generate_pll(block, partnumber)
                if block["type"] == "GPIO":
                    output += self.generate_gpio(block)
                if block["type"] == "MIPI_TX_LANE":
                    output += self.generate_mipi_tx(block)
                if block["type"] == "MIPI_RX_LANE":
                    output += self.generate_mipi_rx(block)
                if block["type"] == "JTAG":
                    output += self.generate_jtag(block)
        return output

    def footer(self):
        return """
# Check design, generate constraints and reports
design.generate(enable_bitstream=True)
# Save the configured periphery design
design.save()"""


    def add_lvds_xml(self, root, params):
        lvds_info = root.find("efxpt:lvds_info", namespaces)
        if params["mode"] == "OUTPUT":
            dir  = "tx"
            mode = "out"
        else:
            dir  = "rx"
            mode = "in"

        pad = self.platform.parser.get_pad_name_from_pin(params["location"][0])
        pad = pad.replace("TXP", "TX")
        pad = pad.replace("TXN", "TX")
        pad = pad.replace("RXP", "RX")
        pad = pad.replace("RXN", "RX")
        # Sometimes there is an extra identifier at the end
        # TODO: do a better parser
        if pad.count("_") == 2:
            pad = pad.rsplit("_", 1)[0]

        lvds = et.SubElement(lvds_info, "efxpt:lvds",
            name     = params["name"],
            lvds_def = pad,
            ops_type = dir
        )

        et.SubElement(lvds, "efxpt:ltx_info",
            pll_instance    = "",
            fast_clock_name = "{}".format(params["fast_clk"]),
            slow_clock_name = "{}".format(params["slow_clk"]),
            reset_name      = "",
            out_bname       = "{}".format(params["name"]),
            oe_name         = "",
            clock_div       = "1",
            mode            = "{}".format(mode),
            serialization   = "{}".format(params["serialisation"]),
            reduced_swing   = "false",
            load            = "3"
        )

    def add_iobank_info_xml(self, root, iobank_info):
        dev = root.find("efxpt:device_info", namespaces)
        bank_info = dev.find("efxpt:iobank_info", namespaces)
        for name, iostd in iobank_info:
            for child in bank_info:
                    if name == child.get("name"):
                        child.set("iostd", iostd)
