"""richsmi.py"""

import os
import warnings
from datetime import datetime
from enum import IntEnum, auto
from pathlib import Path
from shutil import get_terminal_size
from typing import Any, Iterable, NamedTuple, Optional

import toml
from pkg_resources import get_distribution
from rich import box
from rich.console import RenderableType
from rich.panel import Panel
from rich.progress import BarColumn, Progress
from rich.table import PaddingDimensions, Table
from rich.text import Text
from tap import Tap
from textual.app import App, events
from textual.reactive import Reactive
from textual.views import GridView
from textual.widget import Widget
from textual.widgets import Footer, Header

from librichsmi import NvmlWrapper  # pylint: disable=no-name-in-module
from . import utils

warnings.simplefilter("ignore", ResourceWarning)


class ArgumentParser(Tap):
    """Typed Argument Parser"""

    loop: float = 1.0
    id: Optional[int] = None
    config: str = Path(__file__).resolve().parent / "config.toml"

    def configure(self):
        """Override argument configure"""

        self.add_argument(
            '-l', '--loop', help="Interval seconds. Defaults to 1.0s.")
        self.add_argument(
            '-i', '--id', help="Target a specific GPU. If not specified, show all GPU info.")
        self.add_argument('--config', help="Path to custom config file.")


class MyHeader(Header):  # pylint: disable=too-few-public-methods
    """Display header with clock."""

    timestamp: Reactive[RenderableType] = Reactive("")

    def __init__(self, driver_version: str,
                 cuda_version: int, config: dict) -> None:
        """Constructor"""

        super().__init__(style="")

        self.timestamp = datetime.now().ctime()
        self.driver_version = driver_version
        self.cuda_version = cuda_version
        self.config = config

        self.version = get_distribution("richsmi").version

    async def on_timer(self, event: events.Timer) -> None:  # pylint: disable=unused-argument
        """[Override] callback on timer"""
        self.timestamp = datetime.now().ctime().replace(":", "[blink]:[/]")

    def render(self) -> Panel:
        """Render widget contents"""

        grid = Table.grid(expand=True)
        grid.add_column(justify="left", ratio=1)
        grid.add_column(justify="center", ratio=2)
        grid.add_column(justify="right", ratio=1)
        grid.add_row(
            Text(f"Driver Version {self.driver_version} CUDA Version {self.cuda_version}",
                 overflow="ellipsis"),
            f"[b]Rich-SMI[/] Ver. {self.version}",
            self.timestamp,
        )

        return Panel(grid, border_style=self.config["border_style"], box=box.SQUARE)


class MyFooter(Footer):  # pylint: disable=too-few-public-methods
    """Display footer."""

    def make_key_text(self) -> Text:
        """Create text containing all the keys."""

        text = Text(
            style="white on #505050",
            no_wrap=True,
            overflow="ellipsis",
            justify="left",
            end="",
        )
        for binding in self.app.bindings.shown_keys:
            key_display = (
                binding.key.upper()
                if binding.key_display is None
                else binding.key_display
            )
            hovered = self.highlight_key == binding.key
            key_text = Text.assemble(
                (f" {key_display} ", "reverse" if hovered else "default on red"),
                f" {binding.description} ",
                meta={
                    "@click": f"app.press('{binding.key}')", "key": binding.key},
            )
            text.append_text(key_text)

        return text


class BreakPoint(IntEnum):
    """Break point"""

    S = auto()
    M = auto()
    L = auto()


class BreakPoints(NamedTuple):
    """Break points"""

    columns: BreakPoint
    lines: BreakPoint


