#!/usr/bin/env python3
"""
Create summary PDF report

AUTHORS
----
Mike Tyszka, Ph.D., Caltech Brain Imaging Center

MIT License

Copyright (c) 2020 Mike Tyszka

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import os
import shutil
import tempfile
import numpy as np
from datetime import datetime

import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler

from pandas.plotting import register_matplotlib_converters
from reportlab.lib.enums import TA_JUSTIFY
from reportlab.lib.pagesizes import letter
from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
from reportlab.lib.units import inch
from reportlab.platypus import (SimpleDocTemplate,
                                Paragraph,
                                Spacer,
                                Image,
                                Table,
                                PageBreak)

from .graphics import metric_trend_plot


class Summarize:

    def __init__(self, report_dir, metrics_df, past_months):
        """
        Create summary report PDF and CSV file for all sessions

        :param report_dir: str, report output directory in derivatives
        :param metrics_df: DataFrame, session metric dataframe
        :param past_months: int, number of past months to summarize
        """

        # For datetime axis labeling without warnings
        register_matplotlib_converters()

        # Add plot-friendly Date column
        metrics_df['Date'] = [datetime.strptime(dt, '%Y-%m-%dT%H:%M:%S.%f')
                              for dt in metrics_df['AcquisitionDateTime'].values]

        # Metrics of interest to plot and save to CSV
        self._metric_names = [
            'SNR',
            'SFNR',
            'NoiseFloor',
            'Drift',
            'NyquistSpikes',
            'AirSpikes'
        ]

        self._report_dir = report_dir
        self._metrics_df = metrics_df
        self._subject = metrics_df['Subject'][0]
        self._past_months = past_months
        self._summary_pdf = os.path.join(report_dir, '{}_summary.pdf'.format(self._subject))
        self._summary_csv = self._summary_pdf.replace('.pdf', '.csv')

        # Identify outliers before generating plots and writing CSV
        self._outliers()

        #
        # Summary PDF construction
        #

        # Create working directory for images
        self._work_dir = tempfile.mkdtemp()

        # Contents - list of flowables to be built into a document
        self._contents = []

        # Add a justified paragraph style
        self._pstyles = getSampleStyleSheet()
        self._pstyles.add(ParagraphStyle(name='Justify', alignment=TA_JUSTIFY))

        self._init_pdf()
        self._add_coverpage()
        self._add_metric_graphs()

        self._doc.build(self._contents)

        # Delete working directory
        shutil.rmtree(self._work_dir)

        # Finally write CSV for metrics of interest
        self._write_csv()

    def _init_pdf(self):

        # Create a new PDF document
        self._doc = SimpleDocTemplate(self._summary_pdf,
                                      pagesize=letter,
                                      rightMargin=0.5 * inch,
                                      leftMargin=0.5 * inch,
                                      topMargin=0.5 * inch,
                                      bottomMargin=0.5 * inch)

    def _add_coverpage(self):

        ptext = '<font size=24>CBIC Quality Control Summary</font>'
        self._contents.append(Paragraph(ptext, self._pstyles['Justify']))
        self._contents.append(Spacer(1, 0.5 * inch))

        timestamp = datetime.now().strftime('%Y-%m-%d at %H:%M:%S')
        ptext = '<font size=12>Generated by CBICQC on {}</font>'.format(timestamp)
        self._contents.append(Paragraph(ptext, self._pstyles['Justify']))
        self._contents.append(Spacer(1, 0.25 * inch))

        ptext = '<font size=14><b>Subject Details</b></font>'
        self._contents.append(Paragraph(ptext, self._pstyles['Justify']))
        self._contents.append(Spacer(1, 0.1 * inch))

        # First session metadata - assume identical for all sessions
        m = self._metrics_df.iloc[0]

        meta = [['Subject', self._subject],
                ['Scanner', m['StationName'] + ' ' + m['DeviceSerialNumber']],
                ['Software Version', m['SoftwareVersions']],
                ['Coil Name', m['ReceiveCoilName']]]

        self._contents.append(Table(meta, hAlign='LEFT'))
        self._contents.append(Spacer(1, 0.25 * inch))

        ptext = '<font size=14><b>Potential Outlier Sessions</b></font>'
        self._contents.append(Paragraph(ptext, self._pstyles['Justify']))
        self._contents.append(Spacer(1, 0.1 * inch))

        ptext = '<font size=12>Identified by conservative DBSCAN clustering (epsilon = 1.5)</font>'
        self._contents.append(Paragraph(ptext, self._pstyles['Justify']))
        self._contents.append(Spacer(1, 0.1 * inch))

        df = self._metrics_df[['Date', 'Outlier']]
        meta = df[df['Outlier'] == 'Outlier'].values.tolist()
        self._contents.append(Table(meta, hAlign='LEFT'))
        self._contents.append(Spacer(1, 0.25 * inch))

    def _add_metric_graphs(self):
        """
        Add timecourse and histogram plots for the given metric class ('Noise' or 'Spike/Drift')
        :return:
        """

        # Page break
        self._contents.append(PageBreak())

        ptext = '<font size=14><b>Session Metric Trends</b></font>'
        self._contents.append(Paragraph(ptext, self._pstyles['Justify']))
        self._contents.append(Spacer(1, 0.25 * inch))

        trend_png_fname = self._plot_trends()
        trends_img = Image(trend_png_fname, 7.0 * inch, 9.0 * inch, hAlign='LEFT')
        self._contents.append(trends_img)

    def _plot_trends(self):
        """
        Generate a PNG of the trends and histogram for each of the passed metrics
        :return:
        """

        # Number of metrics to plot
        n_metrics = len(self._metric_names)

        # Output PNG filename
        png_fname = os.path.join(self._work_dir, 'metric_trends.png')

        # Setup plot grid
        plt.figure(figsize=(14, 18))
        gs = gridspec.GridSpec(n_metrics, 2, width_ratios=[3, 1])

        # Fill each of the subplots
        for mc, m_name in enumerate(self._metric_names):
            metric_trend_plot(mc, m_name, self._metrics_df, gridspec=gs, past_months=self._past_months)

        # Tweak subplot margins and spacing
        plt.tight_layout()

        # Save plot to file
        plt.savefig(png_fname, dpi=300)

        return png_fname

    def _write_csv(self):

        # Full column list including subject, session, date and metrics
        cols_to_write = ['Subject', 'Session', 'Date'] + self._metric_names

        print()
        print('Writing metrics of interest to {:s}'.format(self._summary_csv))

        self._metrics_df.to_csv(self._summary_csv,
                                columns=cols_to_write,
                                header=True,
                                index=False)

    def _outliers(self):
        """
        Identify within sample outliers using DBSCAN
        Adds outlier flag column to metric DataFrame

        :return:
        """

        n = len(self._metrics_df)

        # DBSCAN clustering
        X = np.array(self._metrics_df[self._metric_names])

        # Standardize mean, sd to 0, 1
        X = StandardScaler().fit_transform(X)

        # Cluster using DBSCAN
        clustering = DBSCAN(eps=1.5, min_samples=5).fit(X)

        # Add boolean Outlier column to DataFrame
        self._metrics_df['Outlier'] = ['Outlier' if l < 0 else 'Inlier' for l in clustering.labels_]