import random
import time
from concurrent.futures import ThreadPoolExecutor
from queue import Queue, Empty
from threading import Lock, Condition, RLock
from typing import Dict, List, Callable, Optional, Any

import requests
from requests import Session

from simple_proxy2.data.proxy_info import ProxyInfo
from simple_proxy2.pool.proxy_pool import ProxyPool
from simple_proxy2.tools.random_user_agent import get_random as random_agent
from simple_proxy2.tools.simple_timer import SimpleTimer


class Task:
    def __init__(self, fn: Callable):
        self._fn = fn

        self._result = None
        self._started = False
        self._done = False
        self._mutex = RLock()
        self._condition = Condition(self._mutex)

    def get(self):
        with self._mutex:
            if self._done:
                return

            self._condition.wait()
            return self._result

    def __call__(self, *args):
        with self._mutex:
            if self._started:
                return

            self._started = True

            assert not self._done, "task is already finished"

            try:
                self._result = self._fn(*args)
                self._done = True
                self._condition.notify_all()
            except Exception as ex:
                print(ex)


class ProxyManager:
    def __init__(self,
                 test_url: str,
                 proxy_info_dict: Dict[str, List[str]],
                 num_executors=32,
                 executor_queue_size=1000,
                 verbose=False):
        self._test_url = test_url
        self._proxy_info_dict = proxy_info_dict
        self._verbose = verbose

        self._pool = self._init_pool(proxy_info_dict)

        self._executor_running = False
        self._executor_workers = num_executors
        self._executor_queue = Queue(executor_queue_size)
        self._executor = ThreadPoolExecutor(max_workers=self._executor_workers)

        self._metrics_lock = Lock()
        self._success = 0
        self._trials = 0

    def _init_pool(self, info_dict: Dict[str, List[str]]) -> ProxyPool:
        info_list = []
        for protocol, addresses in info_dict.items():
            for address in addresses:
                info_list.append(ProxyInfo(protocol, address))

        random.shuffle(info_list)
        return ProxyPool(self.success_rate, info_list)

    def _on_success(self):
        with self._metrics_lock:
            self._trials += 1
            self._success += 1

    def _on_fail(self):
        with self._metrics_lock:
            self._trials += 1

    def success_rate(self):
        with self._metrics_lock:
            if self._trials < 1:
                return 0.0
            else:
                return self._success / self._trials

    def proxy_as_request_session(self) -> Session:
        return SessionProxy(self)

    def proxy_session(self,
                      fn: Callable[[Session, tuple], Optional[Any]],
                      callback_async: Callable[[
                          Optional[Any]], None] = lambda x: None,
                      exception_handle: Callable[[
                          ProxyInfo, Exception], None] = lambda info, ex: print(info, ex),
                      args: tuple = (),
                      wait=False):
        """
        Work on the given function using the randomly chosen Session from the proxy pool. 'fn' will be retried
        indefinitely until success in order to try all available proxies.

        :param fn: function to execute. If has arguments, put those in the 'args'
        :param callback_async: callback to be invoked after 'fn' is done. Return value from 'fn' will be used as
         parameters
        :param exception_handle: decide what to do when exception occurred
        :param args: arguments to pass to 'fn'
        :param wait: block the thread and wait for the 'fn' to finish.
        :return: if 'wait' is True, the return value of 'fn' will be returned; None otherwise
        """

        def task_fn():
            success = False

            while not success:
                with self._pool.poll() as proxy:
                    session = requests.Session()

                    session.proxies.update(proxy.info().as_requests_dict())
                    session.headers.update({'User-Agent': random_agent()})

                    timer = SimpleTimer()
                    try:
                        with timer:
                            result = fn(session, *args)

                        if type(result) is tuple:
                            callback_async(*result)
                        else:
                            callback_async(result)

                        success = True
                        return result
                    except Exception as ex:
                        exception_handle(proxy.info(), ex)
                    finally:
                        proxy.update_response_time(timer.time_elapsed())

                        if success:
                            self._on_success()
                        else:
                            self._on_fail()

                        if self._verbose:
                            print("success rate:", self.success_rate())

        task = Task(task_fn)
        self._executor_queue.put(task)

        if wait:
            return task.get()

    def start(self):
        self._pool.start()

        def fn():
            while self._executor_running:
                try:
                    task = self._executor_queue.get(False, timeout=1)
                    task()

                    self._executor_queue.task_done()
                except Empty:
                    time.sleep(1)
                    continue

        self._executor_running = True
        for _ in range(self._executor_workers):
            self._executor.submit(fn)

    def stop(self):
        self._executor_queue.join()

        self._executor_running = False
        self._executor.shutdown(True)

        self._pool.end()

    def __enter__(self):
        self.start()

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.stop()


class SessionProxy(Session):
    def __init__(self, manager: ProxyManager):
        super().__init__()
        assert manager._executor_running, "manager hasn't started."

        self._manager = manager

    def request(self, method, url,
                params=None, data=None, headers=None, cookies=None, files=None,
                auth=None, timeout=None, allow_redirects=True, proxies=None,
                hooks=None, stream=True, verify=None, cert=None, json=None):

        def fn(session: Session):
            if headers is not None:
                session.headers.update(headers)

            return session.request(method, url,
                                   params, data, headers, cookies, files,
                                   auth, timeout, allow_redirects, None,
                                   hooks, stream, verify, cert, json)

        return self._manager.proxy_session(fn, wait=True)
