#!/usr/bin/env python3
import sys

import fitz

import io
import numpy as np

from PIL import Image


"""
TODO:

* Test scaling on "hard PDFs" from H
* Do we want to match the mask w/h against the image it is a mask for?
* For certain images, perhaps extract EXIF metadata embedded in the images
* See if we can figure out a way to not have to decode the images in full (in
  particular the JPEG2000 ones), currently we need to do that for the transform
  and bbox
* xres/yres?
* whether the PDF is PDF/A?
* table of contents/chapter info?
* annotations info?

XXX:

* What if a PDF draws vector stuff in colour but has only bitonal images, do we
  want to suggest bitonal still as image stack format? likely not, I imagine?
  (We could render the page to a pixmap and see if the whole page is also
  monochrome, grayscale or coloured!)

"""

ANALYSIS_VERSION = '0.0.1'
SPEC_VERSION = '0.0.1'

UNSUPPORTED_PREFIX = 'XXX'


pdf_metadata_json_round = lambda x: round(x, 2)
def round_list(v):
    return list(map(pdf_metadata_json_round, v))


def remove_images(doc, page, unwanted):
    un_list = [b"/%s Do" % u.encode() for u in unwanted]
    un_list = [x.lower() for x in un_list]
    page.clean_contents()  # unify / format the commands
    xref=page.get_contents()[0]  # get its XREF
    cont=page.read_contents().splitlines()  # read commands as list of lines
    for i in range(len(cont)):  # walk thru the lines
        match_ = cont[i].lower()
        if match_ in un_list:  # invokes an unwanted image
            cont[i] = b""  # remove command
    doc.update_stream(xref, b"\n".join(cont))  # replace cleaned command object

def get_page_colour_mode(pdf, page, image_data):
    labels = []
    for image in image_data:
        labels.append(image['label'])

    remove_images(pdf, page, labels)
    page.clean_contents()


    pix = page.get_pixmap()
    byt = pix.pil_tobytes(format='PNG', optimise=False, compress_level=0)
    pix = None

    imgio = io.BytesIO()
    imgio.write(byt)
    byt = None

    image = Image.open(imgio)
    image.load()
    imgio.close()

    arr = np.array(image)

    w, h, dim = arr.shape
    assert dim == 3, 'page.get_pixmap did not return RGB image?'

    rg = np.all(arr[:, :, 0] == arr[:, :, 1])
    rb = np.all(arr[:, :, 0] == arr[:, :, 2])
    bg = np.all(arr[:, :, 1] == arr[:, :, 2])

    is_gray = rg and rb and bg

    if not is_gray:
        return 'RGB'

    mn = np.min(arr)
    mx = np.max(arr)

    mn_match = arr[:, :, 0] == mn
    mx_match = arr[:, :, 0] == mx

    comb = mn_match | mx_match

    is_bitonal = np.all(comb)

    if is_bitonal:
        return 'Bitonal'

    return 'Grayscale'

def pixmap_for_xref(pdf, xref):
    pix = fitz.Pixmap(pdf, xref)
    return pix


def get_xref_pix_info(pdf, xref):
    # XXX: This uses PIL modes for images, should we?
    data = {}

    pix = pixmap_for_xref(pdf, xref)

    data['size'] = (pix.width, pix.height)

    # XXX: pix.n is the bytes requires per pixel, it doesn't necessarily give us
    # the depth, I believe, but maybe it does, if we also know the number of
    # components... there's also pix.colorspace.n
    if (pix.colorspace is None) or (pix.colorspace.name == 'DeviceGray'):
        if pix.is_monochrome:
            data['depth'] = 1
            data['mode'] = '1'
        else:
            data['depth'] = 8  # XXX: hardcoded for now, n is the bytes required for all components
            data['mode'] = 'LA' if pix.alpha else 'L'

    elif pix.colorspace.name == 'DeviceRGB':
        data['depth'] = 8  # XXX: hardcoded for now, n is the bytes required for all components
        data['mode'] = 'RGBA' if pix.alpha else 'RGB'

    elif pix.colorspace.name == 'DeviceCMYK':
        # This should be always 8, per CMYK spec
        # CMYK doesn't have alpha, so we don't check for that
        data['depth'] = 8
        data['mode'] = 'CMYK'


    else:
        raise Exception('Unknown colorspace:', pix.colorspace.name)

    if pix.colorspace:
        data['colorspace'] = pix.colorspace.name


    return data


def map_pdf_filter_to_likely_image_format(filter_):
    filter_map = {
            'JPXDecode': 'JPEG2000',
            'JBIG2Decode': 'JBIG2',
            'CCITTFaxDecode': 'CCITT',
            'DCTDecode': 'JPEG',

            # No filter is lossless
            '': 'PNG',

            # This is more tricky - do we just use PNG for random "lossless"
            'FlateDecode': 'PNG',
            'ASCII85Decode': 'PNG',
            'ASCIIHexDecode': 'PNG',
            'LZWDecode': 'PNG',
            'RunLengthDecode': 'PNG',
    }

    if filter_ not in filter_map:
        raise Exception('Unknown filter:', filter_)

    return filter_map[filter_]

