import struct
from typing import cast
from typing import Any
from typing import Callable
from typing import ClassVar
from typing import Literal
from typing import TypeVar
from typing import Union
from typing import TYPE_CHECKING

import pydantic
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives.asymmetric.ec import generate_private_key
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicKey
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePublicNumbers
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurve
from cryptography.hazmat.primitives.asymmetric.ec import ECDSA
from cryptography.hazmat.primitives.asymmetric.ec import ECDH
from cryptography.hazmat.primitives.asymmetric.ec import SECP256K1
from cryptography.hazmat.primitives.asymmetric.ec import SECP256R1
from cryptography.hazmat.primitives.asymmetric.ec import SECP384R1
from cryptography.hazmat.primitives.asymmetric.ec import SECP521R1
from cryptography.hazmat.primitives.asymmetric.utils import encode_dss_signature
from cryptography.hazmat.primitives.hashes import SHA256
from cryptography.hazmat.primitives.kdf.concatkdf import ConcatKDFHash
from libcanonical.types import AwaitableBool
from libcanonical.types import AwaitableBytes
from libcanonical.types import Base64
from libcanonical.utils.encoding import b64decode
from libcanonical.utils.encoding import b64decode_int
from libcanonical.utils.encoding import b64encode_int
from libcanonical.utils.encoding import bytes_to_number

from aegisx.ext.jose.types import EncryptionResult
from aegisx.ext.jose.types import JSONWebAlgorithm
from ._jsonwebkeybase import JSONWebKeyBase
from ._symmetricencryptionkey import SymmetricEncryptionKey
if TYPE_CHECKING:
    from ._jsonwebkey import JSONWebKey
    from ._jsonwebkeyellipticcurveprivate import JSONWebKeyEllipticCurvePrivate

R = TypeVar('R')