class WidgetUtils:
    """Widget utility class"""

    @staticmethod
    def get_panel(table: Table, config: dict, hover: bool) -> Panel:
        """Get panel"""

        bp_lines = WidgetUtils.get_breakpoints(get_terminal_size()).lines

        padding = (1, 2, 1, 2)
        if bp_lines < BreakPoint.M:
            padding = (0, 2, 0, 2)

        return Panel(
            table,
            title=config["title"], title_align="left",
            border_style=(config["border_hover_style"]
                          if hover else config["border_style"]),
            padding=padding, box=box.SQUARE
        )

    @staticmethod
    def create_table(headers: Iterable[str], config: dict, ratio: Optional[Iterable[int]] = None,
                     padding: PaddingDimensions = 0) -> Table:
        """Create rich.Table"""

        bp_lines = WidgetUtils.get_breakpoints(get_terminal_size()).lines
        if BreakPoint.M <= bp_lines < BreakPoint.L:
            table = Table.grid(expand=True, padding=(1, 0, 1, 0))
        elif bp_lines < BreakPoint.M:
            table = Table.grid(expand=True)
        else:
            table = Table(expand=True, padding=padding,
                          box=box.SIMPLE, style=config["table_style"])

        iter_ = headers if ratio is None else zip(headers, ratio)

        for item in iter_:
            if ratio is None:
                header = item
            else:
                header, ratio = item

            if bp_lines < BreakPoint.L:
                table.add_column(ratio=ratio, justify="left")
            else:
                table.add_column(header, header_style=config["table_header_style"],
                                 ratio=ratio, justify="left", no_wrap=True)

        if bp_lines < BreakPoint.L:
            table.add_row(*headers, style="bright_black", end_section=True)

        return table

    @staticmethod
    def get_breakpoints(terminal_size: os.terminal_size) -> BreakPoints:
        """Get break points"""

        columns_bp = BreakPoint.S
        if 124 <= terminal_size.columns < 200:
            columns_bp = BreakPoint.M
        elif 200 <= terminal_size.columns:
            columns_bp = BreakPoint.L

        lines_bp = BreakPoint.S
        if 26 <= terminal_size.lines < 30:
            lines_bp = BreakPoint.M
        elif 30 <= terminal_size.lines:
            lines_bp = BreakPoint.L

        return BreakPoints(columns_bp, lines_bp)


class Memory(Widget):
    """memory widget class"""

    mouse_over: Reactive[RenderableType] = Reactive(False)
    nvml: NvmlWrapper

    def __init__(self, nvml: NvmlWrapper, config: dict,
                 loop: float, *, name: Optional[str] = None) -> None:
        """Constructor"""

        super().__init__(name=name)

        self.nvml = nvml
        self.loop = loop

        self.config = config["body"]

    def on_mount(self):
        """Call after terminal goes in to application mode"""
        self.set_interval(self.loop, self.refresh)

    def render(self) -> RenderableType:
        """Render widget contents"""

        table: Table = WidgetUtils.create_table(
            ["ID", "Name", "Usage", "Pct.", "Used/Total"], self.config, ratio=(1, 4, 4, 2, 4)
        )

        gpu_names = self.nvml.get_gpu_names()
        memories = self.nvml.get_memories()

        for i, memory in enumerate(memories):
            used = memory["used"]
            total = memory["total"]
            percent = used / total * 100.

            memory_bar = Progress(
                BarColumn(complete_style=utils.get_style(percent, self.config))
            )
            memory_bar.add_task(f"{i}", completed=used, total=total)

            table.add_row(
                Text(f"{i}", style=self.config["gpu_id_style"]),
                Text(f"{gpu_names[i]}",
                     style=self.config["gpu_name_style"]),
                memory_bar,
                Text(f"{percent:>3.0f}", style=utils.get_style(
                    percent, self.config))
                + Text(" %", style=self.config["suffix_style"]),
                utils.render_file_unit(used, self.config, total))

        return WidgetUtils.get_panel(table, self.config["memory"],
                                     self.mouse_over)

    async def on_enter(self, event: events.Enter) -> None:  # pylint: disable=unused-argument
        """[Override] callback on enter"""
        self.mouse_over = True

    async def on_leave(self, event: events.Leave) -> None:  # pylint: disable=unused-argument
        """[Override] callback on leave"""
        self.mouse_over = False


