import sqlite3
from typing import List

from .logger import Logger


class SqliteDatabase(object):
    """Sqlite3 database interface
    Software to open: DB Browser for SQLite https://sqlitebrowser.org/
    """

    def __init__(self, db_filename: str) -> None:
        """
        db_filename: 数据库文件名  *.db
        """

        self.conn = None
        self.cursor = None
        self.db_filename = None
        if db_filename is not None:
            self.connect(db_filename)

    def connect(self, db_filename: str):
        """Connect to new database"""

        self.close()

        self.conn = sqlite3.connect(db_filename)
        self.cursor = self.conn.cursor()
        self.db_filename = db_filename
        Logger.info(f"Connect database:{self.db_filename}")

    def close(self):
        """Close the database"""

        if self.conn is not None:
            self.conn.close()
            Logger.info(f"Close database:{self.db_filename}")

        self.conn = None
        self.cursor = None
        self.db_filename = None

    def create_table(self, table_name: str):
        """create table"""
        table_name = table_name.upper()
        order = f"CREATE TABLE {table_name} (ID INTEGER PRIMARY KEY AUTOINCREMENT);"
        self.execute(order)
        Logger.warn(f"Create table: {table_name}")

    def create_index(self, table_name, index_name, column_name):
        """添加索引"""
        order = f"CREATE INDEX {index_name} on {table_name} ({column_name});"
        self.execute(order.upper())

    def add_column(self, table_name: str, column_name: str, column_type: str):
        """添加字段  column_type in ["TEXT", "INT"]"""
        assert column_type in ["TEXT", "INT"]
        order = f"ALTER TABLE {table_name} ADD COLUMN {column_name} {column_type}"
        self.execute(order.upper())
        Logger.warn(f"添加字段: {column_name}")

    def insert_batch(self, table_name: str, column_name_list: List[str], value_list: List[List], commit=True):
        """批量更新数据"""
        table_name = table_name.upper()
        # 检查数据有效性
        if len(value_list) == 0 or len(column_name_list) != len(value_list[0]):
            Logger.error("数据维度不一致，插入无效")
            return
        columns = ",".join(column_name_list).upper()
        value_str = ",".join(
            [
                "?",
            ]
            * len(column_name_list)
        )
        order = f"INSERT INTO {table_name} ({columns}) VALUES ({value_str});"
        self.cursor.executemany(order, value_list)
        if commit:
            self.conn.commit()

    def insert(self, table_name: str, column_name_list: str, value_list: List, commit=True):
        """插入一行"""
        self.insert_batch(
            table_name,
            column_name_list,
            [
                value_list,
            ],
            commit=commit,
        )

    def update(self, table_name: str, id_: int, column_name: str, value, commit=True):
        """更新数据"""
        table_name = table_name.upper()
        column_name = column_name.upper()
        if isinstance(value, str):
            value = f'"{value}"'
        order = f"UPDATE {table_name} SET {column_name} = {value} WHERE ID = {id_}"
        self.execute(order, commit=commit)

    def search(self, table_name: str, column_name: str, search_text: str, return_column_list: List[str] = None) -> List[List]:
        """
        table_name: 表名称
        column_name: 搜索的列名
        search_text: 搜索的文本
        return_column_list: 需要返回的列
        """
        table_name = table_name.upper()
        column_name = column_name.upper()
        columns = "*" if return_column_list is None else ",".join(return_column_list)
        columns = columns.upper()

        order = f"SELECT {columns} from {table_name} where {column_name} glob ?"
        res_list = self.execute(order, params=([f"*{search_text}*"]))
        retults = [list(res) for res in res_list]
        return retults

    def max_id(self, table_name: str) -> int:
        """获取最大ID"""
        table_name = table_name.upper()
        order = f"SELECT MAX(ID) FROM {table_name}"
        res = self.execute(order)
        max_id = -1
        for item in res:
            max_id = item[0]

        if max_id == -1:
            Logger.error("There is a serious problem with the database!!!")
            return
        return max_id

    def execute(self, order: str, params=None, commit: bool = True):
        """直接执行语句"""
        if params:
            res_list = self.cursor.execute(order, params)
        else:
            res_list = self.cursor.execute(order)
        if commit:
            self.conn.commit()

        # This line is necessary
        res_list = [res for res in res_list]
        return res_list

    def commit(self):
        self.conn.commit()
