#!/usr/bin/env python3

# SPDX-FileCopyrightText: © 2022 Decompollaborate
# SPDX-License-Identifier: MIT

from __future__ import annotations

import sys
from typing import TextIO
from pathlib import Path

from .. import common

from . import symbols


class FileBase(common.ElementBase):
    def __init__(self, context: common.Context, vromStart: int, vromEnd: int, vram: int, filename: str, words: list[int], sectionType: common.FileSectionType, segmentVromStart: int, overlayCategory: str|None):
        super().__init__(context, vromStart, vromEnd, 0, vram, filename, words, sectionType, segmentVromStart, overlayCategory)

        self.symbolList: list[symbols.SymbolBase] = []

        self.pointersOffsets: set[int] = set()

        self.isHandwritten: bool = False

        self.fileBoundaries: list[int] = list()

        self.symbolsVRams: set[int] = set()
        "addresses of symbols in this section"


    def setCommentOffset(self, commentOffset: int):
        self.commentOffset = commentOffset
        for sym in self.symbolList:
            sym.setCommentOffset(self.commentOffset)

    def getAsmPrelude(self) -> str:
        output = ""

        output += ".include \"macro.inc\"" + common.GlobalConfig.LINE_ENDS
        output += common.GlobalConfig.LINE_ENDS
        output += "# assembler directives" + common.GlobalConfig.LINE_ENDS
        output += ".set noat      # allow manual use of $at" + common.GlobalConfig.LINE_ENDS
        output += ".set noreorder # don't insert nops after branches" + common.GlobalConfig.LINE_ENDS
        output += ".set gp=64     # allow use of 64-bit general purpose registers" + common.GlobalConfig.LINE_ENDS
        output += common.GlobalConfig.LINE_ENDS
        output += f".section {self.sectionType.toSectionName()}" + common.GlobalConfig.LINE_ENDS
        output += common.GlobalConfig.LINE_ENDS
        output += ".balign 16" + common.GlobalConfig.LINE_ENDS

        return output

    def getHash(self) -> str:
        buffer = bytearray(4*len(self.words))
        common.Utils.wordsToBytes(self.words, buffer)
        return common.Utils.getStrHash(buffer)


    def checkAndCreateFirstSymbol(self) -> None:
        "Check if the very start of the file has a symbol and create it if it doesn't exist yet"

        currentVram = self.getVramOffset(0)
        vrom = self.getVromOffset(0)
        contextSym = self.getSymbol(currentVram, tryPlusOffset=False)
        if contextSym is None:
            contextSym = self.addSymbol(currentVram, sectionType=self.sectionType, isAutogenerated=True, symbolVrom=vrom)


    def printNewFileBoundaries(self):
        if not common.GlobalConfig.PRINT_NEW_FILE_BOUNDARIES:
            return

        if len(self.fileBoundaries) > 0:
            print(f"File {self.name}")
            print(f"Section: {self.sectionType.toStr()}")
            print(f"Found {len(self.symbolList)} symbols.")
            print(f"Found {len(self.fileBoundaries)} file boundaries.")

            print("\t offset, size,     vram, symbols")

            boundaries = list(self.fileBoundaries)
            boundaries.append(self.sizew*4 + self.inFileOffset)

            for i in range(len(boundaries)-1):
                start = boundaries[i]
                end = boundaries[i+1]

                symbolsInBoundary = 0
                for func in self.symbolList:
                    funcOffset = func.inFileOffset - self.inFileOffset
                    if start <= funcOffset < end:
                        symbolsInBoundary += 1
                fileVram = 0
                if self.vram is not None:
                    fileVram = start + self.vram
                print(f"\t {start+self.commentOffset:06X}, {end-start:04X}, {fileVram:08X}, {symbolsInBoundary:7}")

            print()

    def printAnalyzisResults(self):
        self.printNewFileBoundaries()


    def compareToFile(self, other_file: FileBase) -> dict:
        hash_one = self.getHash()
        hash_two = other_file.getHash()

        result = {
            "equal": hash_one == hash_two,
            "hash_one": hash_one,
            "hash_two": hash_two,
            "size_one": self.sizew * 4,
            "size_two": other_file.sizew * 4,
            "diff_bytes": 0,
            "diff_words": 0,
        }

        diff_bytes = 0
        diff_words = 0

        if not result["equal"]:
            min_len = min(self.sizew, other_file.sizew)
            for i in range(min_len):
                for j in range(4):
                    if (self.words[i] & (0xFF << (j * 8))) != (other_file.words[i] & (0xFF << (j * 8))):
                        diff_bytes += 1

            min_len = min(self.sizew, other_file.sizew)
            for i in range(min_len):
                if self.words[i] != other_file.words[i]:
                    diff_words += 1

        result["diff_bytes"] = diff_bytes
        result["diff_words"] = diff_words

        return result

    def blankOutDifferences(self, other: FileBase) -> bool:
        if not common.GlobalConfig.REMOVE_POINTERS:
            return False

        return False

    def removePointers(self) -> bool:
        if not common.GlobalConfig.REMOVE_POINTERS:
            return False

        return False


    def disassemble(self) -> str:
        output = ""
        for i, sym in enumerate(self.symbolList):
            output += sym.disassemble()
            if i + 1 < len(self.symbolList):
                output += common.GlobalConfig.LINE_ENDS
        return output

    def disassembleToFile(self, f: TextIO):
        if common.GlobalConfig.ASM_USE_PRELUDE:
            f.write(self.getAsmPrelude())
            f.write(common.GlobalConfig.LINE_ENDS)
        f.write(self.disassemble())


    def saveToFile(self, filepath: str):
        if len(self.symbolList) == 0:
            return

        if filepath == "-":
            self.disassembleToFile(sys.stdout)
        else:
            if common.GlobalConfig.WRITE_BINARY:
                if self.sizew > 0:
                    buffer = bytearray(4*len(self.words))
                    common.Utils.wordsToBytes(self.words, buffer)
                    common.Utils.writeBytearrayToFile(Path(filepath + self.sectionType.toStr()), buffer)
            with open(filepath + self.sectionType.toStr() + ".s", "w") as f:
                self.disassembleToFile(f)


def createEmptyFile() -> FileBase:
    return FileBase(common.Context(), 0, 0, 0, "", [], common.FileSectionType.Unknown, 0, None)
