#!/usr/bin/env python3
"""Apply fixes to an existing database

This script isn't for general users

"""


import argparse
import logging
import sqlite3
import os
import sys

import coloredlogs
import yaml


_logger = logging.getLogger()

fixes_dir = os.path.dirname(sys.argv[0])
fixes_fn = fixes_dir + "/fixes.yaml"


def parse_args(argv):
    ap = argparse.ArgumentParser(
        description = __doc__,
        )
    ap.add_argument("--db-filename", "-d",
                    required=True)
    ap.add_argument("--fix", "-f",
                    required=True)
    ap.add_argument("--fixes-filename", "-F",
                    default=fixes_fn)
    ap.add_argument("--start-step", "-s",
                    default=1)
    ap.add_argument("--vacuum", "-V",
                    default=False,
                    action="store_true")
    opts = ap.parse_args(argv)
    return opts

    

if __name__ == "__main__":
    coloredlogs.install(level="INFO")

    opts = parse_args(sys.argv[1:])
    
    fixes = yaml.load(open(opts.fixes_filename), Loader=yaml.SafeLoader)
    _logger.info(f"Read {len(fixes)} fixes from {opts.fixes_filename}")

    db = sqlite3.connect(opts.db_filename)
    _logger.info(f"Opened {opts.db_filename}")

    
    fix_key = opts.fix
    fix = fixes[fix_key]
    commands = fix["sql"]
    _logger.info(f"Starting fix {fix_key}: {fix['descr']} ({len(commands)} commands)")
    for command_i, command in enumerate(commands):
        if command_i + 1 < opts.start_step:
            _logger.info(f"{fix_key}#{command_i+1}: skipping")
            continue
        _logger.info(f"{fix_key}#{command_i+1}: {command}")
        db.execute(command)
        db.commit()

    if "tests" not in fix:
        _logger.warning("No tests defined")
    else:
        test_ok = 0
        for test in fix["tests"]:
            cur = db.execute(test["sql"])
            result = cur.fetchone()[0]
            if str(result) == str(test["expected"]):
                test_ok += 1
            else:
                _logger.error(f"Failed {test['sql']}; expected {test['expected']}; got {result}")
        _logger.info(f"Passed {test_ok}/{len(fix['tests'])} tests")


    if opts.vacuum:
        _logger.info("Vacuuming")
        db.execute("vacuum")
