#!/usr/bin/env python

import argparse
from time import time

import numpy as np
from scipy import ndimage
from PIL import Image, ImageOps
from skimage.restoration import denoise_nl_means, estimate_sigma

from internetarchivepdf.mrc import threshold_image3, denoise_bregman, \
        invert_mask


if __name__ == '__main__':
    parser = argparse.ArgumentParser('PDF recoder thresholder')
    parser.add_argument('--invert-mask', default=False, action='store_true')
    parser.add_argument('--with-invert-image', default=False, action='store_true')
    parser.add_argument('--with-preboxblur', default=False, action='store_true')
    parser.add_argument('--with-pregaussblur', default=False, action='store_true')
    parser.add_argument('--with-predenoise', default=False, action='store_true')
    parser.add_argument('--with-postdenoise', default=False, action='store_true')
    parser.add_argument('--sigma', default=1, type=int)
    parser.add_argument('--size', default=1, type=int)
    parser.add_argument('infile', nargs='?', default=None)
    parser.add_argument('outfile', nargs='?', default=None)
    args = parser.parse_args()

    t = time()
    img = Image.open(args.infile)
    img.load()
    print('Loading jpeg2000 took:', time()-t)

    if img.mode != 'L' or img.mode != 'LA':
        img = img.convert('L')

    imgarr = np.array(img)
    print(np.mean(img))
    print(np.median(img))

    imgf = np.array(img, dtype=float)
    sigma_est = np.mean(estimate_sigma(imgf))
    print('Sigma:', sigma_est)

    if args.with_preboxblur and sigma_est > 1.:
        t=time()
        img = ndimage.uniform_filter(img, size=sigma_est*0.5)
        #img = ndimage.uniform_filter(img, size=args.size)
        print('box pre-blur took', time()-t)
    elif args.with_pregaussblur and sigma_est > 1.:
        t=time()
        print('Blurring with sigma=', sigma_est*0.1)
        img = ndimage.filters.gaussian_filter(img, sigma=sigma_est*0.1)
        #img = ndimage.filters.gaussian_filter(img, sigma=args.sigma)
        print('gauss pre-blur took', time()-t)
    elif args.with_predenoise:
        raise NotImplementedException('Not implemented')
    else:
        img = np.array(img)  # XXX: hack for denoise

    imgf = np.array(img, dtype=float)
    n_sigma_est = np.mean(estimate_sigma(imgf))
    print('Sigma:', n_sigma_est)

    if sigma_est > 1.0 and n_sigma_est > 1.0:
        # TODO: Compare the sigmas
        if abs(n_sigma_est - sigma_est) < 1.:
            print('Running more aggressive blur!')
            t = time()
            img = ndimage.filters.gaussian_filter(img, sigma=sigma_est*0.5)
            print('2nd gauss pre-blur took', time()-t)



    t = time()
    arr = threshold_image3(np.array(img, dtype=np.uint8))
    print('Threshold took:', time()-t)

    sigma_est = np.mean(estimate_sigma(np.array(arr*255, dtype=float)))
    print('Threshold sigma estimate:', sigma_est)

    if args.with_invert_image:
        t = time()
        inverted_image = ImageOps.invert(Image.fromarray(img))
        arr |= threshold_image3(np.array(inverted_image, dtype=np.uint8))
        print('Inverted threshold took:', time()-t)

        sigma_est = np.mean(estimate_sigma(np.array(arr*255, dtype=float)))
        print('Merged threshold sigma estimate:', sigma_est)

    if args.with_postdenoise:
        t = time()
        arr = denoise_bregman(arr)
        print('Denoise took:', time()-t)

        sigma_est = np.mean(estimate_sigma(np.array(arr*255, dtype=float)))
        print('Threshold sigma estimate:', sigma_est)

    if args.invert_mask:
        t = time()
        arr = invert_mask(arr)
        print('Invert took:', time()-t)

    t = time()
    outimg = Image.fromarray(arr)
    outimg.save(args.outfile, compress_level=0, compress=0)
    print('Saving took:', time()-t)
