"""The coba.pipes.filters module contains core classes for filters used in data pipelines."""

import re
import csv
import itertools
import json
import math
import collections.abc

from itertools import islice
from collections import OrderedDict
from typing import Iterable, Any, Sequence, Union, Tuple, List, Dict, Callable, Optional

from coba.random import CobaRandom
from coba.encodings import Encoder, OneHotEncoder, NumericEncoder, StringEncoder, CobaJsonEncoder, CobaJsonDecoder
from coba.exceptions import CobaException

from coba.pipes.primitives import Filter

_T_DenseRow   = Sequence[Any]
_T_SparseRow  = Dict[Any,Any]
_T_DenseData  = Iterable[_T_DenseRow]
_T_SparseData = Iterable[_T_SparseRow]

_T_Row        = Union[_T_DenseRow,  _T_SparseRow ]
_T_Data       = Union[_T_DenseData, _T_SparseData]

class Identity(Filter[Any, Any]):
    def filter(self, item:Any) -> Any:
        return item

class Shuffle(Filter[Iterable[Any], Iterable[Any]]):

    def __init__(self, seed:Optional[int]) -> None:

        if seed is not None and (not isinstance(seed,int) or seed < 0):
            raise ValueError(f"Invalid parameter for Shuffle: {seed}. An optional integer value >= 0 was expected.")

        self._seed = seed

    def filter(self, items: Iterable[Any]) -> Iterable[Any]: 
        return CobaRandom(self._seed).shuffle(list(items))

class Take(Filter[Iterable[Any], Iterable[Any]]):
    """Take a given number of items from an iterable."""

    def __init__(self, count:Optional[int]) -> None:
        """Instantiate a Take filter.

        Args:
            count: The number of items we wish to take from the given iterable.
        """

        if count is not None and (not isinstance(count,int) or count < 0):
            raise ValueError(f"Invalid parameter for count: {count}. An optional integer value >= 0 was expected.")
        
        self._count = count

    def filter(self, items: Iterable[Any]) -> Iterable[Any]:
        items =  list(islice(items,self._count))
        return items if len(items) == self._count else []

class Reservoir(Filter[Iterable[Any], Iterable[Any]]):
    def __init__(self, count:Optional[int], seed: int = 1, keep_first:bool = False) -> None:
        """Instantiate a Resevoir filter.

        Args:
            count     : The number of items we wish to take from the given iterable.
            seed      : An optional random seed to determine which random count items to take.
            keep_first: Indicate whether the first row should be kept as is (useful for files with headers).

        Remarks:
            We use Algorithm L as described by Kim-Hung Li. (1994) to take a random count of items.

        References:
            Kim-Hung Li. 1994. Reservoir-sampling algorithms of time complexity O(n(1 + log(N/n))). 
            ACM Trans. Math. Softw. 20, 4 (Dec. 1994), 481–493. DOI:https://doi.org/10.1145/198429.198435
        """
        
        if count is not None and (not isinstance(count,int) or count < 0):
            raise ValueError(f"Invalid parameter for Take: {count}. An optional integer value >= 0 was expected.")

        self._count      = count
        self._seed       = seed
        self._keep_first = keep_first

    def filter(self, items: Iterable[Any]) -> Iterable[Any]:

        items    = iter(items)
        first    = [next(items)] if self._keep_first else []
        resevoir = list(islice(items,self._count))

        this_count = len(resevoir) if self._count is None else self._count

        if this_count == 0:
            return []

        if self._seed is not None:
            rng = CobaRandom(self._seed)
            W = 1

            try:
                while True:
                    [r1,r2,r3] = rng.randoms(3)
                    W = W * math.exp(math.log(r1)/this_count)
                    S = math.floor(math.log(r2)/math.log(1-W))
                    resevoir[int(r3*this_count-.001)] = next(itertools.islice(items,S,S+1))
            except StopIteration:
                pass

        return itertools.chain( first, resevoir if len(resevoir) == this_count else [])