class JSONWebKeyEllipticCurvePublic(
    JSONWebKeyBase[
        Literal['EC'],
        Literal['sign', 'verify', 'encrypt', 'decrypt', 'wrapKey', 'unwrapKey']
    ]
):
    model_config = {'extra': 'forbid'}
    thumbprint_claims = ["crv", "kty", "x", "y"]
    curves: ClassVar[dict[str, type[EllipticCurve]]] = {
        'P-256': SECP256R1,
        'P-256K': SECP256K1,
        'P-384': SECP384R1,
        'P-521': SECP521R1,
    }

    crv: Literal['P-256', 'P-256K', 'P-384', 'P-521'] = pydantic.Field(
        default=...,
        title="Curve",
        description=(
            "The `crv` (curve) parameter identifies the "
            "cryptographic curve used with the key."
        )
    )

    x: str = pydantic.Field(
        default=...,
        title="X coordinate",
        description=(
            "The `x` (x coordinate) parameter contains the x "
            "coordinate for the Elliptic Curve point. It is "
            "represented as the base64url encoding of the octet "
            "string representation of the coordinate, as defined "
            "in Section 2.3.5 of SEC1. The length of this octet "
            "string MUST be the full size of a coordinate for "
            "the curve specified in the `crv` parameter. For "
            "example, if the value of `crv` is `P-521`, the octet "
            "string must be 66 octets long."
        )
    )

    y: str = pydantic.Field(
        default=...,
        title="Y coordinate",
        description=(
            "The `y` (y coordinate) parameter contains the y "
            "coordinate for the Elliptic Curve point. It is "
            "represented as the base64url encoding of the octet "
            "string representation of the coordinate, as defined "
            "in Section 2.3.5 of SEC1. The length of this octet "
            "string MUST be the full size of a coordinate for "
            "the curve specified in the `crv` parameter. For "
            "example, if the value of `crv` is `P-521`, the "
            "octet string must be 66 octets long."
        )
    )

    @property
    def public_numbers(self) -> EllipticCurvePublicNumbers:
        return EllipticCurvePublicNumbers(
            curve=self.get_curve(self.crv),
            x=b64decode_int(self.x),
            y=b64decode_int(self.y)
        )

    @property
    def public_key(self):
        return self.public_numbers.public_key()

    @classmethod
    def get_curve(cls, crv: str):
        return cls.curves[crv]()

    @classmethod
    def supports_algorithm(cls, alg: JSONWebAlgorithm) -> bool:
        return alg.config.kty == 'EC'

    def epk(self) -> tuple['JSONWebKeyEllipticCurvePublic', EllipticCurvePrivateKey]:
        private = generate_private_key(self.public_numbers.curve)
        n = private.private_numbers()
        return (
            JSONWebKeyEllipticCurvePublic.model_validate({
                **self.model_dump(include={'kty', 'crv'}),
                'x': b64encode_int(n.public_numbers.x),
                'y': b64encode_int(n.public_numbers.y)
            }),
            private
        )

    def exchange(self, f: Callable[[Any], R]) -> R:
        return f(self.public_key)

    def get_public_key(self):
        return JSONWebKeyEllipticCurvePublic.model_validate(self.model_dump())

    def is_asymmetric(self) -> bool:
        return True

    def derive(
        self,
        alg: JSONWebAlgorithm,
        enc: JSONWebAlgorithm,
        private: Union['JSONWebKey', EllipticCurvePrivateKey],
        public: Union['JSONWebKey', EllipticCurvePublicKey],
        apu: bytes,
        apv: bytes,
        ct: EncryptionResult | None = None
    ) -> bytes:
        if not isinstance(private, EllipticCurvePrivateKey):
            private = cast('JSONWebKeyEllipticCurvePrivate', private.root).private_key
        if not isinstance(public, EllipticCurvePublicKey):
            public = cast('JSONWebKeyEllipticCurvePublic', public.root).public_key
        length = enc.length if alg.is_direct() else alg.length
        cipher = enc.cipher if alg.is_direct() else alg.cipher
        if not length:
            raise ValueError(
                f"Unable to determine key size from management algorithm {alg} "
                f"and content encryption algorithm {enc}."
            )
        # The AlgorithmID value is of the form Datalen || Data, where Data is a
        # variable-length string of zero or more octets, and Datalen is a fixed-length,
        # big-endian 32-bit counter that indicates the length (in octets) of Data.
        # In the Direct Key Agreement case, Data is set to the octets of the ASCII
        # representation of the "enc" Header Parameter value.  In the Key Agreement
        # with Key Wrapping case, Data is set to the octets of the ASCII representation
        # of the "alg" (algorithm) Header Parameter value.
        algorithm_id = enc if alg.is_direct() else alg
        otherinfo = struct.pack('>I', len(algorithm_id))
        otherinfo += str.encode(algorithm_id, 'utf-8')

        # PartyUInfo
        apu = b64decode(apu) if apu else b''
        otherinfo += struct.pack('>I', len(apu))
        otherinfo += apu

        # PartyVInfo
        apv = b64decode(apv) if apv else b''
        otherinfo += struct.pack('>I', len(apv))
        otherinfo += apv

        # SuppPubInfo
        otherinfo += struct.pack('>I', length)

        # Shared Key generation
        if isinstance(private, EllipticCurvePrivateKey): # type: ignore
            shared_key = private.exchange(ECDH(), public)
        else:
            # X25519/X448
            raise NotImplementedError

        # TODO: abstract this
        keysize = length // 8
        if cipher == 'AES+CBC':
            # In CBC mode, the derived key must be twice the length
            # of the algorithm as the first half is used as the MAC
            # key.
            keysize *= 2

        # RFC 7518 4.6.2: Key derivation is performed using the Concat KDF,
        # as defined in Section 5.8.1 of [NIST.800-56A], where the Digest
        # Method is SHA-256.
        ckdf = ConcatKDFHash(algorithm=SHA256(),
            length=keysize,
            otherinfo=otherinfo,
        )
        k = AwaitableBytes(ckdf.derive(shared_key))
        return k

    def derive_cek(
        self,
        alg: JSONWebAlgorithm,
        enc: JSONWebAlgorithm,
        private: Union['JSONWebKey', EllipticCurvePrivateKey],
        public: Union['JSONWebKey', EllipticCurvePublicKey],
        apu: bytes,
        apv: bytes,
        ct: EncryptionResult | None = None
    ) -> SymmetricEncryptionKey:
        # TODO: ugly, refactor
        if not isinstance(public, EllipticCurvePublicKey):
            if isinstance(public.root, JSONWebKeyEllipticCurvePublic):
                public = public.root.public_key
        if not isinstance(public, EllipticCurvePublicKey):
            raise TypeError(
                "A key can only be derived using an elliptic curve "
                "public key."
            )
        shared = self.derive(alg, enc, private, public, apu, apv, ct)
        return SymmetricEncryptionKey(
            alg=alg.wrap if not alg.is_direct() else enc,
            kty='oct',
            k=Base64(shared)
        )

    def encrypt(
        self,
        pt: bytes,
        aad: bytes | None,
        alg: JSONWebAlgorithm
    ) -> EncryptionResult:
        if not alg.length:
            raise ValueError(f"Algorithm {alg} does not specify key size.")
        if not alg.wrap:
            raise ValueError(f"Algorithm {alg} does not specify a wrapping algorithm.")
        public, private = self.epk()
        shared = self.derive_cek(alg, alg.wrap, private, self.public_key, b'', b'', None)
        result: EncryptionResult
        match alg.mode:
            case 'KEY_AGREEMENT_WITH_KEY_WRAPPING':
                result = EncryptionResult.model_validate({
                    'alg': alg,
                    'ct': bytes(shared.encrypt(pt)),
                    'epk': public.model_dump()
                })
            case _:
                raise NotImplementedError(f"Unsupported algorithm: {alg}")
        return result

    def verify(
        self,
        signature: bytes,
        message: bytes
    ) -> AwaitableBool:
        assert self.alg is not None
        n = (self.public_key.curve.key_size + 7) // 8
        try:
            self.public_key.verify(
                signature=encode_dss_signature(
                    bytes_to_number(signature[:n]),
                    bytes_to_number(signature[n:]),
                ),
                data=message,
                signature_algorithm=ECDSA(self.get_hash(self.alg))
            )
            return AwaitableBool(True)
        except InvalidSignature:
            return AwaitableBool(False)