class Clock(Widget):
    """clock widget class"""

    mouse_over: Reactive[RenderableType] = Reactive(False)
    nvml: NvmlWrapper

    def __init__(self, nvml: NvmlWrapper, config: dict,
                 loop: float, *, name: Optional[str] = None) -> None:
        """Constructor"""

        super().__init__(name=name)

        self.nvml = nvml
        self.loop = loop

        self.config = config["body"]

    def on_mount(self):
        """Call after terminal goes in to application mode"""
        self.set_interval(self.loop, self.refresh)

    def render(self) -> RenderableType:
        """Render widget contents"""

        bp_columns = WidgetUtils.get_breakpoints(get_terminal_size()).columns

        headers = ["Graphics", "Memory", "SM", "Video"]
        if bp_columns < BreakPoint.M:
            headers = ["Graphics", "Memory"]

        table: Table = WidgetUtils.create_table(headers, self.config)

        clocks = self.nvml.get_clocks()

        for clock in clocks:
            texts = []
            for header in headers:
                texts.append(
                    Text(f"{clock[header.lower()]}", style=self.config["val_normal_style"]) +
                    Text(" MHz", style=self.config["suffix_style"])
                )

            table.add_row(*texts)

        return WidgetUtils.get_panel(table, self.config["clock"],
                                     self.mouse_over)

    async def on_enter(self, event: events.Enter) -> None:  # pylint: disable=unused-argument
        """[Override] callback on enter"""
        self.mouse_over = True

    async def on_leave(self, event: events.Leave) -> None:  # pylint: disable=unused-argument
        """[Override] callback on leave"""
        self.mouse_over = False


class Utilization(Widget):
    """utilization widget class"""

    mouse_over: Reactive[RenderableType] = Reactive(False)
    nvml: NvmlWrapper

    def __init__(self, nvml: NvmlWrapper, config: dict,
                 loop: float, *, name: Optional[str] = None) -> None:
        """Constructor"""

        super().__init__(name=name)

        self.nvml = nvml
        self.loop = loop

        self.config = config["body"]

    def on_mount(self):
        """Call after terminal goes in to application mode"""
        self.set_interval(self.loop, self.refresh)

    def render(self) -> RenderableType:
        """Render widget contents"""

        table: Table = WidgetUtils.create_table(
            ["GPU Util.", "Memory Util."], self.config
        )

        utilizations = self.nvml.get_utilizations()

        for utilization in utilizations:
            table.add_row(
                Text(f"{utilization['gpu']}", style=utils.get_style(
                    utilization['gpu'], self.config)) +
                Text(" %", style=self.config["suffix_style"]),
                Text(f"{utilization['memory']}", style=utils.get_style(
                    utilization['memory'], self.config)) +
                Text(" %", style=self.config["suffix_style"])
            )

        return WidgetUtils.get_panel(table, self.config["utilization"],
                                     self.mouse_over)

    async def on_enter(self, event: events.Enter) -> None:  # pylint: disable=unused-argument
        """[Override] callback on enter"""
        self.mouse_over = True

    async def on_leave(self, event: events.Leave) -> None:  # pylint: disable=unused-argument
        """[Override] callback on leave"""
        self.mouse_over = False


class Power(Widget):
    """power widget class"""

    mouse_over: Reactive[RenderableType] = Reactive(False)
    nvml: NvmlWrapper

    def __init__(self, nvml: NvmlWrapper, config: dict,
                 loop: float, *, name: Optional[str] = None) -> None:
        """Constructor"""

        super().__init__(name=name)

        self.nvml = nvml
        self.loop = loop

        self.config = config["body"]

    def on_mount(self):
        """Call after terminal goes in to application mode"""
        self.set_interval(self.loop, self.refresh)

    def render(self) -> RenderableType:
        """Render widget contents"""

        bp_columns = WidgetUtils.get_breakpoints(get_terminal_size()).columns

        headers = ["Usage/Limit", "Total Energy"]
        if bp_columns < BreakPoint.M:
            headers = ["Usage/Limit"]

        table: Table = WidgetUtils.create_table(
            headers, self.config
        )

        powers = self.nvml.get_powers()

        for power in powers:
            usage = power["usage"]
            limit = power["limit"]
            percent = usage / limit * 100.

            if bp_columns < BreakPoint.M:
                table.add_row(
                    Text(f"{usage:.1f}", style=utils.get_style(percent, self.config)) +
                    Text(f"/{limit:.1f} W", style=self.config["suffix_style"]),
                )
            else:
                table.add_row(
                    Text(f"{usage:.1f}", style=utils.get_style(percent, self.config)) +
                    Text(f"/{limit:.1f} W", style=self.config["suffix_style"]),
                    utils.render_energy_unit(power["energy"], self.config),
                )

        return WidgetUtils.get_panel(table, self.config["power"],
                                     self.mouse_over)

    async def on_enter(self, event: events.Enter) -> None:  # pylint: disable=unused-argument
        """[Override] callback on enter"""
        self.mouse_over = True

    async def on_leave(self, event: events.Leave) -> None:  # pylint: disable=unused-argument
        """[Override] callback on leave"""
        self.mouse_over = False