class JsonEncode(Filter[Any, str]):
 
    def _min(self,obj):
        #WARNING: This method doesn't handle primitive types such int, float, or str. We handle this shortcoming
        #WARNING: by making sure no primitive type is passed to this method in filter. Accepting the shortcoming
        #WARNING: improves the performance of this method by a few percentage points. 

        #JsonEncoder writes floats with .0 regardless of if they are integers so we convert them to int to save space
        #JsonEncoder also writes floats out 16 digits so we truncate them to 5 digits here to reduce file size

        if isinstance(obj,tuple):
            obj = list(obj)
            kv  = enumerate(obj) 
        elif isinstance(obj,list):
            kv = enumerate(obj)
        elif isinstance(obj,dict):
            kv = obj.items()
        else:
            return obj

        for k,v in kv:
            if isinstance(v, (int,str)):
                obj[k] = v
            elif isinstance(v, float):
                if v.is_integer():
                    obj[k] = int(v) 
                elif math.isnan(v) or math.isinf(v):
                    obj[k] = v                    
                else: 
                    #rounding by any means is considerably slower than this crazy method
                    #we format as a truncated string and then manually remove the string
                    #indicators from the json via string replace methods
                    obj[k] = f"|{v:0.5g}|" 
            else:
                obj[k] = self._min(v)

        return obj

    def __init__(self, minify=True) -> None:
        self._minify = minify

        if self._minify:
            self._encoder = CobaJsonEncoder(separators=(',', ':'))
        else:
            self._encoder = CobaJsonEncoder()

    def filter(self, item: Any) -> str:        
        return self._encoder.encode(self._min([item])[0] if self._minify else item).replace('"|',"").replace('|"',"")

class JsonDecode(Filter[str, Any]):
    def __init__(self, decoder: json.decoder.JSONDecoder = CobaJsonDecoder()) -> None:
        self._decoder = decoder

    def filter(self, item: str) -> Any:
        return self._decoder.decode(item)

class ArffReader(Filter[Iterable[str], _T_Data]):
    """
        https://waikato.github.io/weka-wiki/formats_and_processing/arff_stable/
    """

    def __init__(self, skip_encoding: Union[bool,Sequence[Union[str,int]]] = False, **dialect):

        self._skip_encoding = skip_encoding

        # Match a comment
        self._r_comment = re.compile(r'^%')

        # Match an empty line
        self.r_empty = re.compile(r'^\s+$')

        #@ lines give metadata describing the file. These always come at the top of the file
        self._r_meta = re.compile(r'^\s*@\S*')

        #The @relation line simply names the data. In practice we don't really care about it.
        self._r_relation = re.compile(r'^@[Rr][Ee][Ll][Aa][Tt][Ii][Oo][Nn]\s*(\S*)')

        #The @attribute lines contain typing information for 'columns'
        self._r_attribute = re.compile(r'^\s*@[Aa][Tt][Tt][Rr][Ii][Bb][Uu][Tt][Ee]\s*(..*$)')

        #The @data line indicates when the data begins. After @data there should be no more @ lines.
        self._r_data = re.compile(r'^@[Dd][Aa][Tt][Aa]')

        self._dialect = dialect

    def _determine_encoder(self, index:int, name: str, tipe: str) -> Encoder:

        is_numeric = tipe in ['numeric', 'integer', 'real']
        is_one_hot = '{' in tipe

        if self._skip_encoding != False and (self._skip_encoding == True or index in self._skip_encoding or name in self._skip_encoding):
            return StringEncoder()

        if is_numeric: return NumericEncoder()
        if is_one_hot: return OneHotEncoder([v.strip() for v in tipe.strip("}{").split(',')])

        return StringEncoder()

    def _parse_file(self, lines: Iterable[str]) -> Tuple[_T_Data, Dict[str,Encoder]]:
        in_meta_section=True
        in_data_section=False

        headers   : List[str        ] = []
        encoders  : Dict[str,Encoder] = {}
        data_lines: List[str        ] = []

        for line in lines:

            line = line.strip()

            if in_meta_section:

                if self._r_comment.match(line): continue
                if self._r_relation.match(line): continue

                attribute_match = self._r_attribute.match(line)

                if attribute_match:
                    attribute_text  = attribute_match.group(1).strip()
                    attribute_type  = re.split('[ ]', attribute_text, 1)[1]
                    attribute_name  = re.split('[ ]', attribute_text)[0]
                    attribute_index = len(headers)

                    headers.append(attribute_name)
                    encoders[attribute_name] = self._determine_encoder(attribute_index,attribute_name,attribute_type)

                if self._r_data.match(line):
                    in_data_section = True
                    in_meta_section = False
                    continue

            if in_data_section and line != '':
                data_lines.append(line)

        parsed_data = CsvReader(True, **self._dialect).filter(itertools.chain([",".join(headers)], data_lines))

        return parsed_data, encoders

    def filter(self, source: Iterable[str]) -> _T_Data:

        data, encoders = self._parse_file(source)

        return data if self._skip_encoding == True else Encode(encoders).filter(data)

