"""Declares :class:`Application`."""
import asyncio
import logging
import os
import random
import urllib.parse


import fastapi
import unimatrix.runtime
from fastapi import Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import PlainTextResponse
from fastapi.responses import JSONResponse
from ioc.exc import UnsatisfiedDependency
from starlette.middleware.trustedhost import TrustedHostMiddleware
from unimatrix.conf import settings
from unimatrix.ext import crypto
from unimatrix.ext import jose
from unimatrix.ext.model.exc import CanonicalException
from unimatrix.ext.model.exc import FeatureNotSupported

from .exceptions import UpstreamServiceNotAvailable
from .exceptions import UpstreamConnectionFailure
from .healthcheck import live as liveness_handler
from .healthcheck import ready as readyness_handler
from .metadata import APIMetadataService
from .models import APIMetadata
from .models import Identification
from .models import IDToken
from .trustedidentityproviders import TrustedIdentityProviders


class Application(fastapi.FastAPI):
    """Provides the ASGI interface to handle requests."""
    cors_max_age: int = 600
    logger: logging.Logger = logging.getLogger('unimatrix.ext.webapi')

    def __init__(self, *args, **kwargs):
        kwargs.setdefault('openid_providers',
            getattr(settings, 'OPENID_PROVIDERS', {})
        )
        kwargs.setdefault('redoc_url', getattr(settings, 'REDOC_URL', '/docs'))
        kwargs.setdefault('docs_url', getattr(settings, 'DOCS_URL', '/ui'))
        kwargs.setdefault('openapi_url',
            getattr(settings, 'OPENAPI_URL', '/openapi.json')
        )
        kwargs.setdefault('root_path', os.getenv('HTTP_MOUNT_PATH'))

        # Configure the default exception handlers for the errors specified by
        # the Unimatrix Framework.
        exception_handlers = kwargs.setdefault('exception_handlers', {})
        exception_handlers.update({
            CanonicalException: self.canonical_exception,
            ConnectionError: self.canonical_exception,
            UnsatisfiedDependency: self.canonical_exception
        })

        # Remove the additional variables that we added to prevent them from
        # being passed to the fastapi.FastAPI.
        allowed_hosts = kwargs.pop('allowed_hosts', None)
        self.audience = kwargs.pop('audience', None)
        self.issuer = kwargs.pop('issuer', None)
        self.openid = kwargs.pop('openid', TrustedIdentityProviders())

        # Check if debug endpoints are enabled
        enable_debug_endpoints = kwargs.pop('enable_debug_endpoints', False)

        # Add the trusted OpenID identity providers.
        openid_providers = kwargs.pop('openid_providers')
        for url, params in dict.items(openid_providers): # pragma: no cover
            self.add_openid_provider(url, **params)

        super().__init__(*args, **kwargs)

        # Add standard health-check routes. The initial use case here was
        # Kubernetes.
        self.add_api_route(
            '/.well-known/health/live',
            liveness_handler,
            name='live',
            status_code=204,
            tags=['Health'],
            methods=['GET'],
            response_description = "The service is live.",
            responses={
                '503': {'description': "The service is not live."},
            }
        )
        self.add_api_route(
            '/.well-known/health/ready',
            readyness_handler,
            name='ready',
            tags=['Health'],
            methods=['GET'],
            status_code=204,
            response_description = "The service is ready.",
            responses={
                '503': {'description': "The service is not ready."},
            }
        )

        # Ensure that the Unimatrix startup and teardown functions are invoked
        # when spawning a new ASGI application.

        @self.on_event('startup')
        async def on_startup(): # pylint: disable=unused-variable
            await unimatrix.runtime.on('boot') # pragma: no cover
            await self.openid.on_setup() # pragma: no cover

        @self.on_event('shutdown')
        async def on_shutdown(): # pylint: disable=unused-variable
            await unimatrix.runtime.on('shutdown') # pragma: no cover

        # Add mandatory middleware to the application.
        self.add_middleware(
            TrustedHostMiddleware,
            allowed_hosts=(
                allowed_hosts or getattr(settings, 'HTTP_ALLOWED_HOSTS', [])
            )
        )

        # Enable CORS based on the environment variables and/or settings
        # module.
        self.enable_cors(
            allow_origins=settings.HTTP_CORS_ALLOW_ORIGINS,
            allow_credentials=settings.HTTP_CORS_ALLOW_CREDENTIALS,
            allow_methods=settings.HTTP_CORS_ALLOW_METHODS,
            allow_headers=settings.HTTP_CORS_ALLOW_HEADERS,
            expose_headers=settings.HTTP_CORS_EXPOSE_HEADERS,
            max_age=settings.HTTP_CORS_TTL
        )

        # Add debug handlers if the debug endpoints are enabled.
        if enable_debug_endpoints:
            debug = fastapi.APIRouter()

            @debug.get('/sleep') # pragma: no cover
            async def sleep(seconds: float = None):
                await asyncio.sleep(seconds or random.randint(0, 5) / 10) # nosec

            @debug.post('/token', response_class=PlainTextResponse)
            async def create_bearer_token(dto: dict) -> str:
                """Create a JWT with the claims provided in the request body.
                For development purposes only.
                """
                jwt = await jose.jwt(dto, signer=crypto.get_signer())
                return bytes.decode(bytes(jwt))

            self.include_router(debug, prefix='/debug', tags=['Debug'])

        @self.get('/.well-known/self', tags=['Metadata'], response_model=APIMetadata)
        async def metadata(request: Request):
            svc = APIMetadataService()
            return APIMetadata(**await svc.get(self, request))

        # Add the .well-known/identify endpoint.
        @self.post('/.well-known/identify', tags=['OAuth 2.0'])
        async def identify(request: Request, dto: Identification):
            """Identify the request using an OpenID-compliant identity
            token. Issue an access token that is trusted by the other
            endpoints of this API.

            The identity must at least contain the following OpenID
            standard claims:

            - `iss`
            - `aud`
            - `sub`

            If any of the claims is missing or invalid, or the ID token is
            malformed, then the endpoint responds with a `422` status code.
            """
            claims = await self.openid.verify(dto.id_token)
            token = IDToken(**claims)
            return await token.sign(
                aud=self.get_audience(request),
                iss=self.get_issuer(request)
            )

    async def canonical_exception(self, request, exception):
        """Handles a canonical exception to a standard error message format."""
        if isinstance(exception, ConnectionRefusedError):
            kwargs = {}
            return await self.canonical_exception(
                request,
                UpstreamServiceNotAvailable(**kwargs),
            )
        elif isinstance(
            exception,
            (BrokenPipeError, ConnectionResetError, ConnectionAbortedError)
        ):
            kwargs = {}
            return await self.canonical_exception(
                request,
                UpstreamConnectionFailure(**kwargs),
            )
        elif isinstance(exception, UnsatisfiedDependency):
            return await self.canonical_exception(
                request, FeatureNotSupported()
            )
        elif isinstance(exception, CanonicalException):
            if exception.http_status_code >= 500:
                exception.log(self.logger.exception)
            return JSONResponse(
                status_code=exception.http_status_code,
                content=exception.as_dict()
            )
        else:
            raise NotImplementedError

    def enable_cors(self,
        allow_origins: list = None,
        allow_credentials: bool = False,
        allow_methods: list = None,
        allow_headers: list = None,
        expose_headers: list = None,
        max_age: int = None
    ):
        """Enables and configures Cross-Origin Resource Sharing (CORS)."""
        self.add_middleware(
            CORSMiddleware,
            allow_origins=allow_origins or [],
            allow_credentials=allow_credentials,
            allow_methods=allow_methods or [],
            allow_headers=allow_headers or [],
            expose_headers=expose_headers or [],
            max_age=max_age or self.cors_max_age
        )

    def add_openid_provider(self,
        url: str,
        audience: str,
        issuer: str = None,
        tags: list = None
    ): # pragma: no cover
        """Add a trusted OpenID identity provider."""
        p = urllib.parse.urlparse(url)
        if issuer is None:
            issuer = f'{p.scheme}://{p.netloc}'
        if not tags:
            tags = []
        tags.insert(0, 'openid')
        self.openid.add(url, audience, issuer, tags)

    def get_audience(self, request: Request) -> str:
        """Return the issuer used for JWS access tokens."""
        return self.audience or f'{request.url.scheme}://{request.url.netloc}'

    def get_issuer(self, request: Request) -> str:
        """Return the issuer used for JWS access tokens."""
        return self.issuer or f'{request.url.scheme}://{request.url.netloc}'