def get_scale_from_image_data(image_data):
    # TODO: what else can we take into account here?

    # TODO: What do we do when we have no images? Render at 300dpi?

    if image_data:
        scale_x = 1.
        scale_y = 1.
        for info in image_data:
            bbox = info['bbox']
            width = info['width']
            height = info['height']

            bbox_w = abs(bbox[2] - bbox[0])
            bbox_h = abs(bbox[3] - bbox[1])

            if bbox_w < width:
                scale_x = width / bbox_w

            if bbox_h < height:
                scale_y = height / bbox_h

        scale = max(scale_x, scale_y)
    else:
        scale = 300. / 72.

    # TODO: Do we want some min scale? (if it's less than x, set to x)

    return scale


def get_recommended_image_format_from_page_data(page_data):
    page_colour_modes = [x['page_without_images_color_mode'] for x in page_data]
    if 'RGB' in page_colour_modes:
        return 'RGB'

    if not bool(list(filter(lambda x: 'image_data' in x and x['image_data'], page_data))):
        if 'Grayscale' in page_colour_modes:
            return 'Grayscale'

        if 'Bitonal' in page_colour_modes:
            return 'Bitonal'

        raise ValueError('Cannot recommend image format from page_colour_modes: %s' % page_colour_modes)

    flattened_image_modes = []

    for page in page_data:
        if 'image_data' in page:
            for image in page['image_data']:
                flattened_image_modes.append(image['mode'])

    if 'RGB' in flattened_image_modes or 'RGBA' in flattened_image_modes:
        return 'RGB'


    if 'Grayscale' in page_colour_modes:
        return 'Grayscale'

    if 'L' in flattened_image_modes or 'LA' in flattened_image_modes:
        return 'Grayscale'

    if '1' in flattened_image_modes:
        return 'Bitonal'

    raise ValueError('Cannot recommend image format from image modes: %s' % flattened_image_modes)


def analyse(filename):
    res = {}

    pymupdf_version, mupdf_version, date = fitz.version

    res['version'] = {
            'analysis': ANALYSIS_VERSION,
            'spec': SPEC_VERSION,
            'pymupdf': pymupdf_version,
            'mupdf': mupdf_version,
    }

    pdf = fitz.open(filename)
    res['page_count'] = pdf.page_count

    res['page_data'] = []

    for ppidx, page in enumerate(pdf):
        page_data = {}

        page_data['page_number'] = page.number
        page_data['page_rotation'] = page.rotation
        page_data['page_language'] = page.language
        page_data['page_rect'] = list(page.rect)
        # page_data['page_rotation_matrix'] = page.rotation_matrix

        # We'd like to get rid of this call later, as we only use it to get
        # transform/bboxes, and it's slow. It reads the page contents for image
        # operations, but also actually decompresses the images, which is
        # something we'd rather not do.
        iminfo_all = page.get_image_info(xrefs=True)

        image_data = []

        for page_image in page.get_images():
            img_info = {}

            img_id, mask_id, w, h, bpc, _ignore_colorspace, _ignore, label, filter_ = page_image
            iminfo = None
            for ii in iminfo_all:
                if ii['xref'] == img_id:
                    iminfo = ii
                    break

            if iminfo is None:
                # This can occur since page.get_images can return more images
                # than are actually on the page, whereas page.get_image_info
                # does not, but we need this info as well for the masks
                print('Image not actually present on the page. Skipping this image.', file=sys.stderr)
                continue

            # Info from page.get_images
            img_info['xref'] = img_id
            img_info['width'] = w
            img_info['height'] = h
            img_info['depth'] = bpc
            img_info['label'] = label
            # img_info['%s_filter' % UNSUPPORTED_PREFIX] = filter_


            # Info from page.get_image_info(xrefs=True)
            img_info['bbox'] = round_list(iminfo['bbox'])
            img_info['transform'] = round_list(iminfo['transform'])
            #img_info['xres'] = iminfo['xres'] # not sure if this is good/accurate/useful
            #img_info['yres'] = iminfo['yres'] # not sure if this is good/accurate/useful

            pix_info = get_xref_pix_info(pdf, img_id)

            img_info['mode'] = pix_info['mode']

            # Info from pixmap from image
            # img_info['%s_colorspace' % UNSUPPORTED_PREFIX] = pix_info.get('colorspace', None)
            # img_info['%s_suggested_image' % UNSUPPORTED_PREFIX] = map_pdf_filter_to_likely_image_format(filter_)

            img_info['mask'] = None

            if mask_id > 0:
                mask_pix_info = get_xref_pix_info(pdf, mask_id)

                mask_info = {}
                mask_info['xref'] = mask_id
                mask_info['width'] = mask_pix_info['size'][0]
                mask_info['height'] = mask_pix_info['size'][1]
                mask_info['depth'] = mask_pix_info['depth']
                mask_info['mode'] = mask_pix_info['mode']

                typ_, mask_filter = pdf.xref_get_key(mask_id, 'Filter')
                #mask_info['%s_filter' % UNSUPPORTED_PREFIX] = mask_filter

                img_info['mask'] = mask_info


            image_data.append(img_info)

        scale = get_scale_from_image_data(image_data)

        page_data['estimated_scale'] = pdf_metadata_json_round(scale)
        page_data['estimated_ppi'] = int(72 * scale)
        page_data['estimated_default_render_res'] = round_list(list(map(lambda x:x*scale, page_data['page_rect'])))

        # Capture hyperlinks
        link_uri = []
        for link in page.get_links():
            if link['kind'] == fitz.LINK_URI:
                link_rect = link['from']
                #link_rect = link['from'].round()

                link_uri.append({
                    'uri': link['uri'],
                    'xref': link['xref'],
                    'bbox': round_list([link_rect.x0, link_rect.y0, link_rect.x1, link_rect.y1]),
                })

        flgs = fitz.TEXT_PRESERVE_WHITESPACE | \
                fitz.TEXT_PRESERVE_LIGATURES | \
                fitz.TEXT_PRESERVE_IMAGES
        pagetext = page.get_text(option='rawdict', flags=flgs)
        page_data['has_text_layer'] = len(pagetext['blocks']) != 0

        if len(link_uri):
            page_data['hyperlinks'] = link_uri

        if len(image_data):
            page_data['image_data'] = image_data

        page_data['page_without_images_color_mode'] = get_page_colour_mode(pdf, page, page_data.get('image_data', []))



        res['page_data'].append(page_data)

    res['imagestack_image_format'] = get_recommended_image_format_from_page_data(res['page_data'])

    # TODO (maybe):
    # - permissions
    # - password required?
    # - language(?)
    # - is_repaired?

    return res