class CsvReader(Filter[Iterable[str], _T_Data]):

    def __init__(self, has_header: bool, **dialect):
        self._has_header = has_header
        self._dialect = dialect

    def filter(self, items: Iterable[str]) -> _T_Data:

        lines = iter(filter(None, csv.reader( (i.strip() for i in items), **self._dialect)))

        try:
            header     = [ h.strip().strip('\'"') for h in next(lines)] if self._has_header else None
            first_data = next(lines)
        except StopIteration:
            return [header] #[None] because every other filter method assumes there is some kind of a header row.

        is_sparse = first_data[0].startswith("{") and first_data[-1].endswith("}")
        parser    = self._sparse_parser if is_sparse else self._dense_parser

        return parser(header, itertools.chain([first_data], lines))

    def _dense_parser(self, header: Optional[Sequence[str]],  lines: Iterable[Sequence[str]]) -> Iterable[Sequence[str]]:
        return itertools.chain([header], lines)

    def _sparse_parser(self, header: Optional[Sequence[str]],  lines: Iterable[Sequence[str]]) -> Iterable[Dict[int,str]]:
        yield header if header is None else OrderedDict(zip(header, itertools.count()))

        for line in lines:
            if len(line) > 1 or line[0] != '{}':
                yield OrderedDict((int(k),v) for l in line for k,v in [l.strip("}{").strip().split(' ', 1)])

class LibSvmReader(Filter[Iterable[str], _T_SparseData]):

    """https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/"""
    """https://github.com/cjlin1/libsvm"""

    def filter(self, lines: Iterable[str]) -> _T_SparseData:

        yield None # we yield None because there is no header row

        for line in filter(None,lines):

            items  = line.strip().split(' ')
            labels = items[0].split(',')
            row    = { int(k):float(v) for i in items[1:] for k,v in [i.split(":")] }
            row[0] = labels

            yield row

class ManikReader(Filter[Iterable[str], _T_SparseData]):

    """http://manikvarma.org/downloads/XC/XMLRepository.html"""
    """https://drive.google.com/file/d/1u7YibXAC_Wz1RDehN1KjB5vu21zUnapV/view"""

    def filter(self, lines: Iterable[str]) -> _T_SparseData:

        # we skip first line because it just has metadata
        return LibSvmReader().filter(islice(lines,1,None))

class Encode(Filter[_T_Data, _T_Data]):

    def __init__(self, encoders: Dict[str,Encoder], fit_using=None, has_header:bool = True):
        self._encoders   = encoders
        self._fit_using  = fit_using
        self._has_header = has_header

    def filter(self, items: _T_Data) -> _T_Data:

        items  = iter(items) # this makes sure items are pulled out for fitting

        if not self._has_header:
            header_synced_encoders = self._encoders

        else:
            header = next(items)

            if header is None:
                yield header
                header_synced_encoders = self._encoders

            elif isinstance(header, dict):
                yield header
                header_synced_encoders = { header[k]:encoder for k,encoder in self._encoders.items() if k in header } 

            elif isinstance(header, list):
                yield header
                header_synced_encoders = { header.index(k):encoder for k,encoder in self._encoders.items() if k in header }

            else:
                raise CobaException(f"Unrecognized type ({type(header).__name__}) passed to Encodes.")

        fit_using  = 0 if all([e.is_fit for e in header_synced_encoders.values()]) else self._fit_using
        fit_items  = list(islice(items, fit_using))
        fit_values = collections.defaultdict(list)

        for item in fit_items:
            for k,v in (item.items() if isinstance(item,dict) else enumerate(item)):
                fit_values[k].append(v)

        for k,v in fit_values.items():
            if not header_synced_encoders[k].is_fit:
                header_synced_encoders[k] = header_synced_encoders[k].fit(v)

        for item in itertools.chain(fit_items, items):

            for k,v in (item.items() if isinstance(item,dict) else enumerate(item)):
                item[k] = header_synced_encoders[k].encode([v])[0]

            yield item

