#!/home/cusano/Temp/python/lpips/venv/bin/python3

import photocmp
import PIL.Image
import numpy as np
import argparse
import sys
import os
import glob
import concurrent.futures


def parse_args():
    epilog = """ Compare sets of pictures.  The first set given is the reference
    set.  The others are compared against the reference.

    Sets can be sepcified in three different ways:
    (1) as a string including a '{filename}' pattern;
    (2) as a directory;
    (3) as a pattern like 'dir/*.png'.

    In the first case a list will be generated by replacing the
    pattern with the content of the file whose name is in the
    brackets.  One item for each line in the file will be generated.

    In the second case the generate list will include all the items in
    the directory.

    In the third case the list will include paths matching the
    pattern.  '**' can be used to recursively scan a directory.

    In the last two cases the resulting list is sorted.  In the first
    case sorting is enabled by the option --sort.

    Images are compared by following the order in the sets, unless the
    --math-names options is given.


    METRICS
    Available metrics are:
    - mse: Mean Squared Error
    - psnr: Peak Signal to Noise Ratio
    - ssim: Structural Similarity Index
    - lpips: Learned Perceptual Image Patch Similarity
    - lpips_vgg: as lpips but with VGG instead of AlexNet
    - lpips_squeeze: as lpips but with SqueezeNet instead of AlexNet

    The LPIPS metric requires that the "lpips" module is installed (pip install lpips).
    """
    parser = argparse.ArgumentParser(description="Compare sets of pictures.", epilog=epilog)
    a = parser.add_argument
    a("reference", help="Reference set")
    a("other", nargs="+", help="Set of pictures to compare")
    a("--metric", "-m", choices=list(METRICS), nargs="+", default=["mse", "psnr", "ssim"],
      help="difference metric(s)")
    a("--sort", "-s", action="store_true", help="sort the sets in alphabetical order")
    a("--match-names", "-n", action="store_true", help="match images by filename")
    a("--detail", "-d", action="store_true", help="show individual comparisons")
    a("--workers", "-t", type=int, help="number of parallel threads")
    a("--write", "-w", help="save the results in a file")

    def size(s):
        if s is None:
            return None
        w, x, h = s.lower().partition("x")
        return ((int(w), int(h)) if x == "x" else (int(w), int(w)))
            
    a("--resize", "-r", type=size, help="resize images to the given dimension WxH " +
      "(use a single dimension for square images)")
    args = parser.parse_args()
    return args


def replace_filename(path):
    """Generate a list by taking names from a file.

    Each element in the list will be path where {filename} is replaced
    by a name taken from file 'filename'.  The file must have one name
    in each line.  Raises ValueError if the pattern {...} is not found.

    """
    # 1) Find {filename} in path.
    end = path.rindex("}")
    start = path.rindex("{", 0, end)
    end = path.index("}", start)
    prefix = path[:start]
    suffix = path[end + 1:]
    filename = path[start + 1:end]
    # 2) Read the content and generate the strings.
    with open(filename) as f:
        stripped = map(str.strip, f)
        return [prefix + s + suffix for s in stripped if s]


def imagename(path):
    """Remove directories and the extension."""
    return os.path.splitext(os.path.basename(path))[0]


def match_names(reference, other):
    """Reorder the paths in other to match the filenames in reference.

    ValueError is raised if names in reference cannot be found in
    other, or if they matches more than one path.

    """
    byname = {}
    for path in other:
        name = imagename(path)
        if name in byname:
            byname[name] = None  # Used to detect multiple matching names
        else:
            byname[name] = path
    result = []
    for path in reference:
        name = imagename(path)
        if name not in byname:
            raise ValueError(f"Non-matching name in path '{path}'")
        elif byname[name] is None:
            raise ValueError(f"Multiple matching names for '{name}'")
        result.append(byname[name])
    return result


def load_set(lst, sort):
    """Load paths from lst and opptionally sort them alphabetically.

    lst can be:
      - a string including the {filename} pattern;
      - a directory;
      - a pattern like dir/*.png

    In the first case a list will be generated by replacing the
    pattern with the content of the file whose name is in the
    brackets.  One item for each line in the file will be generated.

    In the second case the generate list will include all the items in
    the directory.

    In the third case the list will include paths matching the
    pattern.  '**' can be used to recursively scan a directory.

    In the last two cases the resulting list is sorted.  In the first
    case sorting is enabled by the corresponding parameter.

    """
    lst = os.path.expanduser(lst)
    if os.path.isdir(lst):
        return sorted(os.path.join(lst, s) for s in os.listdir(lst))
    try:
        ret = replace_filename(lst)
        if sort:
            ret.sort()
        return ret
    except ValueError:
        pass
    return sorted(glob.glob(lst, recursive=True))


