# AUTOGENERATED! DO NOT EDIT! File to edit: notebooks/data_loader.ipynb (unless otherwise specified).

__all__ = ["DataLoader", "create_weighted_random_sampler", "create_sample_weights_from_multilabels", "lookup"]

# Cell
from collections import Counter

from torch.utils.data import DataLoader as PytorchDataLoader
from torch.utils.data import Dataset, WeightedRandomSampler
from tqdm import tqdm


class DataLoader(PytorchDataLoader):
    pass


def create_weighted_random_sampler(weights, n_samples=None, replacement=True):
    if n_samples is None:
        n_samples = len(weights)
    return WeightedRandomSampler(weights, num_samples=n_samples, replacement=replacement)


def create_sample_weights_from_multilabels(multilabels, no_label_sample_factor=1):
    all_labels = []
    labels_for_samples = []
    sample_weights = []
    n_samples = len(multilabels)

    for labels in tqdm(multilabels):
        labels_for_samples.append(labels)
        all_labels.extend(labels)

    label_counts = Counter(all_labels)

    weight_lookup = {label: 1 / n_labels for label, n_labels in label_counts.items()}

    for labels_for_sample in labels_for_samples:
        if len(labels_for_sample) == 0:
            sample_weights.append(1 / n_samples ** no_label_sample_factor)
            continue
        sample_weights.append(sum([weight_lookup[label] for label in labels_for_sample]) / len(labels_for_sample))
    return sample_weights


lookup = {
    "DataLoader": DataLoader,
    "create_weighted_random_sampler": create_weighted_random_sampler,
    "create_sample_weights_from_multilabels": create_sample_weights_from_multilabels,
}
