# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/01_download.ipynb (unless otherwise specified).

__all__ = ['BaseIO', 'BaseDownloader', 'NumeraiClassicDownloader', 'KaggleDownloader', 'PandasDataReader',
           'AwesomeCustomDownloader']

# Cell
import os
import glob
import json
import shutil
import pandas as pd
from tqdm.auto import tqdm
from rich.tree import Tree
from functools import partial
from numerapi import NumerAPI
import pandas_datareader as web
import matplotlib.pyplot as plt
from google.cloud import storage
from rich.console import Console
from typeguard import typechecked
from datetime import datetime as dt
from pathlib import Path, PosixPath
from abc import ABC, abstractmethod
from rich import print as rich_print
from dateutil.relativedelta import relativedelta
from pandas_datareader._utils import RemoteDataError

from .numerframe import NumerFrame

# Cell
@typechecked
class BaseIO(ABC):
    """
    Basic functionality for IO (downloading and uploading).

    :param directory_path: Base folder for IO. Will be created if it does not exist.
    """
    def __init__(self, directory_path: str):
        self.dir = Path(directory_path)
        self._create_directory()

    def remove_base_directory(self):
        """Remove directory with all contents."""
        abs_path = self.dir.resolve()
        rich_print(
            f":warning: [red]Deleting directory for '{self.__class__.__name__}[/red]' :warning:\nPath: '{abs_path}'"
        )
        shutil.rmtree(abs_path)

    def download_file_from_gcs(self, bucket_name: str, gcs_path: str):
        """
        Get file from GCS bucket and download to local directory.
        :param gcs_path: Path to file on GCS bucket.
        """
        blob_path = str(self.dir.resolve())
        blob = self._get_gcs_blob(bucket_name=bucket_name, blob_path=blob_path)
        blob.download_to_filename(gcs_path)
        rich_print(
            f":cloud: :page_facing_up: Downloaded GCS object '{gcs_path}' from bucket '{blob.bucket.id}' to local directory '{blob_path}'. :page_facing_up: :cloud:"
        )

    def upload_file_to_gcs(self, bucket_name: str, gcs_path: str, local_path: str):
        """
        Upload file to some GCS bucket.
        :param gcs_path: Path to file on GCS bucket.
        """
        blob = self._get_gcs_blob(bucket_name=bucket_name, blob_path=gcs_path)
        blob.upload_from_filename(local_path)
        rich_print(
            f":cloud: :page_facing_up: Local file '{local_path}' uploaded to '{gcs_path}' in bucket {blob.bucket.id}:page_facing_up: :cloud:"
        )

    def download_directory_from_gcs(self, bucket_name: str, gcs_path: str):
        """
        Copy full directory from GCS bucket to local environment.
        :param gcs_path: Name of directory on GCS bucket.
        """
        blob_path = str(self.dir.resolve())
        blob = self._get_gcs_blob(bucket_name=bucket_name, blob_path=blob_path)
        for gcs_file in glob.glob(gcs_path + "/**", recursive=True):
            if os.path.isfile(gcs_file):
                blob.download_to_filename(blob_path)
        rich_print(
            f":cloud: :folder: Directory '{gcs_path}' from bucket '{blob.bucket.id}' downloaded to '{blob_path}' :folder: :cloud:"
        )

    def upload_directory_to_gcs(self, bucket_name: str, gcs_path: str):
        """
        Upload full base directory to GCS bucket.
        :param gcs_path: Name of directory on GCS bucket.
        """
        blob = self._get_gcs_blob(bucket_name=bucket_name, blob_path=gcs_path)
        for local_path in glob.glob(str(self.dir) + "/**", recursive=True):
            if os.path.isfile(local_path):
                blob.upload_from_filename(local_path)
        rich_print(
            f":cloud: :folder: Directory '{self.dir}' uploaded to '{gcs_path}' in bucket {blob.bucket.id} :folder: :cloud:"
        )

    def _get_gcs_blob(self, bucket_name: str, blob_path: str) -> storage.Blob:
        """ Create blob that interacts with Google Cloud Storage (GCS). """
        client = storage.Client()
        # https://console.cloud.google.com/storage/browser/[bucket_name]
        bucket = client.get_bucket(bucket_name)
        blob = bucket.blob(blob_path)
        return blob

    def _append_folder(self, folder: str) -> Path:
        """
        Return base directory Path object appended with 'folder'.
        Create directory if it does not exist.
        """
        dir = Path(self.dir / folder)
        dir.mkdir(parents=True, exist_ok=True)
        return dir

    def _create_directory(self):
        """ Create base directory if it does not exist. """
        if not self.dir.is_dir():
            rich_print(
                f"No existing directory found at '[blue]{self.dir}[/blue]'. Creating directory..."
            )
            self.dir.mkdir(parents=True, exist_ok=True)

    @property
    def get_all_files(self) -> list:
        """ Return all paths of contents in directory. """
        return list(self.dir.iterdir())

    @property
    def is_empty(self) -> bool:
        """ Check if directory is empty. """
        return not bool(self.get_all_files)


