from typing import Any, Union, TYPE_CHECKING, Optional
from collections.abc import Iterable

from airflow.hooks.dbapi import DbApiHook
from clickhouse_driver.dbapi.connection import (
    Connection as ChConnection,
    Cursor,
)

if TYPE_CHECKING:
    from airflow.models.connection import Connection


class ClickhouseHook(DbApiHook):
    """@author = klimenko.iv@gmail.com."""

    def bulk_dump(self, table, tmp_file):
        """bulk_dump method."""

    def bulk_load(self, table, tmp_file):
        """bulk_load method."""

    conn_name_attr = "click_conn_id"
    default_conn_name = "click_default"
    conn_type = "clickhouse"
    hook_name = "ClickHouse"
    database = ""

    @staticmethod
    def get_ui_field_behaviour() -> dict:
        """Returns custom field behaviour."""

        return {
            "hidden_fields": ["extra"],
            "relabeling": {"schema": "Database"},
        }

    def get_conn(
        self,
        conn_name_attr: Optional[str] = None,
    ) -> ChConnection:
        """get_conn method."""

        if conn_name_attr:
            self.conn_name_attr = conn_name_attr
        conn: Connection = self.get_connection(getattr(self, self.conn_name_attr))
        host: str = conn.host
        port: int = int(conn.port) if conn.port else 9000
        user: str = conn.login
        password: str = conn.password
        database: str = conn.schema
        click_kwargs = conn.extra_dejson.copy()
        if password is None:
            password = ""
        click_kwargs.update(port=port)
        click_kwargs.update(user=user)
        click_kwargs.update(password=password)
        if database:
            click_kwargs.update(database=database)

        return ChConnection(host=host or "localhost", **click_kwargs)

    def run(
        self,
        sql: Union[str, Iterable[str]],
        parameters: Optional[dict] = None,
        with_column_types: bool = True,
        **kwargs: dict[str, Any],
    ) -> Any:
        """Run method."""

        if isinstance(sql, str):
            queries = (sql,)
        client = self.get_conn()
        cursor: Cursor = client.cursor()
        result = None
        index = 0
        for query in queries:
            index += 1
            self.log.info("Query_%s  to database : %s", index, query)
            result = cursor.execute(
                query=query,
                params=parameters,
                with_column_types=with_column_types,
            )
            self.log.info("Query_%s completed", index)
        return result