def load_image(path, image_size=None):
    """Load an image given the path to the file.

    If image_size is not None, the image is resampled to the given
    size.

    """
    img = PIL.Image.open(path).convert("RGB")
    if image_size is not None and image_size != (img.width, img.height):
        img = img.resize(image_size, PIL.Image.BILINEAR)
    arr = np.array(img)
    return arr


METRICS = {
    "mse": photocmp.mse,
    "psnr": photocmp.psnr,
    "ssim": photocmp.ssim,
    "lpips": photocmp.lpips,
    "lpips_vgg": (lambda x, y: photocmp.lpips(x, y, net="vgg")),
    "lpips_squeeze": (lambda x, y: photocmp.lpips(x, y, net="squeeze"))
}


def compare(paths, metrics, resize):
    """Load the images and compute the metrics given by name."""
    reference = load_image(paths[0], resize)
    if resize is None:
        resize = (reference.shape[1], reference.shape[0])
    others = [load_image(p, resize) for p in paths[1:]]
    return {
        mname: [
            METRICS[mname](reference, other) for other in others
        ] for mname in metrics
    }


def display_str(s, length):
    """Make sure that s has exactly the requested length."""
    if len(s) <= length:
        return s + " " * (length - len(s))
    elif length < 6:
        return s[:length]
    else:
        return s[:3] + "..." + s[-(length - 6):]


def append_metrics(all, single):
    """Append values of metrics to a collection of values.

    single has the structure {metric_name: [metric_values]},
    all has the structure {metric_name: [[metric_values]]}.
    """
    for k in all:
        for i in range(len(all[k])):
            all[k][i].append(single[k][i])


STATISTICS = {
    "MEAN": np.mean,
    "STDDEV": np.std,
    "MIN": np.min,
    "MAX": np.max
}


def header(prefix, metrics, num, column1, columns):
    """Builds the header of a table."""
    items = [(prefix + " " * column1)[:column1]]
    for m in metrics:
        items.append(" ")
        for i in range(num):
            s = f"{m}[{i:d}]"
            items.append(" " * (columns - len(s)) + s)
    return " ".join(items)


def line(prefix, metrics, values, column1, columns):
    """Builds a line of a table."""
    items = [(prefix + " " * column1)[:column1]]
    for m in metrics:
        items.append(" ")
        for value in values[m]:
            s = f"{value:.4f}"
            items.append(" " * (columns - len(s)) + s)
    return " ".join(items)


def compute_stats(values, fun):
    """Aggregate the metric values."""
    return {key: [fun(vs) for vs in values[key]] for key in values}


def dump_line(paths, metric_names, values):
    """Dump data in a string."""
    res = list(paths)
    for m in metric_names:
        res.extend(map(str, values[m]))
    return " ".join(res)


class NoOutput:
    """A file-like object that does nothing."""
    def write(self, *_):
        pass

    def close(self):
        pass

    def __enter__(self):
        return self

    def __exit__(self, *args):
        pass


def main():
    PATH_WIDTH = 16
    VALUE_WIDTH = 9
    STAT_WIDTH = 9

    args = parse_args()
    # 1) Generate the lists of images.
    reference = load_set(args.reference, args.sort)
    sets = [reference]
    for s in args.other:
        other = load_set(s, args.sort)
        if args.match_names:
            other = match_names(reference, other)
        if len(other) != len(reference):
            raise ValueError("The size of the sets does not agree " +
                             f"({len(reference)} vs. {len(other)}).")
        sets.append(other)

    # 2) Compare the images.
    detail = (sys.stdout if args.detail else NoOutput())
    print(header("REFERENCE", args.metric, len(sets) - 1, PATH_WIDTH, VALUE_WIDTH), file=detail)
    allmetrics = {m: [[] for _ in range(len(sets) - 1)] for m in args.metric}
    tuples = list(zip(*sets))

    def task(t):
        return compare(t, args.metric, args.resize)

    with concurrent.futures.ThreadPoolExecutor(max_workers=args.workers) as executor, \
         (open(args.write, "w") if args.write else NoOutput()) as outfile:
        for paths, metrics in zip(tuples, executor.map(task, tuples)):
            append_metrics(allmetrics, metrics)
            disp = display_str(imagename(paths[0]), PATH_WIDTH)
            print(line(disp, args.metric, metrics, PATH_WIDTH, VALUE_WIDTH), file=detail)
            print(dump_line(paths, args.metric, metrics), file=outfile)
    print(file=detail)

    # 3) Print a summary of the results.
    print(header("", args.metric, len(sets) - 1, STAT_WIDTH, VALUE_WIDTH))
    for stat in STATISTICS:
        stat_values = compute_stats(allmetrics, STATISTICS[stat])
        print(line(stat, args.metric, stat_values, STAT_WIDTH, VALUE_WIDTH))


if __name__ == "__main__":
    try:
        main()
    except Exception as e:
        print(e)
