#!/usr/bin/env python3


from argparse import ArgumentParser, Namespace
from atexit import register
from dataclasses import dataclass
from enum import Enum
from itertools import chain, cycle, islice, repeat
from math import pi, sin
from os import environ
from random import choice, randint
from shutil import get_terminal_size
from signal import SIG_DFL, SIGPIPE, signal
from sys import stdin
from typing import Callable, Dict, Iterator, List, Tuple, cast


class ColourSpace(Enum):
    EIGHT = "8"
    TRUE = "24"


class Flags(Enum):
    LES = 1
    GAY = 2
    BI = 3
    TRANS = 4
    ACE = 5
    PAN = 6
    NB = 7
    GQ = 8


HexColour = str
RGB = Tuple[int, int, int]
RawPalette = List[Tuple[HexColour, int]]
Palette = List[RGB]


@dataclass
class FlagSpec:
    aspect_ratio: Tuple[int, int]
    palette: RawPalette


FLAG_SPECS: Dict[Flags, FlagSpec] = {
    Flags.LES: FlagSpec(
        aspect_ratio=(3, 5),
        palette=[
            ("#D62E02", 1),
            ("#FD9855", 1),
            ("#FFFFFF", 1),
            ("#D161A2", 1),
            ("#A20160", 1),
        ],
    ),
    Flags.GAY: FlagSpec(
        aspect_ratio=(3, 5),
        palette=[
            ("#FF0018", 1),
            ("#FFA52C", 1),
            ("#FFFF41", 1),
            ("#008018", 1),
            ("#0000F9", 1),
            ("#86007D", 1),
        ],
    ),
    Flags.BI: FlagSpec(
        aspect_ratio=(3, 5), palette=[("#D60270", 2), ("#9B4F96", 1), ("#0038A8", 2)],
    ),
    Flags.TRANS: FlagSpec(
        aspect_ratio=(3, 5),
        palette=[
            ("#55CDFC", 1),
            ("#F7A8B8", 1),
            ("#FFFFFF", 1),
            ("#F7A8B8", 1),
            ("#55CDFC", 1),
        ],
    ),
    Flags.ACE: FlagSpec(
        aspect_ratio=(3, 5),
        palette=[("#000000", 1), ("#A4A4A4", 1), ("#FFFFFF", 1), ("#810081", 1)],
    ),
    Flags.PAN: FlagSpec(
        aspect_ratio=(3, 5), palette=[("#FF1B8D", 1), ("#FFDA00", 1), ("#1BB3FF", 1)],
    ),
    Flags.NB: FlagSpec(
        aspect_ratio=(3, 5),
        palette=[("#FFF430", 1), ("#FFFFFF", 1), ("#9C59D1", 1), ("#000000", 1)],
    ),
    Flags.GQ: FlagSpec(
        aspect_ratio=(3, 5), palette=[("#B77FDD", 1), ("#FFFFFF", 1), ("#48821E", 1)]
    ),
}


def on_exit() -> None:
    print("\033[0m", end="", flush=True)


def parse_args() -> Namespace:
    rand_flag = choice([f for f in Flags])
    colour_space = (
        ColourSpace.TRUE
        if environ.get("COLORTERM") in {"truecolor", "24bit"}
        else ColourSpace.EIGHT
    )
    namespace = Namespace(flag=rand_flag)
    parser = ArgumentParser()

    parser.add_argument(
        "-c",
        "--colour",
        choices=[c.value for c in ColourSpace],
        default=colour_space.value,
    )

    parser.add_argument("-f", "--flag", dest="flag_only", action="store_true")

    parser.add_argument(
        "-l", "--les", "--lesbian", action="store_const", dest="flag", const=Flags.LES,
    )
    parser.add_argument(
        "-g", "--gay", action="store_const", dest="flag", const=Flags.GAY,
    )
    parser.add_argument(
        "-b", "--bi", "--bisexual", action="store_const", dest="flag", const=Flags.BI,
    )
    parser.add_argument(
        "-t",
        "--trans",
        "--transgender",
        action="store_const",
        dest="flag",
        const=Flags.TRANS,
    )
    parser.add_argument(
        "-a", "--ace", "--asexual", action="store_const", dest="flag", const=Flags.ACE,
    )
    parser.add_argument(
        "-p",
        "--pan",
        "--pansexual",
        action="store_const",
        dest="flag",
        const=Flags.PAN,
    )
    parser.add_argument(
        "-n", "--nb", "--non-binary", action="store_const", dest="flag", const=Flags.NB,
    )
    parser.add_argument(
        "--gq", "--gender-queer", action="store_const", dest="flag", const=Flags.GQ,
    )
    parser.add_argument("--period", type=int, default=randint(5, 10))

    return parser.parse_args(namespace=namespace)