# Cell
@typechecked
class BaseDownloader(BaseIO):
    """
    Abstract base class for downloaders.

    :param directory_path: Base folder to download files to.
    """
    def __init__(self, directory_path: str):
        super().__init__(directory_path=directory_path)

    @abstractmethod
    def download_training_data(self, *args, **kwargs):
        """ Download all necessary files needed for training. """
        ...

    @abstractmethod
    def download_inference_data(self, *args, **kwargs):
        """ Download minimal amount of files needed for weekly inference. """
        ...

    @staticmethod
    def _load_json(file_path: str, verbose=False, *args, **kwargs) -> dict:
        """ Load JSON from file and return as dictionary. """
        with open(Path(file_path)) as json_file:
            json_data = json.load(json_file, *args, **kwargs)
        if verbose:
            rich_print(json_data)
        return json_data

    def __call__(self, *args, **kwargs):
        """
        The most common use case will be to get weekly inference data. So calling the class itself returns inference data.
        """
        self.download_inference_data(*args, **kwargs)

# Cell
class NumeraiClassicDownloader(BaseDownloader):
    """
    WARNING: Versions 1 and 2 (legacy data) are deprecated. Only supporting version 3+.

    Downloading from NumerAPI for Numerai Classic data. \n
    :param directory_path: Base folder to download files to. \n
    All *args, **kwargs will be passed to NumerAPI initialization.
    """
    def __init__(self, directory_path: str, *args, **kwargs):
        super().__init__(directory_path=directory_path)
        self.napi = NumerAPI(*args, **kwargs)
        self.current_round = self.napi.get_current_round()
        # NumerAPI filenames corresponding to version, class and data type
        self.version_mapping = {"3": {
            "train": {
                "int8": [
                    "v3/numerai_training_data_int8.parquet",
                    "v3/numerai_validation_data_int8.parquet"
                ],
                "float": [
                    "v3/numerai_training_data.parquet",
                    "v3/numerai_validation_data.parquet"
                ]
            },
            "inference": {
                "int8": ["v3/numerai_tournament_data_int8.parquet"],
                "float": ["v3/numerai_tournament_data.parquet"]
            },
            "live": {
                "int8": ["v3/numerai_live_data_int8.parquet"],
                "float": ["v3/numerai_live_data.parquet"]
            },
            "example": [
                "v3/example_predictions.parquet",
                "v3/example_validation_predictions.parquet"
            ]
        },
            "4": {
                "train": {
                    "int8": [
                        "v4/train_int8.parquet",
                        "v4/validation_int8.parquet"
                    ],
                    "float": [
                        "v4/train.parquet",
                        "v4/validation.parquet"
                    ]
                },
                "inference": {
                    "int8": ["v4/live_int8.parquet"],
                    "float": ["v4/live.parquet"]
                },
                "live": {
                    "int8": ["v4/live_int8.parquet"],
                    "float": ["v4/live.parquet"]
                },
                "example": [
                    "v4/live_example_preds.parquet",
                    "v4/validation_example_preds.parquet"
                ]
            }
        }

    def download_training_data(
        self, subfolder: str = "", version: int = 4, int8: bool = False
    ):
        """
        Get Numerai classic training and validation data.
        :param subfolder: Specify folder to create folder within base directory root.
        Saves in base directory root by default.
        :param version: Numerai dataset version (3=1050+ features dataset (parquet))
        :param int8: Integer version of data
        """
        data_type = "int8" if int8 else "float"
        train_val_files = self._get_version_mapping(version)["train"][data_type]
        for file in train_val_files:
            dest_path = self.__get_dest_path(subfolder, file)
            self.download_single_dataset(
                filename=file,
                dest_path=dest_path
            )

    def download_inference_data(
        self,
        subfolder: str = "",
        version: int = 4,
        int8: bool = False,
        round_num: int = None,
    ):
        """
        Get Numerai classic inference (tournament) data.
        If only minimal live data is needed, consider .download_live_data.
        :param subfolder: Specify folder to create folder within base directory root.
        Saves in base directory root by default.
        :param version: Numerai dataset version (2=super massive dataset (parquet))
        :param int8: Integer version of data
        :param round_num: Numerai tournament round number. Downloads latest round by default.
        """
        data_type = "int8" if int8 else "float"
        inference_files = self._get_version_mapping(version)["inference"][data_type]
        for file in inference_files:
            dest_path = self.__get_dest_path(subfolder, file)
            self.download_single_dataset(
                filename=file,
                dest_path=dest_path,
                round_num=round_num
            )

    def download_single_dataset(
        self, filename: str, dest_path: str, round_num: int = None
    ):
        """
        Download one of the available datasets through NumerAPI.

        :param filename: Name as listed in NumerAPI (Check NumerAPI().list_datasets() for full overview)
        :param dest_path: Full path where file will be saved.
        :param round_num: Numerai tournament round number. Downloads latest round by default.
        """
        rich_print(
            f":file_folder: [green]Downloading[/green] '{filename}' :file_folder:"
        )
        self.napi.download_dataset(
            filename=filename,
            dest_path=dest_path,
            round_num=round_num
        )

    def download_live_data(
            self,
            subfolder: str = "",
            version: int = 4,
            int8: bool = False,
            round_num: int = None
    ):
        """
        Download all live data in specified folder for given version (i.e. minimal data needed for inference).

        :param subfolder: Specify folder to create folder within directory root.
        Saves in directory root by default.
        :param version: Numerai dataset version (2=super massive dataset (parquet))
        :param int8: Integer version of data
        :param round_num: Numerai tournament round number. Downloads latest round by default.
        """
        data_type = "int8" if int8 else "float"
        live_files = self._get_version_mapping(version)["live"][data_type]
        for file in live_files:
            dest_path = self.__get_dest_path(subfolder, file)
            self.download_single_dataset(
                filename=file,
                dest_path=dest_path,
                round_num=round_num
            )

    def download_example_data(
        self, subfolder: str = "", version: int = 4, round_num: int = None
    ):
        """
        Download all example prediction data in specified folder for given version.

        :param subfolder: Specify folder to create folder within base directory root.
        Saves in base directory root by default.
        :param version: Numerai dataset version (2=super massive dataset (parquet))
        :param round_num: Numerai tournament round number. Downloads latest round by default.
        """
        example_files = self._get_version_mapping(version)["example"]
        for file in example_files:
            dest_path = self.__get_dest_path(subfolder, file)
            self.download_single_dataset(
                filename=file,
                dest_path=dest_path,
                round_num=round_num
            )

    def get_classic_features(self, subfolder: str = "", filename="v4/features.json", *args, **kwargs) -> dict:
        """
        Download feature overview (stats and feature sets) through NumerAPI and load as dict.
        :param subfolder: Specify folder to create folder within base directory root.
        Saves in base directory root by default.
        :param filename: name for feature overview.
        Currently defined as 'features.json' in NumerAPI and used as default.
        *args, **kwargs will be passed to the JSON loader.
        """
        dest_path = self.__get_dest_path(subfolder, filename)
        self.download_single_dataset(filename=filename,
                                     dest_path=dest_path)
        json_data = self._load_json(dest_path, *args, **kwargs)
        return json_data

    def _get_version_mapping(self, version: int) -> dict:
        """ Check if data version is supported and return file mapping for version. """
        try:
            mapping_dictionary = self.version_mapping[str(version)]
        except KeyError:
            raise NotImplementedError(
                f"Version '{version}' is not available. Available versions are {list(self.version_mapping.keys())}"
            )
        return mapping_dictionary

    def __get_dest_path(self, subfolder: str, filename: str) -> str:
        """ Prepare destination path for downloading. """
        dir = self._append_folder(subfolder)
        dest_path = str(dir.joinpath(filename.split("/")[-1]))
        return dest_path

