import redis
import threading
from loguru import logger
import traceback
from time import sleep
from typing import Union


class base_tredis_msg_sender:
    def __init__(self, tredis=None):
        self.tredis: Tredis = tredis

    def redis_msg_sender(self, channel, data):
        logger.info(f'This is what I got: {data} from {channel}')
        if self.tredis is not None and self.tredis.r is not None:
            ret = self.tredis.r.get(str(data))
            if ret is not None:
                logger.info(f'This is what I have found {ret} {type(ret)}')
                return ret
            else:
                logger.info(f'Cannot find {data} on server')
        return None


class Tredis(threading.Thread):
    default_port = 6379
    default_db = 0


    def __init__(self,
                 server='localhost',
                 port=6379,
                 db=0,
                 password='',
                 channel='test',
                 prefix='test',
                 redis_msg_sender=base_tredis_msg_sender()):
        threading.Thread.__init__(self)
        self.server = server
        self.port = port
        self.db = db
        self.password = password
        self.channel = channel
        self.prefix = prefix
        self.redis_msg_sender = redis_msg_sender
        self.redis_msg_sender.tredis = self
        self.r: redis.client.Redis = None
        self.str_thread_exit_magic = 'redis thread exit'

        self.sub: Union[redis.client.PubSub, None] = None

        # For test
        # r = redis.StrictRedis(host='localhost', port=6379, db=0, charset="utf-8", decode_responses=True)
        self.r = redis.StrictRedis(host=self.server,
                                   port=self.port,
                                   db=self.db,
                                   charset="utf-8",
                                   decode_responses=True,
                                   password=self.password)

        logger.info(f'Redis connected to {self.server}, port {self.port}, db: {self.db}')
        self.sub: redis.client.PubSub = self.r.pubsub()
        logger.info(f'Redis subscribe to channel [{self.channel}]')

        self.start()

    def subscribe(self, channel):
        logger.info(f'Redis subscribe to extra channel [{channel}]')
        self.sub.subscribe(channel)


    def run(self):
        t = threading.current_thread()
        self.sub.subscribe(self.channel)
        for message in self.sub.listen():
            if message:
                logger.info(f'REDIS got message: [{message}]')
                try:
                    channel = message['channel']
                    data = message['data']
                    logger.info(f'redis_working_thread got msg: {data}')

                    if isinstance(data, str) and data.startswith(self.str_thread_exit_magic):
                        if data[len(self.str_thread_exit_magic)+1:] == str(self):
                            logger.info(f'{self} to exit')
                            break
                    else:
                        function_send = getattr(self.redis_msg_sender, 'redis_msg_sender', None)
                        if function_send is not None and callable(function_send):
                            logger.info(f'function_send: {data}')
                            function_send(channel, data)
                except Exception as e:
                    traceback.print_exc()
                    logger.exception(f'\033[1;33mredis subscribe\n{message}\n{e}\033[0m')
        logger.info(f'redis thread: {self} stopped.')

    @logger.catch()
    def stop(self):
        t = threading.current_thread()
        self.r.publish(self.channel, f'{self.str_thread_exit_magic} {str(self)}')



if __name__ == '__main__':
    logger.disable('mypylib.tredis')
    tredis = Tredis()

    tredis.subscribe('test1')
    tredis.subscribe('test2')

    tredis.r.publish('test', 'This is a book')
    tredis.r.publish('test1', 'That is a pencil.')
    tredis.r.publish('test2', 'My name is John.')

    index = 0
    while True:
        try:
            sleep(1)
            index += 1
            if index == 10000000000:
                break
        except KeyboardInterrupt:
            break

    tredis.stop()




