# -*- coding:utf-8 -*-
#
# Copyright (C) 2020-2021, Saarland University
# Copyright (C) 2020-2021, Maximilian Köhl <koehl@cs.uni-saarland.de>
# Copyright (C) 2020-2021, Michaela Klauck <klauck@cs.uni-saarland.de>

from __future__ import annotations

import dataclasses as d
import typing as t

import enum
import itertools
import math
import re


from momba import model
from momba.model import expressions, types
from momba.moml import expr, prop


class TankType(enum.Enum):
    """
    An enumeration of different *tank types*.

    The actual tank size is calculate based on the size of the
    track and *capacity factor*.

    Attributes
    ----------
    capacity_factor:
        The capacity factor associated with the tank size.
    """

    SMALL = 0.5
    """ A small tank. """

    MEDIUM = 0.75
    """ A medium-sized tank. """

    LARGE = 1
    """ A large tank. """

    capacity_factor: float

    def __init__(self, capacity_factor: float) -> None:
        self.capacity_factor = capacity_factor


class AccelerationModel(t.Protocol):
    def __call__(self, acceleration: model.Expression) -> model.Expression:
        pass


class Underground(enum.Enum):
    """
    An enumeration of different *undergrounds*.

    Undergrounds introduce probabilistic noise modeling
    slippery road conditions.

    Attributes
    ----------
    acceleration_probability:
        An expression for the probability that the acceleration succeeds.
    acceleration_model:
        A function for computing the *abnormal* acceleration.
    """

    TARMAC = expr("9 / 10"), lambda a: a
    """
    A very solid non-slippery underground introducing no noise.
    """

    SAND = (
        expr("5 / 10"),
        lambda a: expr("$a > 0 ? $a - 1 : ($a < 0 ? $a + 1 : 0)", a=a),
    )
    """
    A sandy underground introducing some noise, be cautious!
    """

    ICE = expr("3 / 10"), lambda a: expr("0")
    """
    A very slippy underground.
    """

    acceleration_probability: model.Expression
    acceleration_model: AccelerationModel

    def __init__(
        self,
        acceleration_probability: model.Expression,
        acceleration_model: AccelerationModel,
    ) -> None:
        self.acceleration_probability = acceleration_probability
        self.acceleration_model = acceleration_model


@d.dataclass(frozen=True, order=True)
class Coordinate:
    """
    Represents a coordinate on the track.
    """

    x: int
    """ The :math:`x` coordinate. """

    y: int
    """ The :math:`y` coordinate. """


class CellType(enum.Enum):
    """
    An enumeration of *cell types*.
    """

    BLANK = "."
    """
    A *blank cell* where one can drive.
    """

    BLOCKED = "x"
    """
    A cell *blocked* by an obstacle.
    """

    START = "s"
    """
    A start cell.
    """

    GOAL = "g"
    """
    A goal cell.
    """


