from collections import OrderedDict, defaultdict
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union

from dagster import check
from dagster.core.errors import (
    DagsterRunAlreadyExists,
    DagsterRunNotFoundError,
    DagsterSnapshotDoesNotExist,
)
from dagster.core.events import EVENT_TYPE_TO_PIPELINE_RUN_STATUS, DagsterEvent, DagsterEventType
from dagster.core.execution.backfill import BulkActionStatus, PartitionBackfill
from dagster.core.snap import (
    ExecutionPlanSnapshot,
    PipelineSnapshot,
    create_execution_plan_snapshot_id,
    create_pipeline_snapshot_id,
)
from dagster.daemon.types import DaemonHeartbeat
from dagster.utils import frozendict, merge_dicts

from ..pipeline_run import PipelineRun, PipelineRunsFilter, RunRecord
from .base import RunStorage


class InMemoryRunStorage(RunStorage):
    def __init__(self, preload=None):
        self._init_storage()
        if preload:
            for payload in preload:
                self._runs[payload.pipeline_run.run_id] = payload.pipeline_run
                self._pipeline_snapshots[
                    payload.pipeline_run.pipeline_snapshot_id
                ] = payload.pipeline_snapshot
                self._ep_snapshots[
                    payload.pipeline_run.execution_plan_snapshot_id
                ] = payload.execution_plan_snapshot

        super().__init__()

    # separate method so it can be reused in wipe
    def _init_storage(self):
        self._runs: Dict[str, PipelineRun] = OrderedDict()
        self._run_tags: Dict[str, dict] = defaultdict(dict)
        self._pipeline_snapshots: Dict[str, PipelineSnapshot] = OrderedDict()
        self._ep_snapshots: Dict[str, ExecutionPlanSnapshot] = OrderedDict()
        self._bulk_actions: Dict[str, PartitionBackfill] = OrderedDict()

    def add_run(self, pipeline_run: PipelineRun) -> PipelineRun:
        check.inst_param(pipeline_run, "pipeline_run", PipelineRun)
        if self._runs.get(pipeline_run.run_id):
            raise DagsterRunAlreadyExists(
                "Can not add same run twice for run_id {run_id}".format(run_id=pipeline_run.run_id),
            )
        if pipeline_run.pipeline_snapshot_id:
            if not self.has_pipeline_snapshot(pipeline_run.pipeline_snapshot_id):
                raise DagsterSnapshotDoesNotExist(
                    "pipeline_snapshot_id {ss_id} does not exist in run storage.".format(
                        ss_id=pipeline_run.pipeline_snapshot_id
                    )
                )

        self._runs[pipeline_run.run_id] = pipeline_run
        if pipeline_run.tags and len(pipeline_run.tags) > 0:
            self._run_tags[pipeline_run.run_id] = frozendict(pipeline_run.tags)

        return pipeline_run

    def handle_run_event(self, run_id: str, event: DagsterEvent):
        check.str_param(run_id, "run_id")
        check.inst_param(event, "event", DagsterEvent)
        if run_id not in self._runs:
            return
        run = self._runs[run_id]

        if event.event_type in [DagsterEventType.PIPELINE_START, DagsterEventType.PIPELINE_SUCCESS]:
            self._runs[run_id] = run.with_status(
                EVENT_TYPE_TO_PIPELINE_RUN_STATUS[event.event_type]
            )
        else:
            self._runs[run_id] = self._runs[run_id].with_status(
                EVENT_TYPE_TO_PIPELINE_RUN_STATUS[event.event_type]
            )

    def get_runs(
        self, filters: PipelineRunsFilter = None, cursor: str = None, limit: int = None
    ) -> List[PipelineRun]:
        check.opt_inst_param(filters, "filters", PipelineRunsFilter)
        check.opt_str_param(cursor, "cursor")
        check.opt_int_param(limit, "limit")

        if not filters:
            return self._slice(list(self._runs.values())[::-1], cursor, limit)

        def run_filter(run):
            if filters.run_ids and run.run_id not in filters.run_ids:
                return False

            if filters.statuses and run.status not in filters.statuses:
                return False

            if filters.pipeline_name and filters.pipeline_name != run.pipeline_name:
                return False

            if filters.mode and filters.mode != run.mode:
                return False

            if filters.tags and not all(
                run.tags.get(key) == value for key, value in filters.tags.items()
            ):
                return False

            if filters.snapshot_id and filters.snapshot_id != run.pipeline_snapshot_id:
                return False

            return True

        matching_runs = list(filter(run_filter, list(self._runs.values())[::-1]))
        return self._slice(matching_runs, cursor=cursor, limit=limit)

    def get_runs_count(self, filters: PipelineRunsFilter = None) -> int:
        check.opt_inst_param(filters, "filters", PipelineRunsFilter)

        return len(self.get_runs(filters))

    def _slice(
        self,
        items: List,
        cursor: Optional[str],
        limit: Optional[int],
        key_fn: Callable = lambda _: _.run_id,
    ):
        if cursor:
            try:
                index = next(i for i, item in enumerate(items) if key_fn(item) == cursor)
            except StopIteration:
                return []
            start = index + 1
        else:
            start = 0

        end: Optional[int]
        if limit:
            end = start + limit
        else:
            end = None

        return list(items)[start:end]

    def get_run_by_id(self, run_id: str) -> Optional[PipelineRun]:
        check.str_param(run_id, "run_id")
        return self._runs.get(run_id)

    def get_run_records(
        self,
        filters: PipelineRunsFilter = None,
        limit: int = None,
        order_by: str = None,
        ascending: bool = False,
    ) -> List[RunRecord]:
        raise NotImplementedError("In memory run storage does not track timestamp yet.")

    def get_run_tags(self) -> List[Tuple[str, Set[str]]]:
        all_tags = defaultdict(set)
        for _run_id, tags in self._run_tags.items():
            for k, v in tags.items():
                all_tags[k].add(v)

        return sorted([(k, v) for k, v in all_tags.items()], key=lambda x: x[0])

    def add_run_tags(self, run_id: str, new_tags: Dict[str, str]):
        check.str_param(run_id, "run_id")
        check.dict_param(new_tags, "new_tags", key_type=str, value_type=str)
        run = self._runs[run_id]
        run_tags = merge_dicts(run.tags if run.tags else {}, new_tags)
        self._runs[run_id] = run.with_tags(run_tags)
        self._run_tags[run_id] = frozendict(run_tags)

    def has_run(self, run_id: str) -> bool:
        check.str_param(run_id, "run_id")
        return run_id in self._runs

    def delete_run(self, run_id: str):
        check.str_param(run_id, "run_id")
        del self._runs[run_id]
        if run_id in self._run_tags:
            del self._run_tags[run_id]

    def has_pipeline_snapshot(self, pipeline_snapshot_id: str) -> bool:
        check.str_param(pipeline_snapshot_id, "pipeline_snapshot_id")
        return pipeline_snapshot_id in self._pipeline_snapshots

    def add_pipeline_snapshot(self, pipeline_snapshot: PipelineSnapshot) -> str:
        check.inst_param(pipeline_snapshot, "pipeline_snapshot", PipelineSnapshot)
        pipeline_snapshot_id = create_pipeline_snapshot_id(pipeline_snapshot)
        self._pipeline_snapshots[pipeline_snapshot_id] = pipeline_snapshot
        return pipeline_snapshot_id

    def get_pipeline_snapshot(self, pipeline_snapshot_id: str) -> PipelineSnapshot:
        check.str_param(pipeline_snapshot_id, "pipeline_snapshot_id")
        return self._pipeline_snapshots[pipeline_snapshot_id]

    def has_execution_plan_snapshot(self, execution_plan_snapshot_id: str) -> bool:
        check.str_param(execution_plan_snapshot_id, "execution_plan_snapshot_id")
        return execution_plan_snapshot_id in self._ep_snapshots

    def add_execution_plan_snapshot(self, execution_plan_snapshot: ExecutionPlanSnapshot) -> str:
        check.inst_param(execution_plan_snapshot, "execution_plan_snapshot", ExecutionPlanSnapshot)
        execution_plan_snapshot_id = create_execution_plan_snapshot_id(execution_plan_snapshot)
        self._ep_snapshots[execution_plan_snapshot_id] = execution_plan_snapshot
        return execution_plan_snapshot_id

    def get_execution_plan_snapshot(self, execution_plan_snapshot_id: str) -> ExecutionPlanSnapshot:
        check.str_param(execution_plan_snapshot_id, "execution_plan_snapshot_id")
        return self._ep_snapshots[execution_plan_snapshot_id]

    def wipe(self):
        self._init_storage()

    def build_missing_indexes(self, print_fn: Callable = None, force_rebuild_all: bool = False):
        pass

    def get_run_group(self, run_id: str) -> Optional[Tuple[str, List[PipelineRun]]]:
        check.str_param(run_id, "run_id")
        pipeline_run = self._runs.get(run_id)
        if not pipeline_run:
            raise DagsterRunNotFoundError(
                f"Run {run_id} was not found in instance.", invalid_run_id=run_id
            )
        # if the run doesn't have root_run_id, itself is the root
        root_run = (
            self.get_run_by_id(pipeline_run.root_run_id)
            if pipeline_run.root_run_id
            else pipeline_run
        )
        if not root_run:
            return None
        run_group = [root_run]
        for curr_run in self._runs.values():
            if curr_run.root_run_id == root_run.run_id:
                run_group.append(curr_run)
        return (root_run.root_run_id, run_group)

    def get_run_groups(
        self, filters: PipelineRunsFilter = None, cursor: str = None, limit: int = None
    ) -> Dict[str, Dict[str, Union[Iterable[PipelineRun], int]]]:
        runs = self.get_runs(filters=filters, cursor=cursor, limit=limit)
        root_run_id_to_group: Dict[str, Dict[str, PipelineRun]] = defaultdict(dict)
        for run in runs:
            root_run_id = run.get_root_run_id()
            if root_run_id is not None:
                root_run_id_to_group[root_run_id][run.run_id] = run
            else:
                # this run is the root run
                root_run_id_to_group[run.run_id][run.run_id] = run

        # add root run to the group if it's not already there
        for root_run_id in root_run_id_to_group:
            if root_run_id not in root_run_id_to_group[root_run_id]:
                root_pipeline_run = self.get_run_by_id(root_run_id)
                if root_pipeline_run:
                    root_run_id_to_group[root_run_id][root_run_id] = root_pipeline_run

        # counts total number of runs in a run group, including the ones don't match the given filter
        root_run_id_to_count: Dict[str, int] = defaultdict(int)
        for run in self.get_runs():
            root_run_id = run.get_root_run_id() or run.run_id
            if root_run_id in root_run_id_to_group:
                root_run_id_to_count[root_run_id] += 1

        return {
            root_run_id: {
                "runs": list(run_group.values()),
                "count": root_run_id_to_count[root_run_id],
            }
            for root_run_id, run_group in root_run_id_to_group.items()
        }

    # Daemon Heartbeats

    def add_daemon_heartbeat(self, daemon_heartbeat: DaemonHeartbeat):
        raise NotImplementedError(
            "The dagster daemon lives in a separate process. It cannot use in memory storage."
        )

    def get_daemon_heartbeats(self) -> Dict[str, DaemonHeartbeat]:
        raise NotImplementedError(
            "The dagster daemon lives in a separate process. It cannot use in memory storage."
        )

    def wipe_daemon_heartbeats(self):
        raise NotImplementedError(
            "The dagster daemon lives in a separate process. It cannot use in memory storage."
        )

    def get_backfills(
        self, status: BulkActionStatus = None, cursor: str = None, limit: int = None
    ) -> List[PartitionBackfill]:
        check.opt_inst_param(status, "status", BulkActionStatus)
        backfills = [
            backfill
            for backfill in self._bulk_actions.values()
            if not status or status == backfill.status
        ]
        return self._slice(backfills[::-1], cursor, limit, key_fn=lambda _: _.backfill_id)

    def get_backfill(self, backfill_id: str) -> Optional[PartitionBackfill]:
        check.str_param(backfill_id, "backfill_id")
        return self._bulk_actions.get(backfill_id)

    def add_backfill(self, partition_backfill: PartitionBackfill):
        check.inst_param(partition_backfill, "partition_backfill", PartitionBackfill)
        self._bulk_actions[partition_backfill.backfill_id] = partition_backfill

    def update_backfill(self, partition_backfill: PartitionBackfill):
        check.inst_param(partition_backfill, "partition_backfill", PartitionBackfill)
        self._bulk_actions[partition_backfill.backfill_id] = partition_backfill
