#!/usr/bin/env python3

# SPDX-FileCopyrightText: © 2022 Decompollaborate
# SPDX-License-Identifier: MIT

from __future__ import annotations

from typing import Callable

from ... import common


class SymbolBase(common.ElementBase):
    def __init__(self, context: common.Context, vromStart: int, vromEnd: int, inFileOffset: int, vram: int, words: list[int], sectionType: common.FileSectionType, segmentVromStart: int, overlayCategory: str|None):
        super().__init__(context, vromStart, vromEnd, inFileOffset, vram, "", words, sectionType, segmentVromStart, overlayCategory)

        self.endOfLineComment: list[str] = []
        self.contextSym: common.ContextSymbol

        offsetSym = context.offsetSymbols[self.sectionType].get(self.vromStart, None)
        if offsetSym is not None:
            # Check symbols from relocatable elf files
            self.contextSym = offsetSym
        else:
            self.contextSym = self.addSymbol(self.vram, sectionType=self.sectionType, isAutogenerated=True)
        self.contextSym.vromAddress = self.vromStart
        self.contextSym.isDefined = True
        self.contextSym.sectionType = self.sectionType


    def getName(self) -> str:
        return self.contextSym.getName()

    def setNameIfUnset(self, name: str) -> None:
        self.contextSym.setNameIfUnset(name)

    def setNameGetCallback(self, callback: Callable[[common.ContextSymbol], str]) -> None:
        self.contextSym.setNameGetCallback(callback)

    def setNameGetCallbackIfUnset(self, callback: Callable[[common.ContextSymbol], str]) -> None:
        self.contextSym.setNameGetCallbackIfUnset(callback)


    def generateAsmLineComment(self, localOffset: int, wordValue: int|None = None) -> str:
        if not common.GlobalConfig.ASM_COMMENT:
            return ""

        offsetHex = "{0:0{1}X}".format(localOffset + self.inFileOffset + self.commentOffset, common.GlobalConfig.ASM_COMMENT_OFFSET_WIDTH)

        currentVram = self.getVramOffset(localOffset)
        vramHex = f"{currentVram:08X}"

        wordValueHex = ""
        if wordValue is not None:
            wordValueHex = f"{common.Utils.wordToCurrenEndian(wordValue):08X} "

        return f"/* {offsetHex} {vramHex} {wordValueHex}*/"

    def getSymbolAtVramOrOffset(self, localOffset: int) -> common.ContextSymbol|None:
        contextSym = self.context.getOffsetSymbol(self.inFileOffset + localOffset, self.sectionType)
        if contextSym is not None:
            return contextSym

        currentVram = self.getVramOffset(localOffset)
        return self.getSymbol(currentVram, tryPlusOffset=False)

    def getLabel(self) -> str:
        if self.contextSym is not None:
            return self.getLabelFromSymbol(self.contextSym)

        offsetSym = self.context.getOffsetSymbol(self.inFileOffset, self.sectionType)
        return self.getLabelFromSymbol(offsetSym)

    def getLabelAtOffset(self, localOffset: int) -> str:
        label = ""
        contextSym = self.getSymbolAtVramOrOffset(localOffset)
        if contextSym is not None:
            # Possible symbols in the middle
            label = common.GlobalConfig.LINE_ENDS
            symLabel = contextSym.getSymbolLabel()
            if symLabel:
                label += symLabel + common.GlobalConfig.LINE_ENDS
                if common.GlobalConfig.ASM_DATA_SYM_AS_LABEL:
                    label += f"{contextSym.getName()}:" + common.GlobalConfig.LINE_ENDS
        return label


    def isRdata(self) -> bool:
        "Checks if the current symbol is .rdata"
        return False


    def renameBasedOnType(self):
        pass


    def analyze(self):
        self.renameBasedOnType()

        byteStep = 4
        if self.contextSym.isByte():
            byteStep = 1
        elif self.contextSym.isShort():
            byteStep = 2

        if self.sectionType != common.FileSectionType.Bss:
            for i in range(0, self.sizew):
                localOffset = 4*i
                for j in range(0, 4, byteStep):
                    if i == 0 and j == 0:
                        continue
                    contextSym = self.getSymbolAtVramOrOffset(localOffset+j)
                    if contextSym is not None:
                        contextSym.vromAddress = self.getVromOffset(localOffset+j)
                        contextSym.isDefined = True
                        contextSym.sectionType = self.sectionType
                        if contextSym.hasNoType():
                            contextSym.type = contextSym.type


    def getNthWordAsBytes(self, i: int) -> tuple[str, int]:
        output = ""
        localOffset = 4*i
        w = self.words[i]

        dotType = ".byte"
        for j in range(0, 4):
            label = ""
            if j != 0 or i != 0:
                label = self.getLabelAtOffset(localOffset + j)

            shiftValue = j * 8
            if common.GlobalConfig.ENDIAN == common.InputEndian.BIG:
                shiftValue = 24 - shiftValue
            subVal = (w & (0xFF << shiftValue)) >> shiftValue
            value = f"0x{subVal:02X}"

            comment = self.generateAsmLineComment(localOffset+j)
            output += f"{label}{comment} {dotType} {value}"
            if j == 0 and i < len(self.endOfLineComment):
                output += self.endOfLineComment[i]
            output += common.GlobalConfig.LINE_ENDS

        return output, 0

    def getNthWordAsShorts(self, i: int) -> tuple[str, int]:
        output = ""
        localOffset = 4*i
        w = self.words[i]

        dotType = ".short"
        for j in range(0, 4, 2):
            label = ""
            if j != 0 or i != 0:
                label = self.getLabelAtOffset(localOffset + j)

            shiftValue = j * 8
            if common.GlobalConfig.ENDIAN == common.InputEndian.BIG:
                shiftValue = 16 - shiftValue
            subVal = (w & (0xFFFF << shiftValue)) >> shiftValue
            value = f"0x{subVal:04X}"

            comment = self.generateAsmLineComment(localOffset+j)
            output += f"{label}{comment} {dotType} {value}"
            if j == 0 and i < len(self.endOfLineComment):
                output += self.endOfLineComment[i]
            output += common.GlobalConfig.LINE_ENDS

        return output, 0

    def getNthWordAsWords(self, i: int, canReferenceSymbolsWithAddends: bool=False, canReferenceConstants: bool=False) -> tuple[str, int]:
        output = ""
        localOffset = 4*i
        w = self.words[i]

        dotType = ".word"

        label = ""
        if i != 0:
            label = self.getLabelAtOffset(localOffset)

        value = f"0x{w:08X}"

        # .elf relocated symbol
        if len(self.context.relocSymbols[self.sectionType]) > 0:
            possibleReference = self.context.getRelocSymbol(self.inFileOffset + localOffset, self.sectionType)
            if possibleReference is not None:
                value = possibleReference.getNamePlusOffset(w)
        else:
            # This word could be a reference to a symbol
            symbolRef = self.getSymbol(w, tryPlusOffset=canReferenceSymbolsWithAddends)
            if symbolRef is not None:
                value = symbolRef.getSymbolPlusOffset(w)
            elif canReferenceConstants:
                constant = self.getConstant(w)
                if constant is not None:
                    value = constant.getName()

        comment = self.generateAsmLineComment(localOffset)
        output += f"{label}{comment} {dotType} {value}"
        if i < len(self.endOfLineComment):
            output += self.endOfLineComment[i]
        output += common.GlobalConfig.LINE_ENDS

        return output, 0

    def getNthWord(self, i: int, canReferenceSymbolsWithAddends: bool=False, canReferenceConstants: bool=False) -> tuple[str, int]:
        if self.contextSym.isByte():
            return self.getNthWordAsBytes(i)
        if self.contextSym.isShort():
            return self.getNthWordAsShorts(i)
        return self.getNthWordAsWords(i, canReferenceSymbolsWithAddends=canReferenceSymbolsWithAddends, canReferenceConstants=canReferenceConstants)


    def countExtraPadding(self) -> int:
        "Returns how many extra word paddings this symbol has"
        return 0

    def getPrevAlignDirective(self, i: int=0) -> str:
        return ""

    def getPostAlignDirective(self, i: int=0) -> str:
        return ""

    def disassembleAsData(self) -> str:
        output = self.getPrevAlignDirective(0)
        output += self.getLabel()
        if common.GlobalConfig.ASM_DATA_SYM_AS_LABEL:
            output += f"{self.getName()}:" + common.GlobalConfig.LINE_ENDS

        canReferenceSymbolsWithAddends = self.canUseAddendsOnData()
        canReferenceConstants = self.canUseConstantsOnData()

        i = 0
        while i < self.sizew:
            data, skip = self.getNthWord(i, canReferenceSymbolsWithAddends, canReferenceConstants)
            if i != 0:
                output += self.getPrevAlignDirective(i)
            output += data
            output += self.getPostAlignDirective(i)

            i += skip
            i += 1
        return output

    def disassemble(self) -> str:
        return self.disassembleAsData()
