import argparse
import hashlib
import os
import os.path as osp
from typing import List, Tuple

from ..files.base import find_all_file
from ..tools.logger import Logger
from .tasks import multi_tasks


def md5sum(filename: str) -> str:
    """计算文件的md5值

    copy from https://www.aiuai.cn/aifarm1114.html
    """
    f = open(filename, "rb")
    md5 = hashlib.md5()
    while True:
        fb = f.read(8096)
        if not fb:
            break
        md5.update(fb)
    f.close()
    return md5.hexdigest()


def __md5sum_folder(folder: str, num_workers: int = 10) -> List[Tuple[str, str]]:
    """计算文件夹的md5值(支持多级目录)

    Args:
        folder (str): 文件夹

    Returns:
        List[Tuple[str, str]]: [[md5, filename],...]
    """
    md5_list = []
    filename_list = list(find_all_file(folder))
    md5_list = multi_tasks(md5sum, filename_list, return_results=True, desc="md5sum", num_workers=num_workers)
    return zip(md5_list, filename_list)


def save_md5sum(path: str, md5_filename: str = "md5.txt", num_workers: int = 10) -> None:
    """保存MD5到文件

    Args:
        path (str): file or folder
        md5_filename (str): md5.txt
    """
    if osp.exists(md5_filename):
        Logger.warn(f"[exists]: {md5_filename}")
        return
    Logger.exists(path)
    if osp.isfile(path):
        md5_filename_list = [[md5sum(path), path]]
    elif osp.isdir(path):
        md5_filename_list = __md5sum_folder(path, num_workers=num_workers)

    remove_head = path
    if not remove_head.endswith("/"):
        remove_head = remove_head + "/"
    with open(md5_filename, "w", encoding="utf-8") as fw:
        for md5, filename in md5_filename_list:
            if filename.startswith(remove_head):
                filename = filename[len(remove_head) :]
            fw.write(f"{md5} {filename}\n")


def check_md5sum(path: str, md5_filename: str = "md5.txt", num_workers: int = 10) -> None:
    """检查MD5

    Args:
        path (str): folder
        md5_filename (str): md5.txt
    """
    if not Logger.exists(md5_filename):
        return
    md5_list = []
    with open(md5_filename, "r") as fr:
        for line in fr.readlines():
            md5, filename = line.replace("  ", " ").strip().split(" ")
            md5_list.append([md5, filename])

    def check(md5, filename):
        full_path = osp.join(path, filename)
        if Logger.exists(full_path):
            if md5sum(full_path) != md5:
                Logger.error(f"[文件被修改]:{filename}")

    # TODO 需要检查新增文件
    multi_tasks(check, md5_list, return_results=False, desc="check_md5sum", num_workers=num_workers)


def run_save_md5sum():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cpu", default=10, type=int, help="num_workers")

    args = parser.parse_args()
    save_md5sum(".", num_workers=args.cpu)


def run_check_md5sum():
    parser = argparse.ArgumentParser()
    parser.add_argument("--cpu", default=10, type=int, help="num_workers")

    args = parser.parse_args()
    check_md5sum(".", num_workers=args.cpu)
