#!/usr/bin/env python3

# This script fetches Amazon "Order History Reports" and annotates your Mint
# transactions based on actual items in each purchase. It can handle orders
# that are split into multiple shipments/charges, and can even itemized each
# transaction for maximal control over categorization.

import argparse
import atexit
import datetime
from functools import partial
import logging
import os
import sys
import time

from PyQt5.QtCore import (
    Q_ARG, QDate, QEventLoop, Qt, QMetaObject, QObject, QThread, QUrl,
    pyqtSlot, pyqtSignal)
from PyQt5.QtGui import QDesktopServices, QKeySequence
from PyQt5.QtWidgets import (
    QAbstractItemView, QApplication, QCalendarWidget,
    QCheckBox, QComboBox, QDialog, QErrorMessage, QFileDialog,
    QFormLayout, QGroupBox, QHBoxLayout, QInputDialog,
    QLabel, QLineEdit, QMainWindow, QProgressBar,
    QPushButton, QShortcut, QTableView, QWidget, QVBoxLayout)
from outdated import check_outdated

from mintamazontagger import amazon
from mintamazontagger import tagger
from mintamazontagger import VERSION
from mintamazontagger.args import (
    define_gui_args, get_name_to_help_dict, TAGGER_BASE_PATH)
from mintamazontagger.qt import (
    MintUpdatesTableModel, AmazonUnmatchedTableDialog, AmazonStatsDialog,
    TaggerStatsDialog)
from mintamazontagger.mintclient import MintClient
from mintamazontagger.my_progress import QtProgress
from mintamazontagger.orderhistory import fetch_order_history
from mintamazontagger.webdriver import get_webdriver

logger = logging.getLogger(__name__)

NEVER_SAVE_MSG = 'Email & password are *never* saved.'


