import logging
import time
from concurrent import futures
from dataclasses import dataclass
from enum import Enum
from typing import Callable, Dict, List, Optional, Tuple

from snowflake.connector import SnowflakeConnection
from snowflake.connector.cursor import SnowflakeCursor

from metaphor.snowflake.filter import DatabaseFilter, SnowflakeFilter

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

DEFAULT_THREAD_POOL_SIZE = 10
DEFAULT_SLEEP_TIME = 0.1  # 0.1 s


class SnowflakeTableType(Enum):
    BASE_TABLE = "BASE TABLE"
    VIEW = "VIEW"
    TEMPORARY_TABLE = "TEMPORARY TABLE"


@dataclass
class DatasetInfo:
    database: str
    schema: str
    name: str
    type: str


@dataclass
class QueryWithParam:
    query: str
    params: Optional[Tuple] = None


def async_query(conn: SnowflakeConnection, query: QueryWithParam) -> SnowflakeCursor:
    """Executing a snowflake query asynchronously"""
    cursor = conn.cursor()
    if query.params is not None:
        logger.debug(f"Query {query.query} params {query.params}")
        cursor.execute_async(query.query, query.params)
    else:
        cursor.execute_async(query.query)

    query_id = cursor.sfqid

    # Wait for the query to finish running.
    while conn.is_still_running(conn.get_query_status(query_id)):
        time.sleep(DEFAULT_SLEEP_TIME)

    cursor.get_results_from_sfqid(query_id)
    return cursor


def async_execute(
    conn: SnowflakeConnection,
    queries: Dict[str, QueryWithParam],
    query_name: str = "",
    max_workers: Optional[int] = None,
    results_processor: Optional[Callable[[str, List[Tuple]], None]] = None,
) -> Dict[str, List]:
    """
    Executing snowflake query with a set of parameters using thread pool
    If results_processor is not provided, will return Dict[key, result_tuples],
    Otherwise, apply the results_processor to the result_tuples
    """
    workers = max_workers if max_workers is not None else DEFAULT_THREAD_POOL_SIZE
    with futures.ThreadPoolExecutor(max_workers=workers) as executor:
        future_map = {
            executor.submit(async_query, conn, query): key
            for key, query in queries.items()
        }

        results_map = {}
        for future in futures.as_completed(future_map):
            key = future_map[future]
            try:
                results = future.result().fetchall()
            except Exception as ex:
                logger.error(f"Error executing {query_name} for {key}")
                logger.exception(ex)
                continue

            if results_processor is None:
                results_map[key] = results
            else:
                results_processor(key, results)

        return results_map


def include_table(
    database: str, schema: str, table: str, filter: SnowflakeFilter
) -> bool:
    database_lower = database.lower()
    schema_lower = schema.lower()
    table_lower = table.lower()

    def covered_by_filter(database_filter: DatabaseFilter):
        # not covered by database filter
        if database_lower not in database_filter:
            return False

        schema_filter = database_filter[database_lower]

        # empty schema filter
        if schema_filter is None or len(schema_filter) == 0:
            return True

        # not covered by schema filter
        if schema_lower not in schema_filter:
            return False

        table_filter = schema_filter[schema_lower]

        # empty table filter
        if table_filter is None or len(table_filter) == 0:
            return True

        # covered by table filter?
        return table_lower in table_filter

    # Filtered out by includes
    if filter.includes is not None and not covered_by_filter(filter.includes):
        return False

    # Filtered out by excludes
    if filter.excludes is not None and covered_by_filter(filter.excludes):
        return False

    return True
