-
Notifications
You must be signed in to change notification settings - Fork 3
/
sampler.py
66 lines (50 loc) · 2.07 KB
/
sampler.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import pandas as pd
import torch
import torch.utils.data
import torchvision
class Imbalanced_Dataset_Sampler(torch.utils.data.sampler.Sampler):
"""Samples elements randomly from a given list of indices for imbalanced dataset
Arguments:
indices: a list of indices
num_samples: number of samples to draw
callback_get_label: a callback-like function which takes two arguments - dataset and index
"""
def __init__(
self,
dataset,
labels
):
# if indices is not provided, all elements in the dataset will be considered
self.indices = list(range(len(dataset)))
# if num_samples is not provided, draw `len(indices)` samples in each iteration
self.num_samples = len(self.indices)
# distribution of classes in the dataset
df = pd.DataFrame()
df["label"] = self._get_labels(dataset) if labels is None else labels
df.index = self.indices
df = df.sort_index()
label_to_count = df["label"].value_counts()
weights = 1.0 / label_to_count[df["label"]]
self.weights = torch.DoubleTensor(weights.to_list())
def _get_labels(self, dataset):
return dataset.get_multi_labels()
def __iter__(self):
return (self.indices[i] for i in torch.multinomial(self.weights, self.num_samples, replacement=True))
def __len__(self):
return self.num_samples
def __call__(self):
return self
from torch.utils.data import WeightedRandomSampler
class Weighted_Random_Sampler():
def __init__(
self,
dataset,
labels
):
self.class_counts = pd.Series(labels).value_counts().to_list()
self.num_samples = len(dataset)
self.labels = labels
self.class_weights = [self.num_samples / self.class_counts[i] for i in range(len(self.class_counts))]
self.weights = [self.class_weights[labels[i]] for i in range(int(self.num_samples))]
def __call__(self):
return WeightedRandomSampler(torch.DoubleTensor(self.weights), int(self.num_samples))