class TaggerGui:
    def __init__(self, args, arg_name_to_help):
        self.args = args
        self.arg_name_to_help = arg_name_to_help

    def create_gui(self):
        try:
            from fbs_runtime.application_context.PyQt5 import (
                ApplicationContext)
            appctxt = ApplicationContext()
            app = appctxt.app
        except ImportError:
            app = QApplication(sys.argv)
        app.setStyle('Fusion')
        self.window = QMainWindow()

        self.quit_shortcuts = []
        for seq in ("Ctrl+Q", "Ctrl+C", "Ctrl+W", "ESC"):
            s = QShortcut(QKeySequence(seq), self.window)
            s.activated.connect(app.exit)
            self.quit_shortcuts.append(s)

        is_outdated, latest_version = check_outdated(
            'mint-amazon-tagger', VERSION)
        if is_outdated:
            outdate_msg = QErrorMessage(self.window)
            outdate_msg.showMessage(
                'A new version is available. Please update for the best '
                'experience. https://github.com/jprouty/mint-amazon-tagger')

        v_layout = QVBoxLayout()
        h_layout = QHBoxLayout()
        v_layout.addLayout(h_layout)

        amazon_group = QGroupBox('Amazon Order History')
        amazon_group.setMinimumWidth(300)
        amazon_layout = QVBoxLayout()

        amazon_mode = QComboBox()
        amazon_mode.addItem('Fetch Reports')
        amazon_mode.addItem('Use Local Reports')
        amazon_mode.setFocusPolicy(Qt.StrongFocus)

        has_csv = any([
            self.args.orders_csv, self.args.items_csv, self.args.refunds_csv])
        self.amazon_mode_layout = (
            self.create_amazon_import_layout()
            if has_csv else self.create_amazon_fetch_layout())
        amazon_mode.setCurrentIndex(1 if has_csv else 0)
        self.fetch_amazon = not has_csv

        def on_amazon_mode_changed(i):
            self.clear_layout(self.amazon_mode_layout)
            if i == 0:
                self.amazon_mode_layout = self.create_amazon_fetch_layout()
                self.fetch_amazon = True
            elif i == 1:
                self.amazon_mode_layout = self.create_amazon_import_layout()
                self.fetch_amazon = False
            amazon_layout.addLayout(self.amazon_mode_layout)
        amazon_mode.currentIndexChanged.connect(
            on_amazon_mode_changed)

        amazon_layout.addWidget(amazon_mode)
        amazon_layout.addLayout(self.amazon_mode_layout)
        amazon_group.setLayout(amazon_layout)
        h_layout.addWidget(amazon_group)

        mint_group = QGroupBox('Mint Login && Options')
        mint_group.setMinimumWidth(350)
        mint_layout = QFormLayout()

        mint_layout.addRow(
            'Email:',
            self.create_line_edit('mint_email', tool_tip=NEVER_SAVE_MSG))
        mint_layout.addRow(
            'Password:',
            self.create_line_edit(
                'mint_password', tool_tip=NEVER_SAVE_MSG, password=True))
        mint_layout.addRow(
            'MFA Code:',
            self.create_combobox(
                'mint_mfa_preferred_method',
                ['SMS', 'Email'],
                lambda x: x.lower()))
        mint_layout.addRow(
            'I will login myself',
            self.create_checkbox('mint_user_will_login'))
        mint_layout.addRow(
            'Sync first?',
            self.create_checkbox('mint_wait_for_sync'))

        mint_layout.addRow(
            'Merchant Filter',
            self.create_line_edit('mint_input_merchant_filter'))
        mint_layout.addRow(
            'Include MMerchant',
            self.create_checkbox('mint_input_include_mmerchant'))
        mint_layout.addRow(
            'Include Merchant',
            self.create_checkbox('mint_input_include_merchant'))
        mint_layout.addRow(
            'Input Categories Filter',
            self.create_line_edit('mint_input_categories_filter'))
        mint_group.setLayout(mint_layout)
        h_layout.addWidget(mint_group)

        tagger_group = QGroupBox('Tagger Options')
        tagger_layout = QHBoxLayout()
        tagger_left = QFormLayout()

        tagger_left.addRow(
            'Verbose Itemize',
            self.create_checkbox('verbose_itemize'))
        tagger_left.addRow(
            'Do not Itemize',
            self.create_checkbox('no_itemize'))
        tagger_left.addRow(
            'Retag Changed',
            self.create_checkbox('retag_changed'))

        tagger_right = QFormLayout()
        tagger_right.addRow(
            'Do not tag categories',
            self.create_checkbox('no_tag_categories'))
        tagger_right.addRow(
            'Do not predict categories',
            self.create_checkbox('do_not_predict_categories'))
        tagger_right.addRow(
            'Max days between payment/shipment',
            self.create_combobox(
                'max_days_between_payment_and_shipping',
                ['3', '4', '5', '6', '7', '8', '9', '10'],
                lambda x: int(x)))

        tagger_layout.addLayout(tagger_left)
        tagger_layout.addLayout(tagger_right)
        tagger_group.setLayout(tagger_layout)
        v_layout.addWidget(tagger_group)

        self.start_button = QPushButton('Start Tagging')
        self.start_button.setAutoDefault(True)
        self.start_button.clicked.connect(self.on_start_button_clicked)
        v_layout.addWidget(self.start_button)

        main_widget = QWidget()
        main_widget.setLayout(v_layout)
        self.window.setCentralWidget(main_widget)
        self.window.show()
        return app.exec_()

    def create_amazon_fetch_layout(self):
        amazon_fetch_layout = QFormLayout()
        amazon_fetch_layout.addRow(QLabel(
            'Fetches recent Amazon order history for you.'))
        amazon_fetch_layout.addRow(
            'Email:',
            self.create_line_edit('amazon_email', tool_tip=NEVER_SAVE_MSG))
        amazon_fetch_layout.addRow(
            'Password:',
            self.create_line_edit(
                'amazon_password', tool_tip=NEVER_SAVE_MSG, password=True))
        amazon_fetch_layout.addRow(
            'I will login myself',
            self.create_checkbox(
                'amazon_user_will_login'))
        amazon_fetch_layout.addRow(
            'Start date:',
            self.create_date_edit(
                'order_history_start_date',
                'Select Amazon order history start date'))
        amazon_fetch_layout.addRow(
            'End date:',
            self.create_date_edit(
                'order_history_end_date',
                'Select Amazon order history end date'))
        return amazon_fetch_layout

    def create_amazon_import_layout(self):
        amazon_import_layout = QFormLayout()

        order_history_link = QLabel()
        order_history_link.setText(
            '''<a href="https://www.amazon.com/gp/b2b/reports">
            Download your Amazon reports</a><br>
            and select them below:''')
        order_history_link.setOpenExternalLinks(True)
        amazon_import_layout.addRow(order_history_link)

        amazon_import_layout.addRow(
            'Items CSV:',
            self.create_file_edit(
                'items_csv',
                'Select Amazon Items Report'
            ))
        amazon_import_layout.addRow(
            'Orders CSV:',
            self.create_file_edit(
                'orders_csv',
                'Select Amazon Orders Report'
            ))
        amazon_import_layout.addRow(
            'Refunds CSV:',
            self.create_file_edit(
                'refunds_csv',
                'Select Amazon Refunds Report'
            ))
        return amazon_import_layout

    def on_quit(self):
        pass

    def on_tagger_dialog_closed(self):
        self.start_button.setEnabled(True)
        # Reset any csv file handles, as there might have been an error and
        # they user may try again (could already be consumed/closed).
        for attr_name in ('orders_csv', 'items_csv', 'refunds_csv'):
            file = getattr(self.args, attr_name)
            if file:
                setattr(
                    self.args,
                    attr_name,
                    open(file.name, 'r', encoding='utf-8'))

    def on_start_button_clicked(self):
        self.start_button.setEnabled(False)
        # If the fetch tab is selected for Amazon order history, clear out any
        # provided csv file paths, so the tagger actually fetches (versus using
        # the given paths).
        args = argparse.Namespace(**vars(self.args))
        if self.fetch_amazon:
            for attr_name in ('orders_csv', 'items_csv', 'refunds_csv'):
                setattr(args, attr_name, None)

        self.tagger = TaggerDialog(
            args=args,
            parent=self.window)
        self.tagger.show()
        self.tagger.finished.connect(self.on_tagger_dialog_closed)

    def clear_layout(self, layout):
        if layout:
            while layout.count():
                child = layout.takeAt(0)
                if child.widget() is not None:
                    child.widget().deleteLater()
                elif child.layout() is not None:
                    self.clear_layout(child.layout())

    def create_checkbox(self, name, tool_tip=None, invert=False):
        x_box = QCheckBox()
        x_box.setTristate(False)
        x_box.setCheckState(
            Qt.Checked if getattr(self.args, name) else Qt.Unchecked)
        if not tool_tip and name in self.arg_name_to_help:
            tool_tip = 'When checked, ' + self.arg_name_to_help[name]
        if tool_tip:
            x_box.setToolTip(tool_tip)

        def on_changed(state):
            setattr(
                self.args, name,
                state != Qt.Checked if invert else state == Qt.Checked)
        x_box.stateChanged.connect(on_changed)
        return x_box

    def advance_focus(self):
        self.window.focusNextChild()

    def create_line_edit(self, name, tool_tip=None, password=False):
        line_edit = QLineEdit(getattr(self.args, name))
        if not tool_tip:
            tool_tip = self.arg_name_to_help[name]
        if tool_tip:
            line_edit.setToolTip(tool_tip)
        if password:
            line_edit.setEchoMode(QLineEdit.PasswordEchoOnEdit)

        def on_changed(state):
            setattr(self.args, name, state)

        def on_return():
            self.advance_focus()
        line_edit.textChanged.connect(on_changed)
        line_edit.returnPressed.connect(on_return)
        return line_edit

    def create_date_edit(
            self, name, popup_title, max_date=datetime.date.today(),
            tool_tip=None):
        date_edit = QPushButton(str(getattr(self.args, name)))
        date_edit.setAutoDefault(True)
        if not tool_tip:
            tool_tip = self.arg_name_to_help[name]
        if tool_tip:
            date_edit.setToolTip(tool_tip)

        def on_date_edit_clicked():
            dlg = QDialog(self.window)
            dlg.setWindowTitle(popup_title)
            layout = QVBoxLayout()
            cal = QCalendarWidget()
            cal.setMaximumDate(QDate(max_date))
            cal.setSelectedDate(QDate(getattr(self.args, name)))
            cal.selectionChanged.connect(lambda: dlg.accept())
            layout.addWidget(cal)
            okay = QPushButton('Select')
            okay.clicked.connect(lambda: dlg.accept())
            layout.addWidget(okay)
            dlg.setLayout(layout)
            dlg.exec()

            setattr(self.args, name, cal.selectedDate().toPyDate())
            date_edit.setText(str(getattr(self.args, name)))

        date_edit.clicked.connect(on_date_edit_clicked)
        return date_edit

    def create_file_edit(
            self, name, popup_title, filter='CSV files (*.csv)',
            tool_tip=None):
        file_button = QPushButton(
            'Select a file' if not getattr(self.args, name)
            else os.path.split(getattr(self.args, name).name)[1])

        if not tool_tip:
            tool_tip = self.arg_name_to_help[name]
        if tool_tip:
            file_button.setToolTip(tool_tip)

        def on_button():
            dlg = QFileDialog()
            selection = dlg.getOpenFileName(
                self.window, popup_title, '', filter)
            if selection[0]:
                prev_file = getattr(self.args, name)
                if prev_file:
                    prev_file.close()
                setattr(
                    self.args,
                    name,
                    open(selection[0], 'r', encoding='utf-8'))
                file_button.setText(os.path.split(selection[0])[1])

        file_button.clicked.connect(on_button)
        return file_button

    def create_combobox(self, name, items, transform, tool_tip=None):
        combo = QComboBox()
        combo.setFocusPolicy(Qt.StrongFocus)
        if not tool_tip:
            tool_tip = self.arg_name_to_help[name]
        if tool_tip:
            combo.setToolTip(tool_tip)
        combo.addItems(items)

        def on_change(option):
            setattr(self.args, name, transform(option))
        combo.currentTextChanged.connect(on_change)
        return combo