# Cell
class KaggleDownloader(BaseDownloader):
    """
    Download awesome financial data from Kaggle.

    For authentication, make sure you have a directory called .kaggle in your home directory
    with therein a kaggle.json file. kaggle.json should have the following structure: \n
    `{"username": USERNAME, "key": KAGGLE_API_KEY}` \n
    More info on authentication: github.com/Kaggle/kaggle-api#api-credentials \n

    More info on the Kaggle Python API: kaggle.com/donkeys/kaggle-python-api \n

    :param directory_path: Base folder to download files to.
    """
    def __init__(self, directory_path: str):
        self.__check_kaggle_import()
        super().__init__(directory_path=directory_path)

    def download_inference_data(self, kaggle_dataset_path: str):
        """
        Download arbitrary Kaggle dataset.
        :param kaggle_dataset_path: Path on Kaggle (URL slug on kaggle.com/)
        """
        self.download_training_data(kaggle_dataset_path)

    def download_training_data(self, kaggle_dataset_path: str):
        """
        Download arbitrary Kaggle dataset.
        :param kaggle_dataset_path: Path on Kaggle (URL slug on kaggle.com/)
        """
        import kaggle
        kaggle.api.dataset_download_files(kaggle_dataset_path,
                                          path=self.dir, unzip=True)

    @staticmethod
    def __check_kaggle_import():
        try:
            import kaggle
        except OSError:
            raise OSError("Could not find kaggle.json credentials. Make sure it's located in /home/runner/.kaggle. Or use the environment method. Check github.com/Kaggle/kaggle-api#api-credentials for more information on authentication.")