@d.dataclass(frozen=True)
class Track:
    """
    Represents a *track*.

    Attributes
    ----------
    width:
        The width of the track.
    height:
        The height of the track.
    blank_cells:
        The set of blank cells.
    blocked_cells:
        The set of blocked cells.
    start_cells:
        The set of start cells.
    goal_cells:
        The set of goal cells.
    """

    width: int
    height: int

    blank_cells: t.FrozenSet[Coordinate]
    blocked_cells: t.FrozenSet[Coordinate]
    start_cells: t.FrozenSet[Coordinate]
    goal_cells: t.FrozenSet[Coordinate]

    def get_cell_type(self, cell: Coordinate) -> CellType:
        """
        Retrives the type of the given *cell*.
        """
        if cell in self.blank_cells:
            return CellType.BLANK
        elif cell in self.blocked_cells:
            return CellType.BLOCKED
        elif cell in self.start_cells:
            return CellType.START
        else:
            assert cell in self.goal_cells
            return CellType.GOAL

    @property
    def textual_description(self) -> str:
        """
        Converts the track into its textual description.
        """
        lines = [f"dim: {self.width} {self.height}"]
        for y in range(self.height):
            lines.append(
                "".join(
                    self.get_cell_type(Coordinate(x, y)).value
                    for x in range(self.width)
                )
            )
        return "\n".join(lines)

    @classmethod
    def from_source(cls, source: str) -> Track:
        """
        Converts a textual specification of a track into a :class:`Track`.
        """
        firstline, _, remainder = source.partition("\n")
        dimension = re.match(r"dim: (?P<height>\d+) (?P<width>\d+)", firstline)
        assert dimension is not None, "invalid format: dimension missing"
        width, height = int(dimension["width"]), int(dimension["height"])

        track = [
            list(line.strip())
            for line in remainder.splitlines(keepends=False)
            if line.strip()
        ]

        assert (
            len(track) == height
        ), "given track height does not match actual track height"
        assert all(
            len(row) == width for row in track
        ), "given track width does not match actual track width"

        def get_coordinates(expected_cell_char: str) -> t.FrozenSet[Coordinate]:
            return frozenset(
                Coordinate(x, y)
                for y, row in enumerate(track)
                for x, cell_char in enumerate(row)
                if cell_char == expected_cell_char
            )

        blank_cells = get_coordinates(".")
        blocked_cells = get_coordinates("x")
        start_cells = get_coordinates("s")
        goal_cells = get_coordinates("g")

        assert len(start_cells) > 0, "no start cell specified"
        assert len(goal_cells) > 0, "no goal cell specified"

        return cls(width, height, blank_cells, blocked_cells, start_cells, goal_cells)


class FuelModel(t.Protocol):
    def __call__(
        self, scenario: Scenario, dx: model.Expression, dy: model.Expression
    ) -> model.Expression:
        pass


def fuel_model_linear(
    scenario: Scenario, dx: model.Expression, dy: model.Expression
) -> model.Expression:
    return expr("abs($dx) + abs($dy)", dx=dx, dy=dy)


def fuel_model_quadratic(
    scenario: Scenario, dx: model.Expression, dy: model.Expression
) -> model.Expression:
    return expr("$linear ** 2", linear=fuel_model_linear(scenario, dx, dy))


def fuel_model_regular(
    scenario: Scenario, dx: model.Expression, dy: model.Expression
) -> model.Expression:
    return expr(
        "(1 + $max_acceleration) + $quadratic",
        max_acceleration=scenario.max_acceleration,
        quadratic=fuel_model_quadratic(scenario, dx, dy),
    )


@d.dataclass(frozen=True)
class Scenario:
    """
    A scenario description comprising a track, start cell, tank type, underground,
    maximal speed and acceleration values, and a fuel model.
    """

    track: Track

    start_cell: Coordinate

    tank_type: TankType = TankType.LARGE
    underground: Underground = Underground.TARMAC

    max_speed: t.Optional[int] = None
    max_acceleration: int = 1

    fuel_model: t.Optional[FuelModel] = fuel_model_regular

    def __post_init__(self) -> None:
        assert (
            self.start_cell in self.track.start_cells
        ), f"invalid start cell {self.start_cell}"

    @property
    def tank_size(self) -> int:
        return math.floor(
            self.tank_type.capacity_factor * 3 * len(self.track.blank_cells)
        )

    @property
    def possible_accelerations(self) -> t.Iterable[int]:
        return tuple(range(-self.max_acceleration, self.max_acceleration + 1))

    def compute_consumption(
        self, dx: model.Expression, dy: model.Expression
    ) -> model.Expression:
        assert self.fuel_model is not None, "no fuel model has been defined"
        return self.fuel_model(self, dx, dy)