class TaggerDialog(QDialog):
    def __init__(self, args, **kwargs):
        super(TaggerDialog, self).__init__(**kwargs)

        self.reviewing = False
        self.args = args

        self.worker = TaggerWorker()
        self.thread = QThread()
        self.worker.moveToThread(self.thread)

        self.worker.on_error.connect(self.on_error)
        self.worker.on_review_ready.connect(self.on_review_ready)
        self.worker.on_stopped.connect(self.on_stopped)
        self.worker.on_progress.connect(self.on_progress)
        self.worker.on_updates_sent.connect(self.on_updates_sent)
        self.worker.on_mint_mfa.connect(self.on_mint_mfa)

        self.thread.started.connect(
            partial(self.worker.create_updates, args, self))
        self.thread.start()

        self.init_ui()

    def init_ui(self):
        self.setWindowTitle('Tagger is running...')
        self.setModal(True)
        self.v_layout = QVBoxLayout()
        self.setLayout(self.v_layout)

        self.label = QLabel()
        self.v_layout.addWidget(self.label)

        self.progress = 0
        self.progress_bar = QProgressBar()
        self.progress_bar.setRange(0, 0)
        self.v_layout.addWidget(self.progress_bar)

        self.button_bar = QHBoxLayout()
        self.v_layout.addLayout(self.button_bar)

        self.cancel_button = QPushButton('Cancel')
        self.button_bar.addWidget(self.cancel_button)
        self.cancel_button.clicked.connect(self.on_cancel)

    def on_error(self, msg):
        logger.error(msg)
        self.label.setText('Error: {}'.format(msg))
        self.label.setStyleSheet(
            'QLabel { color: red; font-weight: bold; }')
        self.cancel_button.setText('Close')
        self.cancel_button.clicked.connect(self.close)

    def open_amazon_order_id(self, order_id):
        if order_id:
            QDesktopServices.openUrl(QUrl(
                amazon.get_invoice_url(order_id)))

    def on_activated(self, index):
        # Only handle clicks on the order_id cell.
        if index.column() != 5:
            return
        order_id = self.updates_table_model.data(index, Qt.DisplayRole)
        self.open_amazon_order_id(order_id)

    def on_double_click(self, index):
        if index.column() == 5:
            # Ignore double clicks on the order_id cell.
            return
        order_id_cell = self.updates_table_model.createIndex(index.row(), 5)
        order_id = self.updates_table_model.data(order_id_cell, Qt.DisplayRole)
        self.open_amazon_order_id(order_id)

    def on_review_ready(self, results):
        self.reviewing = True
        self.progress_bar.hide()

        self.label.setText('Select below which updates to send to Mint.')

        self.updates_table_model = MintUpdatesTableModel(results.updates)
        self.updates_table = QTableView()
        self.updates_table.doubleClicked.connect(self.on_double_click)
        self.updates_table.clicked.connect(self.on_activated)

        def resize():
            self.updates_table.resizeColumnsToContents()
            self.updates_table.resizeRowsToContents()
            min_width = sum(
                self.updates_table.columnWidth(i) for i in range(6))
            self.updates_table.setMinimumSize(min_width + 20, 600)

        self.updates_table.setSelectionMode(QAbstractItemView.SingleSelection)
        self.updates_table.setSelectionBehavior(QAbstractItemView.SelectRows)
        self.updates_table.setModel(self.updates_table_model)
        self.updates_table.setSortingEnabled(True)
        resize()
        self.updates_table_model.layoutChanged.connect(resize)

        self.v_layout.insertWidget(2, self.updates_table)

        unmatched_button = QPushButton('View Unmatched Amazon orders')
        self.button_bar.addWidget(unmatched_button)
        unmatched_button.clicked.connect(
            partial(self.on_open_unmatched, results.unmatched_orders))

        amazon_stats_button = QPushButton('Amazon Stats')
        self.button_bar.addWidget(amazon_stats_button)
        amazon_stats_button.clicked.connect(
            partial(self.on_open_amazon_stats,
                    results.items,
                    results.orders,
                    results.refunds))

        tagger_stats_button = QPushButton('Tagger Stats')
        self.button_bar.addWidget(tagger_stats_button)
        tagger_stats_button.clicked.connect(
            partial(self.on_open_tagger_stats, results.stats))

        self.confirm_button = QPushButton('Send to Mint')
        self.button_bar.addWidget(self.confirm_button)
        self.confirm_button.clicked.connect(self.on_send)

        self.setGeometry(50, 50, self.width(), self.height())

    def on_updates_sent(self, num_sent):
        self.label.setText(
            'All done! {} newly tagged Mint transactions'.format(num_sent))
        self.cancel_button.setText('Close')

    def on_open_unmatched(self, unmatched):
        self.unmatched_dialog = AmazonUnmatchedTableDialog(unmatched)
        self.unmatched_dialog.show()

    def on_open_amazon_stats(self, items, orders, refunds):
        self.amazon_stats_dialog = AmazonStatsDialog(items, orders, refunds)
        self.amazon_stats_dialog.show()

    def on_open_tagger_stats(self, stats):
        self.tagger_stats_dialog = TaggerStatsDialog(stats)
        self.tagger_stats_dialog.show()

    def on_send(self):
        self.progress_bar.show()
        updates = self.updates_table_model.get_selected_updates()

        self.confirm_button.hide()
        self.updates_table.hide()
        self.confirm_button.deleteLater()
        self.updates_table.deleteLater()
        self.adjustSize()

        QMetaObject.invokeMethod(
            self.worker, 'send_updates', Qt.QueuedConnection,
            Q_ARG(list, updates),
            Q_ARG(object, self.args))

    def on_stopped(self):
        self.close()

    def on_progress(self, msg, max, value):
        self.label.setText(msg)
        self.progress_bar.setRange(0, max)
        self.progress_bar.setValue(value)

    def on_cancel(self):
        if not self.reviewing:
            QMetaObject.invokeMethod(
                self.worker, 'stop', Qt.QueuedConnection)
        else:
            self.close()

    def on_mint_mfa(self):
        mfa_code, ok = QInputDialog().getText(
            self, 'Please enter your Mint Code.',
            'Mint Code:')
        self.worker.mfa_code = mfa_code
        QMetaObject.invokeMethod(
            self.worker, 'mfa_code', Qt.QueuedConnection,
            Q_ARG(str, mfa_code))
        self.worker.on_mint_mfa_done.emit()