class Temperature(Widget):
    """temperature widget class"""

    mouse_over: Reactive[RenderableType] = Reactive(False)
    nvml: NvmlWrapper

    def __init__(self, nvml: NvmlWrapper, config: dict,
                 loop: float, *, name: Optional[str] = None) -> None:
        """Constructor"""

        super().__init__(name=name)

        self.nvml = nvml
        self.loop = loop

        self.config = config["body"]

    def on_mount(self):
        """Call after terminal goes in to application mode"""
        self.set_interval(self.loop, self.refresh)

    def render(self) -> RenderableType:
        """Render widget contents"""

        table: Table = WidgetUtils.create_table(
            ["Temperature", "Fan Speed"], self.config
        )

        temperatures = self.nvml.get_temperatures()
        fan_speeds = self.nvml.get_fan_speeds()

        for temperature, fan_speed in zip(temperatures, fan_speeds):
            board = temperature['board']
            slowdown = temperature['thresh_slowdown'] - 5
            shutdown = temperature['thresh_shutdown'] - 5
            speed = fan_speed['speed']

            style = self.config["val_normal_style"]
            warning = ""
            if slowdown < board <= shutdown:
                style = self.config["val_warn_style"]
                warning = "(Slow Down) "
            elif shutdown < board:
                style = self.config["val_error_style"]
                warning = "(Shut Down) "

            table.add_row(
                Text(f"{board} {warning}", style=style) +
                Text("C", style=self.config["suffix_style"]),
                Text(f"{speed}", style=utils.get_style(speed, self.config)) +
                Text(" %", style=self.config["suffix_style"]),
            )

        return WidgetUtils.get_panel(table, self.config["temperature"],
                                     self.mouse_over)

    async def on_enter(self, event: events.Enter) -> None:  # pylint: disable=unused-argument
        """[Override] callback on enter"""
        self.mouse_over = True

    async def on_leave(self, event: events.Leave) -> None:  # pylint: disable=unused-argument
        """[Override] callback on leave"""
        self.mouse_over = False


class ComputeProcs(Widget):
    """compute processes widget class"""

    mouse_over: Reactive[RenderableType] = Reactive(False)
    nvml: NvmlWrapper

    def __init__(self, nvml: NvmlWrapper, config: dict,
                 loop: float, *, name: Optional[str] = None) -> None:
        """Constructor"""

        super().__init__(name=name)

        self.nvml = nvml
        self.loop = loop

        self.config = config["body"]

    def on_mount(self):
        """Call after terminal goes in to application mode"""
        self.set_interval(self.loop, self.refresh)

    def render(self) -> RenderableType:
        """Render widget contents"""

        table: Table = WidgetUtils.create_table(
            ["PID", "Name", "Memory Usage"], self.config, ratio=(1, 5, 2)
        )

        all_procs = self.nvml.get_compute_procs()

        for procs in all_procs:
            for proc in procs:
                name, pid, memory = proc
                name = name.strip('\n')
                table.add_row(
                    Text(f"{pid}", style=self.config["proc_id_style"]),
                    Text(f"{name}", style=self.config["proc_name_style"]),
                    utils.render_file_unit(memory, self.config),
                )

        return WidgetUtils.get_panel(table, self.config["compute_procs"],
                                     self.mouse_over)

    async def on_enter(self, event: events.Enter) -> None:  # pylint: disable=unused-argument
        """[Override] callback on enter"""
        self.mouse_over = True

    async def on_leave(self, event: events.Leave) -> None:  # pylint: disable=unused-argument
        """[Override] callback on leave"""
        self.mouse_over = False


