from __future__ import annotations
import asyncio
import inspect
from typing import Any
from urllib.parse import parse_qsl, urlencode, urljoin, urlsplit

from rnet import Client, Impersonate, Method  # type: ignore[import]

from .exceptions import HttpError
from .request import Request
from .response import Response, HTMLResponse
from .logging import get_logger


class HttpClient:
    def __init__(
        self,
        *,
        concurrency: int = 16,
        impersonate: Impersonate = Impersonate.Firefox139,
        default_headers: dict[str, str] | None = None,
        timeout: float | None = None,
        html_max_size_bytes: int = 5_000_000,
        follow_redirects: bool = True,
        max_redirects: int = 10,
        **client_kwargs: Any,
    ) -> None:
        self._client = Client(impersonate=impersonate, **client_kwargs)
        self._sem = asyncio.Semaphore(concurrency)
        self._default_headers = default_headers or {}
        self._timeout = timeout
        self._html_max_size_bytes = html_max_size_bytes
        self._follow_redirects = follow_redirects
        if max_redirects < 0:
            raise ValueError("max_redirects must be non-negative")
        self._max_redirects = max_redirects
        self.logger = get_logger(component="http")

    async def fetch(self, req: Request) -> Response:
        proxy = req.meta.get("proxy")
        current_req = req
        redirects_followed = 0
        visited_urls: set[str] = set()
        total_start = asyncio.get_running_loop().time()

        # Response data captured from the final request in any redirect chain
        body: bytes = b""
        status: int = 0
        headers: dict[str, str] = {}
        elapsed: float = 0.0

        while True:
            resp: Any | None = None
            method = self._normalize_method(current_req.method)
            url = self._build_url(current_req)
            visited_urls.add(url)

            try:
                async with self._sem:
                    timeout = (
                        current_req.timeout
                        if current_req.timeout is not None
                        else self._timeout
                    )
                    request_kwargs = dict(
                        headers={**self._default_headers, **current_req.headers},
                        data=current_req.data,
                        json=current_req.json,
                        proxy=proxy,
                    )
                    if timeout is not None:
                        request_kwargs["timeout"] = timeout

                    # Adjust keyword arguments to actual rnet.Client.request signature
                    resp = await self._client.request(method, url, **request_kwargs)

                    status = resp.status
                    headers = self._normalize_headers(resp.headers)

                    if self._should_follow_redirect(status, headers):
                        if redirects_followed >= self._max_redirects:
                            raise HttpError(
                                f"Exceeded maximum redirects ({self._max_redirects})"
                            )

                        redirect_url = self._resolve_redirect_url(
                            url, headers.get("location", "")
                        )
                        if redirect_url in visited_urls:
                            raise HttpError("Redirect loop detected")

                        redirects_followed += 1
                        self.logger.debug(
                            "Following redirect",
                            from_url=url,
                            to_url=redirect_url,
                            status=status,
                        )
                        current_req = self._redirect_request(
                            current_req, redirect_url, status, method
                        )
                        await self._close_response(resp)
                        resp = None
                        continue

                    body = await self._read_body(resp)
                    elapsed = (asyncio.get_running_loop().time() - total_start) * 1000
                break
            except HttpError:
                raise
            except Exception as exc:
                raise HttpError(f"Request to {req.url} failed") from exc
            finally:
                await self._close_response(resp)

        self.logger.debug(
            "HTTP response",
            url=url,
            status=status,
            elapsed_ms=round(elapsed, 2),
            proxy=bool(proxy),
            redirects=redirects_followed,
        )
        content_type = headers.get("content-type", "")
        if "html" in content_type:
            return HTMLResponse(
                url=url,
                status=status,
                headers=headers,
                body=body,
                request=current_req,
                doc_max_size_bytes=self._html_max_size_bytes,
            )

        return Response(
            url=url,
            status=status,
            headers=headers,
            body=body,
            request=current_req,
        )

    async def _read_body(self, resp: Any) -> bytes:
        """
        rnet responses may expose the payload differently; try common attributes.
        """
        # preferred path: read() coroutine or function
        if hasattr(resp, "read"):
            reader = resp.read
            if callable(reader):
                result = reader()
                if inspect.isawaitable(result):
                    result = await result
                return self._ensure_bytes(result)

        # fallbacks: content/body attributes or callables
        for attr in ("content", "body"):
            if hasattr(resp, attr):
                value = getattr(resp, attr)
                if callable(value):
                    value = value()
                if inspect.isawaitable(value):
                    value = await value
                return self._ensure_bytes(value)

        # last resort: text/aread style
        for attr in ("text",):
            if hasattr(resp, attr):
                value = getattr(resp, attr)
                if callable(value):
                    value = value()
                if inspect.isawaitable(value):
                    value = await value
                return self._ensure_bytes(value)

        raise TypeError("Unable to read response body")

    async def _close_response(self, resp: Any | None) -> None:
        """Release the underlying HTTP response if it exposes a close hook."""
        if resp is None:
            return

        closer = getattr(resp, "aclose", None) or getattr(resp, "close", None)
        if closer and callable(closer):
            try:
                result = closer()
                if inspect.isawaitable(result):
                    await result
            except Exception:
                # Best-effort cleanup; avoid surfacing close errors.
                self.logger.debug("Failed to close response", exc_info=True)

    def _ensure_bytes(self, data: Any) -> bytes:
        if isinstance(data, bytes):
            return data
        if isinstance(data, str):
            return data.encode("utf-8", errors="replace")
        if data is None:
            return b""
        return bytes(data)

    def _normalize_headers(self, raw_headers: Any) -> dict[str, str]:
        """
        rnet's Response.headers may be a mapping or a list of raw header lines;
        coerce both shapes into a plain dict without raising.
        """

        def _as_str(val: Any) -> str:
            if isinstance(val, bytes):
                return val.decode("utf-8", errors="ignore")
            return str(val)

        if raw_headers is None:
            return {}

        # Best effort if it already looks like a mapping
        if isinstance(raw_headers, dict):
            return {_as_str(k).lower(): _as_str(v) for k, v in raw_headers.items()}
        if hasattr(raw_headers, "items"):
            try:
                return {_as_str(k).lower(): _as_str(v) for k, v in raw_headers.items()}
            except Exception:
                pass

        headers: dict[str, str] = {}
        if isinstance(raw_headers, (list, tuple)):
            for entry in raw_headers:
                if isinstance(entry, (list, tuple)) and len(entry) == 2:
                    k, v = entry
                elif isinstance(entry, (bytes, str)):
                    text = (
                        entry.decode("utf-8", errors="ignore")
                        if isinstance(entry, bytes)
                        else entry
                    )
                    if ":" not in text:
                        continue
                    k, v = text.split(":", 1)
                else:
                    continue
                k = _as_str(k).strip().lower()
                v = _as_str(v).strip()
                headers[k] = v
            if headers:
                return headers

        try:
            return {
                _as_str(k).lower(): _as_str(v) for k, v in dict(raw_headers).items()
            }
        except Exception:
            return {}

    def _build_url(self, req: Request) -> str:
        if not req.params:
            return req.url

        parts = urlsplit(req.url)
        existing = dict(parse_qsl(parts.query, keep_blank_values=True))
        existing.update(req.params)
        query = urlencode(existing, doseq=True)
        return parts._replace(query=query).geturl()

    def _normalize_method(self, method: str | Method) -> Method:
        if isinstance(method, Method):
            return method

        upper = method.upper()
        if hasattr(Method, upper):
            return getattr(Method, upper)

        raise ValueError(f"Unsupported HTTP method: {method!r}")

    def _method_name(self, method: Method) -> str:
        return (
            getattr(method, "name", None)
            or getattr(method, "value", None)
            or str(method)
        )

    def _should_follow_redirect(self, status: int, headers: dict[str, str]) -> bool:
        if not self._follow_redirects:
            return False

        return status in {301, 302, 303, 307, 308} and "location" in headers

    def _resolve_redirect_url(self, current_url: str, location: str) -> str:
        return urljoin(current_url, location.strip())

    def _redirect_request(
        self, req: Request, redirect_url: str, status: int, method: Method
    ) -> Request:
        method_name = self._method_name(method).upper()
        new_method = method_name
        new_data = req.data
        new_json = req.json

        if status in {301, 302, 303} and method_name not in {"GET", "HEAD"}:
            new_method = "GET"
            new_data = None
            new_json = None

        updated = req.replace(
            url=redirect_url,
            method=new_method,
            data=new_data,
            json=new_json,
            params={},  # don't re-append original query params to redirect targets
        )

        redirects = updated.meta.get("redirect_times", 0) + 1
        updated.meta["redirect_times"] = redirects
        return updated

    async def close(self) -> None:
        closer = getattr(self._client, "aclose", None) or getattr(
            self._client, "close", None
        )
        if closer is None or not callable(closer):
            return

        try:
            result = closer()
            if inspect.isawaitable(result):
                await result
        except Exception as exc:
            # Best-effort cleanup; suppress shutdown errors so the engine can exit.
            self.logger.debug("Failed to close HTTP client cleanly", error=str(exc))
