# AUTOGENERATED! DO NOT EDIT! File to edit: nbs/02_shap.core.ipynb (unless otherwise specified).

__all__ = []

# Cell
from fastai.tabular.all import *

# Cell
def _prepare_data(dl:TabDataLoader, n_samples: Optional[int]=None):
    "Prepares dataloader data for `SHAP`"
    # Try to avoid concatenate big dataframes -> Sample dataframe before merging them
    if n_samples is not None and len(dl.cats) > n_samples:
        cats_df = dl.cats.sample(n=n_samples)
        conts_df = dl.conts.loc[cats_df.index]
    else:
        cats_df, conts_df = dl.cats, dl.conts

    return pd.merge(cats_df, conts_df, left_index=True, right_index=True)

# Cell
def _prepare_test_data(learn:Learner, test_data=None, n_samples:int=128):
    "Prepares train and test data for `SHAP`, pass in a learner with optional data"
    user_provided_test_data = test_data is not None
    if isinstance(test_data, pd.DataFrame):
        dl = learn.dls.test_dl(test_data)
    elif isinstance(test_data, TabDataLoader):
        dl = test_data
    elif test_data is None:
        try:
            dl = learn.dls[1]
        except IndexError:
            print('No validation dataloader, using `train`')
            dl = learn.dls[0]
    else:
        raise ValueError('Input is not supported. Please use either a `DataFrame` or `TabularDataLoader`')

    return _prepare_data(dl, None if user_provided_test_data else n_samples)

# Cell
def _predict(learn:TabularLearner, data:np.array):
    "Predict function for some data on a fastai model"
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = learn.model.to(device)
    dl = learn.dls[0]
    nb_cat_cols = len(dl.dataset.cat_names)
    nb_cont_cols = len(dl.dataset.cont_names)
    x_cat = torch.from_numpy(data[:, :nb_cat_cols]).to(device, torch.int64)
    x_cont = torch.from_numpy(data[:, -nb_cont_cols:]).to(device, torch.float32)
    with torch.no_grad():
        pred_probs = learn.model(x_cat, x_cont).cpu().numpy()
    return pred_probs