class Drop(Filter[_T_Data, _T_Data]):

    def __init__(self, drop_cols: Sequence[Any] = [], drop_row: Callable[[_T_Row], bool] = None) -> None:
        self._drop_cols = drop_cols
        self._drop_row  = drop_row

    def filter(self, data: _T_Data) -> _T_Data:
        
        if not self._drop_cols and not self._drop_row: return data
        else:

            data   = iter(data)
            header = next(data)

            if header is None:
                drop_keys = sorted(self._drop_cols,reverse=True)
            elif isinstance(header,dict):
                drop_keys = sorted([ header.pop(k) for k in self._drop_cols ],reverse=True)
            elif isinstance(header,list):
                drop_keys = sorted([ header.index(k) for k in self._drop_cols ],reverse=True)
                for i in drop_keys: header.pop(i)
            else:
                raise CobaException(f"Unrecognized type ({type(header).__name__}) passed to Drops.")

            yield header

            for row in data:

                if self._drop_row and self._drop_row(row):
                    continue

                for k in drop_keys:
                    row.pop(k)

                yield row

class Structure(Filter[_T_Data, Iterable[Any]]):

    def __init__(self, split_cols: Sequence[Any]) -> None:
        self._col_structure = split_cols

    def filter(self, data: _T_Data) -> _T_Data:

        data   = iter(data)
        header = next(data)

        key_structure = self._recursive_structure_keys(header, self._col_structure)

        for row in data :
           yield self._recursive_structure_rows(row, key_structure) 

    def _recursive_structure_keys(self, header, cols):

        if header is None:
            return cols

        elif isinstance(cols,(list,tuple)):
            return [ self._recursive_structure_keys(header,s) for s in cols ]

        elif cols is None:
            return cols

        elif isinstance(header,list):
            return header.index(cols)

        elif isinstance(header,dict):
            return header[cols]

        else:
            raise CobaException(f"Unrecognized type ({type(header).__name__}) passed to Structure.")

    def _recursive_structure_rows(self, row, keys):
        if keys is None:
            return row
        elif isinstance(keys,(list,tuple)):
            return [ self._recursive_structure_rows(row,k) for k in keys ]
        else:
            return row.pop(keys)

class Flatten(Filter[_T_Data, _T_Data]):

    def filter(self, data: _T_Data) -> _T_Data:

        for row in data:

            if isinstance(row,dict):
                for k in list(row.keys()):
                    if isinstance(row[k],(list,tuple)):
                        row.update([((k,i), v) for i,v in enumerate(row.pop(k))])

            elif isinstance(row,list):
                for k in range(len(row)):
                    if isinstance(row[k],(list,tuple)):
                        row.extend(row.pop(k))

            else:
                raise CobaException(f"Unrecognized type ({type(row).__name__}) passed to Flattens.")

            yield row

class Default(Filter[_T_Data, _T_Data]):

    def __init__(self, defaults: Dict[str, Any]) -> None:
        self._defaults = defaults

    def filter(self, data: _T_Data) -> _T_Data:

        if not self._defaults: return data
        else:

            data   = iter(data)
            header = next(data)

            if header is None:
                yield header
                header_synced_defaults = self._defaults

            elif isinstance(header, dict):
                yield header
                header_synced_defaults = { header[k]:encoder for k,encoder in self._defaults.items() } 

            elif isinstance(header, list):
                yield header
                header_synced_defaults = { header.index(k):encoder for k,encoder in self._defaults.items() }

            else:
                raise CobaException(f"Unrecognized type ({type(header).__name__}) passed to Defaults.")

            for row in data:

                if isinstance(row,dict):
                    for k,v in header_synced_defaults.items():
                        if k not in row:
                            row[k] = v

                yield row