def construct_model(scenario: Scenario) -> model.Network:
    """
    Constructs an MDP network from the provided scenario description.
    """

    ctx = model.Context(model.ModelType.MDP)
    network = ctx.create_network(name="Featured Racetrack")

    track = scenario.track

    ctx.global_scope.declare_constant("WIDTH", types.INT, value=track.width)
    ctx.global_scope.declare_constant("HEIGHT", types.INT, value=track.height)

    speed_bound = (
        max(scenario.track.width, scenario.track.height) + scenario.max_acceleration
    )

    ctx.global_scope.declare_variable(
        "car_dx",
        types.INT.bound(-speed_bound, speed_bound),
        initial_value=0,
    )
    ctx.global_scope.declare_variable(
        "car_dy",
        types.INT.bound(-speed_bound, speed_bound),
        initial_value=0,
    )

    ctx.global_scope.declare_variable(
        "car_x",
        types.INT.bound(-1, track.width),
        initial_value=scenario.start_cell.x,
    )
    ctx.global_scope.declare_variable(
        "car_y",
        types.INT.bound(-1, track.height),
        initial_value=scenario.start_cell.y,
    )

    if scenario.fuel_model is not None:
        ctx.global_scope.declare_variable(
            "fuel",
            types.INT.bound(0, scenario.tank_size),
            initial_value=scenario.tank_size,
        )

    accelerate = ctx.create_action_type("accelerate").create_pattern()
    # The environment is about to move the car.
    move_tick = ctx.create_action_type("move_tick").create_pattern()
    # The environment is about to check the state of the car.
    check_tick = ctx.create_action_type("check_tick").create_pattern()
    # The environment is about to delegate the decision back to the car.
    delegate = ctx.create_action_type("delegate").create_pattern()

    def is_at_cell(
        cells: t.Iterable[Coordinate],
        car_x: model.Expression = expr("car_x"),
        car_y: model.Expression = expr("car_y"),
    ):
        return expressions.logic_any(
            *(
                expr(
                    "$car_x == $cell_x and $car_y == $cell_y ",
                    car_x=car_x,
                    car_y=car_y,
                    cell_x=cell.x,
                    cell_y=cell.y,
                )
                for cell in cells
            )
        )

    def is_at_goal(
        car_x: model.Expression = expr("car_x"),
        car_y: model.Expression = expr("car_y"),
    ) -> model.Expression:
        return is_at_cell(track.goal_cells, car_x, car_y)

    def is_at_blocked(
        car_x: model.Expression = expr("car_x"),
        car_y: model.Expression = expr("car_y"),
    ) -> model.Expression:
        return is_at_cell(track.blocked_cells, car_x, car_y)

    def is_off_track(
        car_x: model.Expression = expr("car_x"),
        car_y: model.Expression = expr("car_y"),
    ):
        return expr(
            "$car_x >= WIDTH or $car_x < 0 or $car_y >= HEIGHT or $car_y < 0",
            car_x=car_x,
            car_y=car_y,
        )

    # In case the fuel is empty before reaching the goal, the model goes
    # into a dead state without transitions. Hence, this property also
    # covers the consumption of fuel.
    ctx.define_property(
        "goalProbability",
        prop("min({ Pmax(F($is_at_goal)) | initial })", is_at_goal=is_at_goal()),
    )

    def construct_car_automaton() -> model.Automaton:
        automaton = ctx.create_automaton(name="car")
        initial = automaton.create_location(initial=True)

        def compute_speed(
            current: model.Expression, acceleration: expressions.ValueOrExpression
        ) -> model.Expression:
            if scenario.max_speed is None:
                return expr(
                    "$current + $acceleration",
                    current=current,
                    acceleration=acceleration,
                )
            else:
                return expr(
                    "max(min($current + $acceleration, $max_speed), -$max_speed)",
                    current=current,
                    acceleration=acceleration,
                    max_speed=scenario.max_speed,
                )

        for ax, ay in itertools.product(scenario.possible_accelerations, repeat=2):
            automaton.create_edge(
                source=initial,
                destinations={
                    model.create_destination(
                        location=initial,
                        assignments={
                            "car_dx": compute_speed(expr("car_dx"), ax),
                            "car_dy": compute_speed(expr("car_dy"), ay),
                        },
                        probability=scenario.underground.acceleration_probability,
                    ),
                    model.create_destination(
                        location=initial,
                        assignments={
                            "car_dx": compute_speed(
                                expr("car_dx"),
                                scenario.underground.acceleration_model(ax),
                            ),
                            "car_dy": compute_speed(
                                expr("car_dy"),
                                scenario.underground.acceleration_model(ay),
                            ),
                        },
                        probability=expr(
                            "1 - $p",
                            p=scenario.underground.acceleration_probability,
                        ),
                    ),
                },
                action_pattern=accelerate,
                annotation={"ax": ax, "ay": ay},
            )

        return automaton

    def construct_tank_automaton() -> model.Automaton:
        automaton = ctx.create_automaton(name="tank")
        initial = automaton.create_location(initial=True)

        consumption = scenario.compute_consumption(expr("car_dx"), expr("car_dy"))
        automaton.create_edge(
            source=initial,
            destinations={
                model.create_destination(
                    initial,
                    assignments={
                        "fuel": expr(
                            "fuel - floor($consumption)", consumption=consumption
                        )
                    },
                )
            },
            action_pattern=check_tick,
            guard=expr(
                "fuel >= $consumption",
                consumption=consumption,
            ),
        )

        return automaton

    def construct_environment_automaton() -> model.Automaton:
        automaton = ctx.create_automaton(name="environment")

        automaton.scope.declare_variable(
            "start_x", typ=types.INT.bound(-1, track.width), initial_value=0
        )
        automaton.scope.declare_variable(
            "start_y", typ=types.INT.bound(-1, track.height), initial_value=0
        )
        automaton.scope.declare_variable(
            "counter",
            typ=types.INT.bound(0, max(track.width, track.height) + 1),
            initial_value=0,
        )

        wait_for_car = automaton.create_location("wait_for_car", initial=True)
        move_car = automaton.create_location("move_car")
        env_check = automaton.create_location("env_check")

        move_ticks = expr("max(abs(car_dx), abs(car_dy))")

        # Wait for the decision of the car.
        automaton.create_edge(
            source=wait_for_car,
            destinations={
                model.create_destination(
                    location=move_car,
                    assignments={
                        "counter": expr("0"),
                        "start_x": expr("car_x"),
                        "start_y": expr("car_y"),
                    },
                )
            },
            action_pattern=accelerate,
        )

        # Move the car or delegate the decision back to the car.
        automaton.create_edge(
            source=move_car,
            destinations={
                model.create_destination(
                    env_check,
                    assignments={
                        "counter": expr("counter + 1"),
                        "car_x": expr(
                            "start_x + floor((counter + 1) * (car_dx / $move_ticks) + 0.5)",
                            move_ticks=move_ticks,
                        ),
                        "car_y": expr(
                            "start_y + floor((counter + 1) * (car_dy / $move_ticks) + 0.5)",
                            move_ticks=move_ticks,
                        ),
                    },
                )
            },
            guard=expr("counter < $move_ticks", move_ticks=move_ticks),
            action_pattern=move_tick,
        )
        automaton.create_edge(
            source=move_car,
            destinations={model.create_destination(wait_for_car)},
            guard=expr("counter >= $move_ticks", move_ticks=move_ticks),
            action_pattern=delegate,
        )

        # Checker whether we should terminate or continue moving the car.
        should_terminate = expr(
            "$is_off_track or $is_at_goal or $is_at_blocked",
            is_off_track=is_off_track(),
            is_at_goal=is_at_goal(),
            is_at_blocked=is_at_blocked(),
        )
        automaton.create_edge(
            source=env_check,
            destinations={model.create_destination(move_car)},
            guard=expr("not $should_terminate", should_terminate=should_terminate),
            action_pattern=check_tick,
        )

        return automaton

    car = construct_car_automaton().create_instance()
    environment = construct_environment_automaton().create_instance()

    check_tick_vector = {environment: check_tick}

    if scenario.fuel_model:
        tank = construct_tank_automaton().create_instance()
        check_tick_vector[tank] = check_tick

    network.create_link({car: accelerate, environment: accelerate}, result=accelerate)
    network.create_link({environment: move_tick}, result=move_tick)
    network.create_link(check_tick_vector, result=check_tick)
    network.create_link({environment: delegate}, result=delegate)

    return network


def generate_scenarios(
    track: Track, speed_bound: int, acceleration_bound: int
) -> t.Iterator[Scenario]:
    for start_cell in track.start_cells:
        for max_speed in range(1, speed_bound + 1):
            for max_acceleration in range(1, acceleration_bound + 1):
                for underground in Underground:
                    for tank_type in TankType:
                        yield Scenario(
                            track,
                            start_cell,
                            tank_type,
                            underground,
                            max_speed,
                            max_acceleration,
                        )
