from functools import reduce
from .model import *
from .select import *
from .query import *
from .util import key_to_index


class CRUDMixin(SelectMixin):
    """
    This class provides class methods available on all model types.

    Every method takes the DB connection object as its first argument.

    Also, arguments listed below are commonly used in some methods of this mixin class:

    - `pks` represents value(s) of primary key(s).
        - If `pks` is a dictionary, each item is considered to be a pair of column name and value respectively.
        - Otherwise, the model must have a single primary key and `pks` is its value.
    - `qualifier` is a dictionary whose key is a column name and the value is a function generating SQL expression around the placeholder for the column.
        - This argument is used to overwrite the expression generated by `Marker` in insert and update query.
        - The function must take an argument to which the default expression is passed.
        - Be aware that the actual value stored in database can be different from value set on model when this argument exists.
    """
    @classmethod
    def count(cls, db, condition=Q.of()):
        """
        Count rows which satisfies the condition.

        Parameters
        ----------
        db: Connection
            DB connection.
        condition: Conditional
            Condition.

        Returns
        -------
        int
            The number of rows.
        """
        wc, wp = where(condition)
        c = db.stmt().execute(f"SELECT COUNT(*) FROM {cls.name}{_spacer(wc)}", *wp)
        return c.fetchone()[0]

    @classmethod
    def fetch(cls, db, pks, lock=None):
        """
        Fetch a record by primary key(s).

        Parameters
        ----------
        db: Connection
            DB connection.
        pks: object | {str: object}
            Representation of primary key(s).
        lock: object
            An object whose string expression is appended to query.

        Returns
        -------
        cls
            A model of the record.
        """
        cols, vals = parse_pks(cls, pks)
        cond = Conditional.all([Q.eq(**{c: v}) for c, v in zip(cols, vals)])
        wc, wp = where(cond)
        s = cls.select()
        c = db.stmt().execute(f"SELECT {s} FROM {cls.name}{_spacer(wc)}{_spacer(lock)}", *wp)
        row = c.fetchone()
        return read_row(row, s)[0] if row else None

    @classmethod
    def fetch_where(cls, db, condition=Q.of(), orders={}, limit=None, offset=None, lock=None):
        """
        Fetch records which satisfy a condition.

        Parameters
        ----------
        db: Connection
            DB connection.
        condition: Conditional
            Condition.
        orders: {str: bool}
            Ordering parameters. Each key is column and its value denotes direction; `True` is ascending and `False` is descending.
        limit: int
            The number of rows to fetch. If `None`, all rows are obtained.
        offset: int
            The number of rows to skip.
        lock: object
            An object whose string expression is appended to query.

        Returns
        -------
        [cls]
            Models of records.
        """
        wc, wp = where(condition)
        rc, rp = ranged_by(limit, offset)
        s = cls.select()
        c = db.stmt().execute(f"SELECT {s} FROM {cls.name}{_spacer(wc)}{_spacer(order_by(orders))}{_spacer(rc)}{_spacer(lock)}", *(wp + rp))
        return [read_row(row, s)[0] for row in c.fetchall()]

    @classmethod
    def insert(cls, db, record, qualifier={}):
        """
        Insert a record.

        Returned model object contains auto incremental column even if they are not set beforehand.
        On the contrary, default value defined on column is not set.

        Parameters
        ----------
        db: Connection
            DB connection.
        record: cls | {str: object}
            Model object or dictionary of columns and values.
        qualifier: {str: str -> str}
            Functions converting place holders.

        Returns
        -------
        Model
            Model object.
        """
        record = record if isinstance(record, cls) else cls(**record)
        value_dict = model_values(cls, record)
        check_columns(cls, value_dict)
        cols, vals = list(value_dict.keys()), list(value_dict.values())
        qualifier = key_to_index(qualifier, cols)

        db.stmt().execute(f"INSERT INTO {cls.name} ({', '.join(cols)}) VALUES {values(len(cols), 1, qualifier)}", *vals)

        for c, v in cls.last_sequences(db, 1):
            setattr(record, c.name, v)

        return record

    @classmethod
    def update(cls, db, pks, values, qualifier={}):
        """
        Update a record by primary key(s).

        This method only updates columns which are found in `values` and not primary key.

        Parameters
        ----------
        db: Connection
            DB connection.
        pks: object | {str: object}
            Representation of primary key(s).
        values: cls | {str: object}
            Model object or dictionary of columns and values.
        qualifier: {str: str -> str}
            Functions converting place holders.

        Returns
        -------
        bool
            Boolean which shows whether a row is updated or not. When there exists no row having given primary key(s), `False` is returned.
        """
        cols, vals = parse_pks(cls, pks)
        return cls.update_where(db, values, Conditional.all([Q.eq(**{c: v}) for c, v in zip(cols, vals)]), qualifier) == 1

    @classmethod
    def update_where(cls, db, values, condition, qualifier={}, allow_all=True):
        """
        Update records which satisfy a condition.

        Parameters
        ----------
        db: Connection
            DB connection.
        values: cls | {str: object}
            Model object or dictionary of columns and values.
        condition: Conditional
            Condition.
        qualifier: {str: str -> str}
            Functions converting place holders.
        allow_all: bool
            Empty condition raises an exception if this is `False`.

        Returns
        -------
        int
            The number of affected rows.
        """
        value_dict = model_values(cls, values, excludes_pk=True)
        check_columns(cls, value_dict)
        cols, vals = list(value_dict.keys()), list(value_dict.values())
        qualifier = key_to_index(qualifier, cols)

        def set_col(acc, icv):
            i, (c, v) = icv
            if isinstance(v, Expression):
                clause = f"{c} = {qualifier.get(i, lambda x:x)(v.expression)}"
                params = v.params
            else:
                clause = f"{c} = {qualifier.get(i, lambda x:x)('$_')}"
                params = [v]
            acc[0].append(clause)
            acc[1].extend(params)
            return acc

        setters, params = reduce(set_col, enumerate(zip(cols, vals)), ([], []))

        wc, wp = where(condition)
        if wc == "" and not allow_all:
            raise ValueError("Update query to update all records is not allowed.")

        c = db.stmt().execute(f"UPDATE {cls.name} SET {', '.join(setters)}{_spacer(wc)}", *(params + wp))

        return getattr(c, "rowcount", None)

    @classmethod
    def delete(cls, db, pks):
        """
        Delete a record by primary key(s).

        Parameters
        ----------
        db: Connection
            DB connection.
        pks: object | {str: object}
            Representation of primary key(s).

        Returns
        -------
        bool
            Boolean which shows whether a row is deleted or not. When there exists no row having given primary key(s), `False` is returned.
        """
        cols, vals = parse_pks(cls, pks)
        return cls.delete_where(db, Conditional.all([Q.eq(**{c: v}) for c, v in zip(cols, vals)])) == 1

    @classmethod
    def delete_where(cls, db, condition, allow_all=True):
        """
        Delete records which fulfill a condition.

        Parameters
        ----------
        db: Connection
            DB connection.
        condition: Conditional
            Condition.
        allow_all: bool
            Empty condition raises an exception if this is `False`.

        Returns
        -------
        int
            The number of affected rows.
        """
        wc, wp = where(condition)
        if wc == "" and not allow_all:
            raise ValueError("Delete query to delete all records is not allowed.")

        c = db.stmt().execute(f"DELETE FROM {cls.name}{_spacer(wc)}", *wp)

        return getattr(c, "rowcount", None)

    @classmethod
    def last_sequences(cls, db, num):
        """
        Returns latest auto generated numbers in this table.

        Parameters
        ----------
        db: Connection
            DB connection.
        num: int
            The number of records inserted by the latest insert query.

        Returns
        -------
        [(Column, int)]
            A list of pairs of column and the generated number.
        """
        return []


def _spacer(s):
    return (" " + str(s)) if s else ""