"""
This module provides access to pre-trained synthesizability machine
learning models for predicting synthesizability of arbitrary crystal
compounds.
"""


import os
import pandas as pd
import numpy as np
import logging


from monty.serialization import loadfn, dumpfn

from pymatgen.ext.matproj import MPRester

from pymatgen import Composition
import json
from pymatgen import Structure, Lattice
from pumml.learners import PULearner
from matminer.featurizers.structure import DensityFeatures, GlobalSymmetryFeatures
from matminer.featurizers.composition import Meredig, CohesiveEnergy
from pathlib import Path
import tarfile

logging.basicConfig(
    filename="output.log", format="%(levelname)s - %(message)s", level=logging.DEBUG
)


class PUPredict:
    def __init__(self, api_key):
        """Predict synthesizability of all compounds in Materials Project
        Database.

        Features for training data are generated by first-principles
        (density functional theory) calculations, or structural or chemical
        data looked up from Materials Project Database.

        Two models each for F-Block compounds and Non F-Block compounds
        have been pre-trained which are used to predict synthesizability of
        compounds based on the block they belong to.

        Args:
            api_key (str): API key for Materials Project API available from
                https://materialsproject.org/open.

        """

        self.api_key = api_key

        # Extract Data from FigShare
        os.system(
            "wget https://s3-eu-west-1.amazonaws.com/pfigshare-u-files/23873285/pupredict_data.tar.gz"
        )
        self.model_dir = os.getcwd()

        # Extract file from the tar file
        my_tar = tarfile.open("pupredict_data.tar.gz")
        my_tar.extractall("Model_Data")  # specifing which folder to extract to
        my_tar.close()

    def _mpr(self):
        """Get MPRester.

        Returns:
            MPRester (function): A function which is used to extract data from Materials Project

        """
        return MPRester(api_key=self.api_key)

    def _get_mp_id(self, updated_input):
        """Gets the Materials IDs from the database.

        Args:
            updated_input (criteria): The condition on which the data is extracted

        Returns:
            mp_ids (list): List of MP-IDs

        """

        criteria = updated_input

        # get mp-id
        mpids = []
        with self._mpr() as m:
            mp_ids = m.query(criteria, ["material_id"], chunk_size=0)

        mp_ids = [mpid["material_id"] for mpid in mp_ids]

        return mp_ids

    def _get_formula(self, updated_input):
        """Gets the full formula from Materials Project.

        Args:
            updated_input (criteria): The condition on which the data is extracted

        Returns:
            formula (str): A string which contains the formula of the input

        """

        criteria = updated_input

        formula = []

        # get formulae
        with self._mpr() as m:
            formula = m.query(criteria, ["full_formula"], chunk_size=0)

        formula = [id["full_formula"] for id in formula]

        return formula[0]

    def _get_data(self, mp_ids):
        """Gets features of the Materials IDs from the database.

         Args:
             mp_ids (list): List of MP-IDs

         Returns:
             df_input(DataFrame): Pandas DataFrame whcih conatains feature
                from Materials Project Database of the input samples


        """
        chunk_size = 1000
        sublists = [
            mp_ids[i : i + chunk_size] for i in range(0, len(mp_ids), chunk_size)
        ]

        # MPRester.supported_properties
        properties = [
            "theoretical",
            "energy_per_atom",
            "formation_energy_per_atom",
            "e_above_hull",
            "icsd_ids",
            "material_id",
            "structure",
        ]

        data = []
        # Get all materials from MP by mpid
        with self._mpr() as m:
            for sublist in sublists:
                data += m.query(
                    {"material_id": {"$in": sublist}}, properties=properties
                )

        df_input = pd.DataFrame(data)

        # PU_label is 1 (0) if experimental crystal structure exists (doesn't exist)
        df_input["PU_label"] = df_input.theoretical.apply(
            lambda x: 0 if x == True else 1
        )

        return df_input

    def _check_for_block(self, comp):
        """
        Check whether compound belongs to F-Block or Non F-Block.

        Args:
            comp (pymatgen.core.composition.Composition): Contains the
                composition of the the input sample

        Returns:
            fblock (Bool): 0 if it belongs to F-block, 1 if 
                it belongs to non-fblock

        """

        # check which block it belongs to
        if comp.contains_element_type("f-block") == True:
            return 0  # Belongs to F-block
        else:
            return 1  # Belongs to Non F-block

    def _check_for_theoretical(self, updated_input):
        """
        Check if the input is theoretical or experimental and drop the experimental compounds.

        Args:
            updated_input (criteria): The condition on which the data is
                extracted

        Returns:
            mp_ids (list): List of MP-IDs
            df_input(DataFrame): Pandas DataFrame whcih conatains features
                from Materials Project Database of the input samples

        """

        # Get MP-IDs
        mp_id = self._get_mp_id(updated_input)

        # Get Features of input
        df_input = self._get_data(mp_id)

        # Print the MP-ID for experimental compounds
        for idx, elem in enumerate(mp_id):
            if df_input.theoretical.values[idx] == False:
                logging.info(
                    'The given input with material_id "%s" exists.',
                    df_input.material_id.values[idx],
                )

        # Update the df_input by removing experimental compounds
        j = df_input[(df_input.theoretical == False)].index
        df_input = df_input.drop(j)

        return mp_id, df_input

    def _check_if_already_exists(self, comp, updated_input):
        """Check if the input exists in pre-trained model.

        Args:
            updated_input (criteria): The condition on which the data is
                extracted
            comp (pymatgen.core.composition.Composition): Contains the
                composition of the the input sample

        Returns:
            mp_ids (list): List of MP-IDs
            df_input(DataFrame): Pandas DataFrame whcih conatains features
                from Materials Project Database of the input samples
            fblock (Bool): 0 if it belongs to F-block, 1 if it belongs to 
                Non F block
            synth_scores (list): Synthesizability scores (between 0 and 1) of
                given sample.

        """

        mp_id, df_input = self._check_for_theoretical(updated_input)
        fblock = self._check_for_block(comp)

        # Intialise Synthesizability Score
        synth_score = []

        # Load F-Block pre-trained Model
        if fblock == 0:
            fblock_trained = pd.read_json(
                self.model_dir
                + "/Model_Data/0384056001594797986/PUPredict_Data/fblock_trained_model.json"
            )
            mp_id_theo = df_input.material_id.values
            for i in mp_id_theo:
                if i in fblock_trained.material_id.values:
                    synth_score.append(
                        fblock_trained.synth_score[
                            fblock_trained.material_id == i
                        ].values
                    )
                    j = df_input[(df_input.material_id == i)].index
                    df_input = df_input.drop(j)

        # Load Non F-Block pre-trained Model
        elif fblock == 1:
            non_fblock_trained = pd.read_json(
                self.model_dir
                + "/Model_Data/0384056001594797986/PUPredict_Data/non_fblock_trained_model.json"
            )
            mp_id_theo = df_input.material_id.values
            for i in mp_id_theo:
                if i in non_fblock_trained.material_id.values:
                    synth_score.append(
                        non_fblock_trained.synth_score[
                            non_fblock_trained.material_id == i
                        ].values
                    )
                    j = df_input[(df_input.material_id == i)].index
                    df_input = df_input.drop(j)

        return fblock, mp_id, synth_score, df_input

    def _extract_features(df_input):
        """
        Extract features using Matminer from the 'structure' column in
            df_input

         Args:
             df_input (DataFrame): Pandas DataFrame whcih conatains features
                from Materials Project Database of the input samples

         Returns:
             df_extracted (DataFrame): Pandas DataFrame which contains
                features of input samples extracted using Matminer

        """

        # Dropping the 'theoretical' column
        df_input.drop(columns=["theoretical"], inplace=True)

        # Extracting the features
        dfeat = DensityFeatures()
        symmfeat = GlobalSymmetryFeatures()
        mfeat = Meredig()
        cefeat = CohesiveEnergy()

        df_input["density_test"] = df_input.structure.apply(
            lambda x: dfeat.featurize(x)[0]
        )
        df_input["vpa"] = df_input.structure.apply(lambda x: dfeat.featurize(x)[1])
        df_input["packing fraction"] = df_input.structure.apply(
            lambda x: dfeat.featurize(x)[2]
        )
        df_input["spacegroup_num"] = df_input.structure.apply(
            lambda x: symmfeat.featurize(x)[0]
        )
        df_input["cohesive_energy"] = df_input.apply(
            lambda x: cefeat.featurize(
                x.structure.composition,
                formation_energy_per_atom=x.formation_energy_per_atom,
            )[0],
            axis=1,
        )
        df_input["mean AtomicWeight"] = df_input.structure.apply(
            lambda x: mfeat.featurize(x.composition)[-17]
        )
        df_input["range AtomicRadius"] = df_input.structure.apply(
            lambda x: mfeat.featurize(x.composition)[-12]
        )
        df_input["mean AtomicRadius"] = df_input.structure.apply(
            lambda x: mfeat.featurize(x.composition)[-11]
        )
        df_input["range Electronegativity"] = df_input.structure.apply(
            lambda x: mfeat.featurize(x.composition)[-10]
        )
        df_input["mean Electronegativity"] = df_input.structure.apply(
            lambda x: mfeat.featurize(x.composition)[-9]
        )

        # Drop 'structure' column
        df_input.drop(columns=["structure"], inplace=True)

        # ignore compounds that failed to featurize
        df_extracted = df_input.fillna(df_input.mean()).query("cohesive_energy > 0.0")

        # Re-arranging the 'PU Label' column
        pu_label = df_extracted["PU_label"]
        df_extracted = df_extracted.drop(["PU_label"], axis=1)
        df_extracted["PU_label"] = pu_label

        # Drop the icsd_ids column
        df_extracted.drop(columns=["icsd_ids"], inplace=True)

        return df_extracted

    def _append_data(self, df_extracted, fblock):
        """Append featurized data.

        Args:
            df_extracted (DataFrame): Pandas DataFrame which contains features
                of input samples extracted using Matminer
            fblock (Bool): 0 if it belongs to F-block, 1 if it
                belongs to Non F block

        Returns:
            df (DataFrame): Pandas DataFrame which contains features of
                input samples extracted using Matmimer along with training
                data
            input_ids(numpy array): MP-IDs of updated input compounds

        """

        # append data to the respective block
        input_ids = df_extracted.material_id.values

        # If it belongs to 'F-block'
        if fblock == 0:
            # F-Block training Dataset
            fblock_features = pd.read_json(
                self.model_dir
                + "/Model_Data/0384056001594797986/PUPredict_Data/fblock_extracted_features.json"
            )
            df = fblock_features.append(df_extracted)
            df.reset_index(drop=True, inplace=True)

        # If it belongs to 'Non F-Block'
        else:
            # Non F-Block training dataset
            non_fblock_features = pd.read_json(
                self.model_dir
                + "/Model_Data/0384056001594797986/PUPredict_Data/non_fblock_extracted_features.json"
            )
            df = non_fblock_features.append(df_extracted)
            df.reset_index(drop=True, inplace=True)

        return input_ids, df

    def _do_pulearner(input_ids, df):
        """Train the model using PULearner.

        Args:
        df (DataFrame): Pandas DataFrame which contains features of input
            samples extracted using Matmimer along with training data
        input_ids (numpy array): MP-IDs of updated input compounds


        Returns:
            synth_scores (list): Synthesizability scores (between 0 and 1) of 
                given sample.

        """

        # Run PUMML
        df.to_json("test.json")
        pul = PULearner()

        # Set hyperparameters
        n_splits = 10  # kfold CV
        n_repeats = 3  # Repeat the entire kfold CV n times for averaging
        n_bags = 100  # bags for bootstrap aggregating.

        pu_stats = pul.cv_baggingDT(
            "test.json", splits=n_splits, repeats=n_repeats, bags=n_bags
        )
        df1 = pul.df_U.copy()
        df1["synth_score"] = pu_stats["prob"]

        # Initialsing synth_score
        synth_score = []
        for i in range(len(input_ids)):
            synth_score.append(df1.synth_score[df1.material_id == input_ids[i]].values)

        return synth_score

    def synth_score_from_mpid(self, mpid):
        """
        Takes input in form of Materials Project ID and returns
        the synthesizability score for the same.

        Args:
            mpid (str): For example, 'mp-1234'

         Returns:
            synth_scores (list): Synthesizability scores (between 0 and 1) of
                given sample.

        """

        # Updating the criteria
        updated_input = {"material_id": mpid}

        # Extracting formula
        formula = self._get_formula(updated_input)

        # Extracting composition
        comp = Composition(formula)

        # Extracting MP-IDs
        mp_id = self._get_mp_id(updated_input)

        # If Input given is invalid
        if len(mp_id) == 0:
            return "No such compound exists in Materials Project Database"

        # Checking if it already exists in pre-trained model
        fblock, mp_id, synth_score, df_input = self._check_if_already_exists(
            comp, updated_input
        )

        # If found in pre-trained models, report synth_scores
        if df_input.shape[0] == 0:
            if len(synth_score) != 0:
                logging.info("The synthesizability scores are %s", synth_score)
                return synth_score

        # Else, train the model and report scores
        else:
            df_extracted = self._extract_features(df_input)
            input_ids, df = self._append_data(df_extracted, fblock)
            score = self._do_pulearner(input_ids, df)
            synth_score.append(score)
            logging.info("The synthesizability scores are %s", synth_score)
            return synth_score

    def synth_score_from_formula(self, formula):
        """
        Takes input in form of an unreduced crystal composition and returns
        the synthesizability score for the same.

        Args:
            formula (str): For example: 'Ba2Yb2Al4Si2N10O4'

         Returns:
            synth_scores (list): Synthesizability scores (between 0 and 1) of
                 given sample.

        """

        # Updating the criteria
        updated_input = {"full_formula": formula}

        # Extracting the composition
        comp = Composition(formula)

        # Extracting the MP-IDs
        mp_id = self._get_mp_id(updated_input)

        # If Input given is invalid
        if len(mp_id) == 0:
            return "No such compound exists in Materials Project Database"

        # Checking if it already exists in pre-trained model
        fblock, mp_id, synth_score, df_input = self._check_if_already_exists(
            comp, updated_input
        )

        # If found in pre-trained models, report synth_scores
        if df_input.shape[0] == 0:
            if len(synth_score) != 0:
                logging.info("The synthesizability scores are %s", synth_score)
                return synth_score

        # Else, train the model and report scores
        else:
            df_extracted = self._extract_features(df_input)
            input_ids, df = self._append_data(df_extracted, fblock)
            score = self._do_pulearner(input_ids, df)
            synth_score.append(score)
            logging.info("The synthesizability scores are %s", synth_score)
            return synth_score

    def synth_score_from_structure(self, structure):
        """
        Takes input in form of Pymatgen Structure Object and returns the
        synthesizability score for the same.

        Args:
            structure (pymatgen.core.structure.Structure): structure input.

        Returns:
            synth_scores (list): Synthesizability scores (between 0 and 1) of
                given sample.

        """

        # Extract formula and remove the spaces between them
        formula = structure.formula
        formula = formula.replace(" ", "")

        synth_score = self.synth_score_from_formula(formula)
        return synth_score