import asyncio
from typing import Any, AsyncIterator, Dict, Mapping, Optional, Tuple, Type, Union

from redis.asyncio import BlockingConnectionPool

from cashews.backends.interface import Backend

from .client import Redis, SafeRedis

_UNLOCK = """
if redis.call("GET", KEYS[1]) == ARGV[1] then
    return redis.call("DEL", KEYS[1])
else
    return 0
end
"""
_INCR_SLICE = """
redis.call("ZREMRANGEBYSCORE", KEYS[1], 0, "(" .. ARGV[1])
local current_count = redis.call("ZCOUNT", KEYS[1], ARGV[1], ARGV[2])
if current_count < tonumber(ARGV[3]) then
    current_count = current_count + 1
    redis.call("ZADD", KEYS[1], ARGV[2], ARGV[2])
    if tonumber(ARGV[4]) > 0 then
        redis.call("PEXPIRE", KEYS[1], ARGV[4])
    end
end
return current_count
"""
_empty = object()
# pylint: disable=arguments-differ
# pylint: disable=abstract-method


class _Redis(Backend):
    _client: Union[Redis, SafeRedis]
    _client_class: Union[Type[Redis], Type[SafeRedis]]

    def __init__(self, address, safe: bool = True, **kwargs: Any) -> None:
        kwargs.pop("local_cache", None)
        kwargs.pop("prefix", None)
        kwargs.setdefault("client_name", "cashews")
        kwargs.setdefault("health_check_interval", 10)
        kwargs.setdefault("max_connections", 10)
        kwargs.setdefault("socket_keepalive", True)
        kwargs.setdefault("retry_on_timeout", False)
        kwargs.setdefault("socket_timeout", 0.1)
        kwargs["decode_responses"] = False

        self._pool_class = kwargs.pop("connection_pool_class", BlockingConnectionPool)
        if self._pool_class == BlockingConnectionPool:
            kwargs["timeout"] = kwargs.pop("wait_for_connection_timeout", 0.1)
        self._sha: Dict[str, Any] = {}
        if not safe:
            self._client_class = Redis
        else:
            self._client_class = SafeRedis
        self._kwargs = kwargs
        self._address = address
        self.__is_init = False
        super().__init__()

    @property
    def is_init(self) -> bool:
        return self.__is_init

    async def init(self):
        self._client = self._client_class(connection_pool=self._pool_class.from_url(self._address, **self._kwargs))
        await self._client.initialize()
        self.__is_init = True

    async def clear(self):
        return await self._client.flushdb()

    async def set(self, key: str, value: Any, expire: Optional[float] = None, exist=None) -> bool:
        nx = xx = None
        if exist is True:
            xx = True
        elif exist is False:
            nx = True
        pexpire = None
        if isinstance(expire, float):
            pexpire = int(expire * 1000)
            expire = None
        return bool(await self._client.set(key, value, ex=expire, px=pexpire, nx=nx, xx=xx))

    async def set_many(self, pairs: Mapping[str, Any], expire: Optional[float] = None):
        await self._client.mset(pairs)
        if expire is not None:
            async with self._client.pipeline(transaction=True) as pipe:
                for key in pairs.keys():
                    await pipe.pexpire(key, int(expire * 1000))
                await pipe.execute()

    async def get_expire(self, key: str) -> int:
        return await self._client.ttl(key)

    async def expire(self, key: str, timeout: float):
        return await self._client.pexpire(key, int(timeout * 1000))

    async def set_lock(self, key: str, value, expire: float) -> bool:
        pexpire = None
        if isinstance(expire, float):
            pexpire = int(expire * 1000)
            expire = None
        return bool(await self._client.set(key, value, ex=expire, px=pexpire, nx=True))

    async def is_locked(
        self,
        key: str,
        wait: Optional[float] = None,
        step: float = 0.1,
    ) -> bool:
        if wait is None:
            return await self.exists(key)
        async with self._client.client() as conn:
            while wait > 0.0:
                if not await conn.exists(key):
                    return False
                wait -= step
                await asyncio.sleep(step)
        return True

    async def unlock(self, key: str, value: Any) -> bool:
        if "UNLOCK" not in self._sha:
            self._sha["UNLOCK"] = await self._client.script_load(_UNLOCK.replace("\n", " "))
        return await self._client.evalsha(self._sha["UNLOCK"], 1, key, value)

    async def delete(self, key: str) -> bool:
        return bool(await self._client.unlink(key))

    async def exists(self, key: str) -> bool:
        return bool(await self._client.exists(key))

    async def scan(self, pattern: str, batch_size: int = 100) -> AsyncIterator[str]:  # type: ignore
        cursor = 0
        while True:
            cursor, keys = await self._client.scan(cursor, match=pattern, count=batch_size)
            for key in keys:
                yield key.decode()
            if not cursor:
                return

    async def delete_many(self, *keys: str):
        await self._client.delete(*keys)

    async def delete_match(self, pattern: str):
        if "*" not in pattern:
            await self._client.unlink(pattern)
            return
        cursor = 0
        while True:
            cursor, keys = await self._client.scan(cursor, match=pattern, count=100)
            if not keys:
                if not cursor:
                    return
                continue
            await self._client.unlink(keys[0], *keys[1:])

    async def get_match(self, pattern: str, batch_size: int = 100) -> AsyncIterator[Tuple[str, Any]]:  # type: ignore
        cursor = 0
        while True:
            cursor, keys = await self._client.scan(cursor, match=pattern, count=batch_size)
            if not keys:
                if not cursor:
                    return
                continue
            keys = [key.decode() for key in keys]
            values = await self.get_many(*keys, default=_empty)
            for key, value in zip(keys, values):
                if value is not _empty:  # key can be deleted after scan
                    yield key, value
            if not cursor:
                return

    async def get_size(self, key: str) -> int:
        size = await self._client.memory_usage(key) or 0
        return int(size)

    async def get(self, key: str, default=None) -> Any:
        return await self._client.get(key)

    async def get_many(self, *keys: str, default: Optional[Any] = None) -> Tuple[Optional[Any], ...]:
        if not keys:
            return tuple()
        values = await self._client.mget(keys[0], *keys[1:])
        if not values:
            return tuple()
        return tuple(value if value is not None else default for value in values)

    async def incr(self, key: str) -> int:
        return await self._client.incr(key)

    async def get_bits(self, key: str, *indexes: int, size: int = 1):
        """
        https://redis.io/commands/bitfield
        """
        bitops = self._client.bitfield(key)
        for index in indexes:
            bitops.get(fmt=f"u{size}", offset=f"#{index}")
        return tuple(await bitops.execute() or [])

    async def incr_bits(self, key: str, *indexes: int, size: int = 1, by: int = 1) -> Tuple[int, ...]:
        bitops = self._client.bitfield(key)
        for index in indexes:
            bitops.incrby(fmt=f"u{size}", offset=f"#{index}", increment=by, overflow="SAT")
        return tuple(await bitops.execute())

    async def ping(self, message: Optional[bytes] = None) -> bytes:
        pong = await self._client.ping()
        if pong and message:
            return message
        return b"PONG"

    async def set_raw(self, key: str, value: Any, **kwargs: Any):
        return await self._client.set(key, value, **kwargs)

    async def get_raw(self, key: str) -> Any:
        return await self._client.get(key)

    async def slice_incr(self, key: str, start: int, end: int, maxvalue: int, expire: Optional[float] = None) -> int:
        expire = expire or 0
        expire = int(expire * 1000)
        if "INCR_SLICE" not in self._sha:
            self._sha["INCR_SLICE"] = await self._client.script_load(_INCR_SLICE.replace("\n", " "))
        return await self._client.evalsha(self._sha["INCR_SLICE"], 1, key, start, end, maxvalue, expire)

    async def close(self):
        await self._client.close()
        self.__is_init = False