def flag_colours(flag: FlagSpec) -> Palette:
    return [parse_colour(p) for p, _ in flag.palette]


def parse_colour(colour: HexColour) -> RGB:
    hexc = colour[1:]
    it = iter(hexc)
    parsed = tuple(
        int(f"{h1}{h2}", 16) for h1, h2 in iter(lambda: tuple(islice(it, 2)), ())
    )
    return cast(RGB, parsed)


def lerp(c1: RGB, c2: RGB, mix: float) -> RGB:
    lhs = map(lambda c: c * mix, c1)
    rhs = map(lambda c: c * (1 - mix), c2)
    new = map(lambda c: int(round(sum(c))), zip(lhs, rhs))
    return cast(RGB, tuple(new))


def rgb_gen(palette: Palette, rep: int) -> Iterator[RGB]:
    gen = cycle(palette)

    period = pi * pi / 2

    def wave(t: float) -> float:
        x = t / rep * period
        return sin(x / pi + pi / 2)

    prev = next(gen)
    while True:
        curr = next(gen)
        for t in range(0, rep + 1):
            mix = wave(t)
            yield lerp(prev, curr, mix)
        prev = curr


def decor_8(rgb: RGB) -> Iterator[str]:
    r, g, b = map(lambda c: int(round(c / 255 * 5)), rgb)
    yield str(16 + 36 * r + 6 * g + b)


def decor_24(rgb: RGB) -> Iterator[str]:
    it = chain(*zip(repeat(";"), map(str, rgb)))
    next(it)
    yield from it


def decor_for(space: ColourSpace) -> Tuple[str, str, Callable[[RGB], Iterator[str]]]:
    if space == ColourSpace.EIGHT:
        return "\033[38;5;", "\033[48;5;", decor_8
    elif space == ColourSpace.TRUE:
        return "\033[38;2;", "\033[48;2;", decor_24
    else:
        raise ValueError()


def paint_flag(colour_space: ColourSpace, spec: FlagSpec) -> Iterator[str]:
    cols, rows = get_terminal_size((80, 140))
    _, bg_esc, decor = decor_for(colour_space)
    r, c = spec.aspect_ratio
    height = sum(h for _, h in spec.palette)
    ratio = r / c * 0.5
    multiplier = int(min((rows - 4) / height, cols / height * ratio))
    m = max(multiplier, 1)
    line = " " * cols
    for hexc, l in spec.palette:
        colour = parse_colour(hexc)
        for _ in range(0, l * m):
            yield bg_esc
            yield from decor(colour)
            yield "m"
            yield line
            yield "\033[0m"
            yield "\n"


def colourize(
    colour_space: ColourSpace, spec: FlagSpec, period: int, lines: Iterator[str],
) -> Iterator[str]:
    palette = flag_colours(spec)
    colour_gen = rgb_gen(palette, period)
    fg_esc, _, decor = decor_for(colour_space)
    for line in lines:
        for char in line:
            colour = next(colour_gen)
            yield fg_esc
            yield from decor(colour)
            yield "m"
            yield char
        yield "\033[0m"
        yield "\n"


def main() -> None:
    signal(SIGPIPE, SIG_DFL)
    args = parse_args()
    register(on_exit)
    colour_space = ColourSpace(args.colour)
    spec = FLAG_SPECS[args.flag]

    if args.flag_only:
        flag_stripes = paint_flag(colour_space=colour_space, spec=spec)
        print(*flag_stripes, sep="", end="")
    else:
        data = stdin.read()
        lines = data.splitlines()

        new_lines = colourize(
            colour_space=colour_space, spec=spec, period=args.period, lines=iter(lines),
        )
        print(*new_lines, sep="", end="")


try:
    main()
except KeyboardInterrupt:
    exit(130)
