#!/usr/bin/env python
import numpy as np
import os
import sys
import parsnip
import time


if __name__ == '__main__':
    start_time = time.time()

    parser = parsnip.build_default_argparse('Train a ParSNIP model.')

    parser.add_argument('model_path')
    parser.add_argument('dataset_paths', nargs='+')

    parser.add_argument('--overwrite', action='store_true')
    parser.add_argument('--max_epochs', type=int, default=1000)
    parser.add_argument('--split_train_test', action='store_true')
    parser.add_argument('--bands', default=None)

    parser.add_argument('--device', default='cuda')
    parser.add_argument('--threads', default=8, type=int)

    # Parse the arguments
    args = vars(parser.parse_args())

    # Figure out if we have already trained a model at this path.
    model_path = args['model_path']
    if os.path.exists(model_path):
        if args['overwrite']:
            print(f"Model '{model_path}' already exists, overwriting!")
        else:
            print(f"Model '{model_path}' already exists, skipping!")
            sys.exit()

    # Load the dataset(s).
    dataset = parsnip.load_datasets(args['dataset_paths'])

    # Figure out which bands we want to use for the model. If specific ones were
    # specified on the command line, use those. Otherwise, use all available bands.
    bands = args.pop('bands')
    if bands is None:
        bands = parsnip.get_bands(dataset)
    else:
        bands = bands.split(',')

    model = parsnip.ParsnipModel(
        model_path,
        bands,
        device=args['device'],
        threads=args['threads'],
        settings=args,
        ignore_unknown_settings=True
    )

    dataset = model.preprocess(dataset)

    if args['split_train_test']:
        train_dataset, test_dataset = parsnip.split_train_test(dataset)
        model.fit(train_dataset, test_dataset=test_dataset,
                  max_epochs=args['max_epochs'])
    else:
        train_dataset = dataset
        model.fit(train_dataset, max_epochs=args['max_epochs'])

    # Save the score to a file for quick comparisons. If we have a small dataset,
    # repeat the dataset several times when calculating the score.
    rounds = int(np.ceil(25000 / len(train_dataset)))

    train_score = model.score(train_dataset, rounds=rounds)
    if args['split_train_test']:
        test_score = model.score(test_dataset, rounds=10 * rounds)
    else:
        test_score = -1.

    end_time = time.time()

    # Time taken in minutes
    elapsed_time = (end_time - start_time) / 60.

    with open('./parsnip_results.log', 'a') as f:
        print(f'{model_path} {model.epoch} {elapsed_time:.2f} {train_score:.4f} '
              f'{test_score:.4f}', file=f)