class TaggerWorker(QObject):
    """This class is required to prevent locking up the main Qt thread."""
    on_error = pyqtSignal(str)
    on_review_ready = pyqtSignal(tagger.UpdatesResult)
    on_updates_sent = pyqtSignal(int)
    on_stopped = pyqtSignal()
    on_mint_mfa = pyqtSignal()
    on_mint_mfa_done = pyqtSignal()
    on_progress = pyqtSignal(str, int, int)
    stopping = False
    webdriver = None

    @pyqtSlot()
    def stop(self):
        self.stopping = True

    @pyqtSlot(str)
    def mfa_code(self, code):
        logger.info(code)
        self.mfa_code = code

    @pyqtSlot(object)
    def create_updates(self, args, parent):
        try:
            self.do_create_updates(args, parent)
        except Exception as e:
            msg = 'Internal error while creating updates: {}'.format(e)
            self.on_error.emit(msg)
            logger.exception(msg)

    @pyqtSlot(list, object)
    def send_updates(self, updates, args):
        try:
            self.do_send_updates(updates, args)
        except Exception as e:
            msg = 'Internal error while sending updates: {}'.format(e)
            self.on_error.emit(msg)
            logger.exception(msg)

    def close_webdriver(self):
        if self.webdriver:
            self.webdriver.close()
            self.webdriver = None

    def get_webdriver(self, args):
        if self.webdriver:
            logger.info('Using existing webdriver')
            return self.webdriver
        logger.info('Creating a new webdriver')
        self.webdriver = get_webdriver(args.headless, args.session_path)
        return self.webdriver

    def do_create_updates(self, args, parent):
        def on_mint_mfa(prompt):
            logger.info('Asking for Mint MFA')
            self.on_mint_mfa.emit()
            loop = QEventLoop()
            self.on_mint_mfa_done.connect(loop.quit)
            loop.exec_()
            logger.info(self.mfa_code)
            return self.mfa_code

        # Factory that handles indeterminite, determinite, and counter style.
        def progress_factory(msg, max=0):
            return QtProgress(msg, max, self.on_progress.emit)

        atexit.register(self.close_webdriver)

        bound_webdriver_factory = partial(self.get_webdriver, args)
        self.mint_client = MintClient(
            args,
            bound_webdriver_factory,
            mfa_input_callback=on_mint_mfa)

        if not fetch_order_history(
                args, bound_webdriver_factory, progress_factory):
            self.on_error.emit(
                'Failed to fetch Amazon order history. Check credentials')
            return

        results = tagger.create_updates(
            args, self.mint_client,
            on_critical=self.on_error.emit,
            indeterminate_progress_factory=progress_factory,
            determinate_progress_factory=progress_factory,
            counter_progress_factory=progress_factory)

        if results.success and not self.stopping:
            self.on_review_ready.emit(results)

    def do_send_updates(self, updates, args):
        num_updates = self.mint_client.send_updates(
            updates,
            progress=QtProgress(
                'Sending updates to Mint',
                len(updates),
                self.on_progress.emit),
            ignore_category=args.no_tag_categories)
        self.close_webdriver()
        self.on_updates_sent.emit(num_updates)


def main():
    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    root_logger.addHandler(logging.StreamHandler())
    # Disable noisy log spam from filelock from within tldextract.
    logging.getLogger("filelock").setLevel(logging.WARN)
    # For helping remote debugging, also log to file.
    # Developers should be vigilant to NOT log any PII, ever (including being
    # mindful of what exceptions might be thrown).
    log_directory = os.path.join(TAGGER_BASE_PATH, 'Tagger Logs')
    os.makedirs(log_directory, exist_ok=True)
    log_filename = os.path.join(log_directory, '{}.log'.format(
        time.strftime("%Y-%m-%d_%H-%M-%S")))
    root_logger.addHandler(logging.FileHandler(log_filename))

    parser = argparse.ArgumentParser(
        description='Tag Mint transactions based on itemized Amazon history.')
    define_gui_args(parser)
    args = parser.parse_args()

    sys.exit(TaggerGui(args, get_name_to_help_dict(parser)).create_gui())


if __name__ == '__main__':
    main()
