from copy import deepcopy
from multiprocessing import Manager
from threading import Thread
from typing import Sequence, Iterable, Any

from coba.exceptions import CobaFatal
from coba.config     import CobaConfig, BasicLogger, IndentLogger
from coba.pipes      import Filter, Sink, Pipe, QueueIO, MultiprocessFilter

class CobaMultiprocessFilter(Filter[Iterable[Any], Iterable[Any]]):

    class ConfiguredFilter:

        def __init__(self, filters: Sequence[Filter], logger_sink: Sink, with_name:bool) -> None:

            self._logger = deepcopy(CobaConfig.logger)
            self._cacher = deepcopy(CobaConfig.cacher)

            if isinstance(self._logger, IndentLogger):
                self._logger._with_name = with_name
                self._logger._sink      = logger_sink

            if isinstance(self._logger, BasicLogger):
                self._logger._with_name = with_name
                self._logger._sink      = logger_sink

            self._filters = filters

        def filter(self, item: Iterable[Any]) -> Iterable[Any]:

            #placing this here means this is only set inside the process 
            CobaConfig.logger = self._logger
            CobaConfig.cacher = self._cacher

            return Pipe.join(self._filters).filter(item)

    def __init__(self, filters: Sequence[Filter], processes=1, maxtasksperchild=None) -> None:
        self._filters          = filters
        self._processes        = processes
        self._maxtasksperchild = None if maxtasksperchild == -1 else maxtasksperchild

    def filter(self, items: Iterable[Any]) -> Iterable[Any]:

        try:

            with Manager() as manager:

                stderr = QueueIO(manager.Queue())

                def log_stderr():
                    for err in stderr.read():
                        if isinstance(err,str):
                            CobaConfig.logger.sink.write(err)
                        elif isinstance(err,Exception):
                            CobaConfig.logger.log_exception(err)
                        else:
                            CobaConfig.logger.log_exception(err[2])
                            print("".join(err[3]))

                log_thread = Thread(target=log_stderr)
                log_thread.daemon = True
                log_thread.start()

                filter = CobaMultiprocessFilter.ConfiguredFilter(self._filters, stderr, self._processes>1)

                for item in MultiprocessFilter([filter], self._processes, self._maxtasksperchild, stderr).filter(items):
                    yield item

        except RuntimeError as e:
            #This happens when importing main causes this code to run again
            raise CobaFatal(str(e))