#
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
#


import logging
from typing import Any, List, Mapping, MutableMapping, Optional, Tuple

import backoff
import pendulum
import requests
from deprecated import deprecated

from ..exceptions import DefaultBackoffException
from .core import HttpAuthenticator

logger = logging.getLogger("airbyte")


@deprecated(version="0.1.20", reason="Use airbyte_cdk.sources.streams.http.requests_native_auth.Oauth2Authenticator instead")
class Oauth2Authenticator(HttpAuthenticator):
    """
    Generates OAuth2.0 access tokens from an OAuth2.0 refresh token and client credentials.
    The generated access token is attached to each request via the Authorization header.
    """

    def __init__(
        self,
        token_refresh_endpoint: str,
        client_id: str,
        client_secret: str,
        refresh_token: str,
        scopes: List[str] = None,
        refresh_access_token_headers: Optional[Mapping[str, Any]] = None,
        refresh_access_token_authenticator: Optional[HttpAuthenticator] = None,
    ):
        self.token_refresh_endpoint = token_refresh_endpoint
        self.client_secret = client_secret
        self.client_id = client_id
        self.refresh_token = refresh_token
        self.scopes = scopes
        self.refresh_access_token_headers = refresh_access_token_headers
        self.refresh_access_token_authenticator = refresh_access_token_authenticator

        self._token_expiry_date = pendulum.now().subtract(days=1)
        self._access_token = None

    def get_auth_header(self) -> Mapping[str, Any]:
        return {"Authorization": f"Bearer {self.get_access_token()}"}

    def get_access_token(self):
        if self.token_has_expired():
            t0 = pendulum.now()
            token, expires_in = self.refresh_access_token()
            self._access_token = token
            self._token_expiry_date = t0.add(seconds=expires_in)

        return self._access_token

    def token_has_expired(self) -> bool:
        return pendulum.now() > self._token_expiry_date

    def get_refresh_request_body(self) -> Mapping[str, Any]:
        """Override to define additional parameters"""
        payload: MutableMapping[str, Any] = {
            "grant_type": "refresh_token",
            "client_id": self.client_id,
            "client_secret": self.client_secret,
            "refresh_token": self.refresh_token,
        }

        if self.scopes:
            payload["scopes"] = self.scopes

        return payload

    @backoff.on_exception(
        backoff.expo,
        DefaultBackoffException,
        on_backoff=lambda details: logger.info(
            f"Caught retryable error after {details['tries']} tries. Waiting {details['wait']} seconds then retrying..."
        ),
        max_time=300,
    )
    def refresh_access_token(self) -> Tuple[str, int]:
        """
        returns a tuple of (access_token, token_lifespan_in_seconds)
        """
        try:
            response = requests.request(
                method="POST",
                url=self.token_refresh_endpoint,
                data=self.get_refresh_request_body(),
                headers=self.get_refresh_access_token_headers(),
            )
            response.raise_for_status()
            response_json = response.json()
            return response_json["access_token"], int(response_json["expires_in"])
        except requests.exceptions.RequestException as e:
            if e.response.status_code == 429 or e.response.status_code >= 500:
                raise DefaultBackoffException(request=e.response.request, response=e.response)
            raise
        except Exception as e:
            raise Exception(f"Error while refreshing access token: {e}") from e

    def get_refresh_access_token_headers(self):
        headers = {}
        if self.refresh_access_token_headers:
            headers = self.refresh_access_token_headers
        if self.refresh_access_token_authenticator:
            refresh_auth_headers = self.refresh_access_token_authenticator.get_auth_header()
            headers.update(refresh_auth_headers)
        return headers