class GraphicsProcs(Widget):
    """graphics processes widget class"""

    mouse_over: Reactive[RenderableType] = Reactive(False)
    nvml: NvmlWrapper

    def __init__(self, nvml: NvmlWrapper, config: dict,
                 loop: float, *, name: Optional[str] = None) -> None:
        """Constructor"""

        super().__init__(name=name)

        self.nvml = nvml
        self.loop = loop

        self.config = config["body"]

    def on_mount(self):
        """Call after terminal goes in to application mode"""
        self.set_interval(self.loop, self.refresh)

    def render(self) -> RenderableType:
        """Render widget contents"""

        table: Table = WidgetUtils.create_table(
            ["PID", "Name", "Memory Usage"], self.config, ratio=(1, 5, 2)
        )

        all_procs = self.nvml.get_graphics_procs()

        for procs in all_procs:
            for proc in procs:
                name, pid, memory = proc
                name = name.strip('\n')
                table.add_row(
                    Text(f"{pid}", style=self.config["proc_id_style"]),
                    Text(f"{name}", style=self.config["proc_name_style"]),
                    utils.render_file_unit(memory, self.config),
                )

        return WidgetUtils.get_panel(table, self.config["graphics_procs"],
                                     self.mouse_over)

    async def on_enter(self, event: events.Enter) -> None:  # pylint: disable=unused-argument
        """[Override] callback on enter"""
        self.mouse_over = True

    async def on_leave(self, event: events.Leave) -> None:  # pylint: disable=unused-argument
        """[Override] callback on leave"""
        self.mouse_over = False


class Body(GridView):
    """Body class"""

    def __init__(self, nvml: NvmlWrapper, config: dict, loop: float, *, name: str):
        """Constructor"""

        super().__init__(name=name)

        self.terminal_size = get_terminal_size()

        self.memory = Memory(
            nvml, config, loop, name="Memory"
        )
        self.clock = Clock(
            nvml, config, loop, name="Clock"
        )
        self.utilization = Utilization(
            nvml, config, loop, name="Utilization"
        )
        self.power = Power(
            nvml, config, loop, name="Power"
        )
        self.temperature = Temperature(
            nvml, config, loop, name="Temperature"
        )
        self.compute_procs = ComputeProcs(
            nvml, config, loop, name="ComputeProcs"
        )
        self.graphics_procs = GraphicsProcs(
            nvml, config, loop, name="GraphicsProcs"
        )

    async def on_mount(self, event: events.Mount) -> None:  # pylint: disable=unused-argument
        """Call after terminal goes in to application mode"""

        _ = [self.grid.add_column(name=f"x{i}") for i in range(12)]
        _ = [self.grid.add_row(name=f"y{j}") for j in range(3)]

        self.grid.add_areas(
            memory="x0-start|x7-end,y0",
            clock="x8-start|x11-end,y0",
            utilization="x0-start|x3-end,y1",
            power="x4-start|x7-end,y1",
            temperature="x8-start|x11-end,y1",
            compute_procs="x0-start|x5-end,y2",
            graphics_procs="x6-start|x11-end,y2",
        )

        self.grid.place(
            memory=self.memory,
            clock=self.clock,
            utilization=self.utilization,
            power=self.power,
            temperature=self.temperature,
            compute_procs=self.compute_procs,
            graphics_procs=self.graphics_procs,
        )


class RichSMI(App):
    """Rich SMI"""

    def __init__(self, screen: bool = True, driver_class: Optional[Any] = None,
                 log: str = "", log_verbosity: int = 1, title: str = "Textual Application"):
        """Constructor"""

        super().__init__(screen, driver_class, log, log_verbosity, title)

        args = ArgumentParser(underscores_to_dashes=True).parse_args()

        with open(args.config, mode="rt", encoding="utf-8") as fp:
            self.config = toml.load(fp)

        self.loop = args.loop
        self.gpu_id = -1 if args.id is None else args.id
        self.nvml = NvmlWrapper(self.gpu_id)

        self.header: MyHeader

    async def on_load(self) -> None:
        """Sent before going in to application mode."""

        await self.bind("q", "quit", "Quit")
        await self.bind("ctrl+c", "quit", show=False)

    async def on_mount(self) -> None:
        """Call after terminal goes in to application mode"""

        self.header = MyHeader(
            self.nvml.get_driver_version(),
            self.nvml.get_cuda_version(),
            self.config["header"]
        )

        await self.view.dock(self.header, edge="top")
        await self.view.dock(MyFooter(), edge="bottom")
        await self.view.dock(Body(self.nvml, self.config, self.loop, name="body"))

        self.set_interval(self.loop)

    async def on_timer(self, event: events.Timer) -> None:  # pylint: disable=unused-argument
        """[Override] callback on enter"""

        self.nvml.query(False)


def main():
    """Entry point"""

    RichSMI().run(title="RichSMI")

    return 0


if __name__ == "__main__":
    main()