# Cell
class PandasDataReader(BaseDownloader):
    """
    Download financial data using Pandas Datareader.

    :param directory_path: Base folder to download files to. \n
    :param tickers: list of tickers used for downloading. \n
    :param backend: Data provider you want to use. Yahoo Finance by default. \n
    Check pydata.github.io/pandas-datareader/stable/readers/index.html to see all data readers.
    """
    def __init__(self, directory_path: str, tickers: list, backend: str = 'yahoo'):
        super().__init__(directory_path=directory_path)
        self.tickers = tickers
        self.backend = backend
        self.current_time = dt.now()

    def download_inference_data(self, save_path: str = None, *args, **kwargs):
        """ Download a year of data. """
        start = self.current_time - relativedelta(years=1)
        dataf = self._get_all_ticker_data(start=start, *args, **kwargs)
        save_path = save_path if save_path else self.__format_default_save_path(start)
        dataf.to_parquet(save_path)

    def download_training_data(self, start: dt, save_path: str = None, *args, **kwargs):
        """
        Download full training dataset with given start_date.
        :param start: datetime object defining starting date.
        :param save_path: Path for Parquet file.
        """
        dataf = self._get_all_ticker_data(start=start, *args, **kwargs)
        save_path = save_path if save_path else self.__format_default_save_path(start)
        dataf.to_parquet(save_path)

    def download_live_data(self, save_path: str = None, *args, **kwargs):
        """ Download a month of data. """
        start = self.current_time - relativedelta(months=1)
        save_path = save_path if save_path else self.__format_default_save_path(start)
        dataf = self.get_live_data(*args, **kwargs)
        dataf.to_parquet(save_path)

    def get_live_data(self, *args, **kwargs) -> NumerFrame:
        """ Get a month of data as DataFrame. """
        start = self.current_time - relativedelta(months=1)
        return NumerFrame(self._get_all_ticker_data(start=start, *args, **kwargs))

    def _get_all_ticker_data(self, start: dt, *args, **kwargs) -> pd.DataFrame:
        """
        Get data for all tickers defined in class using given starting date.
        :param start: datetime object defining starting date.
        """
        func = partial(self.__get_ticker_data, start=start)
        results = []
        for tick in tqdm(self.tickers):
            try:
                res = func(ticker=tick, *args, **kwargs)
            except RemoteDataError:
                rich_print(f":warning: WARNING: No data found for ticker: [red]'{tick}'[/red]. :warning:")
                continue
            results.append(res)
        dataf = pd.concat(results)
        return dataf

    def __get_ticker_data(self, ticker: str, start: dt, *args, **kwargs) -> pd.DataFrame:
        dataf = web.DataReader(ticker, self.backend, start, self.current_time, *args, **kwargs)
        dataf['ticker'] = ticker
        dataf.index.names = ['date']
        dataf = dataf.reset_index(drop=False)
        return dataf

    def __format_default_save_path(self, start: dt):
        return f"{self.dir}/{self.backend}_{start.strftime('%Y%m%d')}_{self.current_time.strftime('%Y%m%d')}.parquet"

# Cell
class AwesomeCustomDownloader(BaseDownloader):
    """
    TEMPLATE -
    Download awesome financial data from who knows where.

    :param directory_path: Base folder to download files to.
    """
    def __init__(self, directory_path: str):
        super().__init__(directory_path=directory_path)

    def download_inference_data(self, *args, **kwargs):
        """ (minimal) weekly inference downloading here. """
        ...

    def download_training_data(self, *args, **kwargs):
        """ Training + validation dataset downloading here. """
        ...