#def pick_pdf_format(pdf_data):
#    pass
#
#
#def get_pil_info_fast(pdf, xref, filter_):
#    data = {}
#    if filter_ in ('DCTDecode'):
#        xrefdata = pdf.xref_stream_raw(xref)
#    else:
#        xrefdata = pdf.xref_stream(xref)
#
#    imgfd = io.BytesIO()
#    imgfd.write(xrefdata)
#    image = Image.open(imgfd)
#    data['size'] = image.size
#    data['mode'] = image.mode
#    # TODO: make this more complete.
#    if image.mode == '1':
#        data['depth'] = 1
#    if image.mode in ('L', 'LA', 'P', 'RGB', 'RGBA'):
#        data['depth'] = 8
#
#    # TODO: get image.info (dpi, etc)
#    image.close()
#    imgfd.close()
#
#    return data
#
# This lacks bbox/transform, which we -might- need?
#def get_xref_info_fast(pdf, xref, mask=False):
#    keys = pdf.xref_get_keys(xref)
#
#    typ_, width = pdf.xref_get_key(xref, 'Width')
#    if typ_ != 'int': raise Exception('Invalid width')
#    width = int(width)
#
#    typ_, height = pdf.xref_get_key(xref, 'Height')
#    if typ_ != 'int': raise Exception('Invalid height')
#    height = int(height)
#
#    typ_, depth = pdf.xref_get_key(xref, 'BitsPerComponent')
#    if typ_ != 'int': raise Exception('Invalid BitsPerComponent')
#    depth = int(depth)
#
#    typ_, cs = pdf.xref_get_key(xref, 'ColorSpace')
#
#    typ_, filter_ = pdf.xref_get_key(xref, 'Filter')
#
#    # XXX: how do we potentially encoded dpi? do we care for jbig2?
#
#    #if mask and depth != 1:
#    #    raise ValueError('image with depth != 1???')
#
#    d = {'size': (width, height),
#         'mode': '1',
#         'depth': depth,
#         'filter': filter_,
#         'colorspace': cs}
#    #try:
#    #    mask_info = get_pil_info_fast(pdf, xref)
#    #except:
#    #    mask_info = get_pil_info(pdf, xref)
#    return d


## TODO: this lacks mask info
#for page_image in page.get_image_info():
#    print(page_image)
#    img_info = {}

#    img_info['bbox'] = page_image['bbox']
#    img_info['transform'] = page_image['transform']
#    img_info['colorspace'] = page_image['cs-name']
#    img_info['width'] = page_image['width']
#    img_info['height'] = page_image['height']
#    img_info['depth'] = page_image['bpc']

#    image_data.append(img_info)



if __name__ == '__main__':
    r = analyse(sys.argv[1])
    from json import dump

    dump(r, sys.stdout, indent=' ' * 4)
