-
Notifications
You must be signed in to change notification settings - Fork 1
/
dataset.py
72 lines (65 loc) · 2.76 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import numpy as np
from torch.utils.data.dataset import Dataset
import torch
class YinYangDataset(Dataset):
def __init__(self, r_small=0.1, r_big=0.5, size=1000, seed=42, offset=0, transform=None):
super(YinYangDataset, self).__init__()
# using a numpy RNG to allow compatibility to other deep learning frameworks
self.rng = np.random.RandomState(seed)
self.transform = transform
self.r_small = r_small
self.r_big = r_big
self.__vals = []
self.__cs = []
self.class_names = ['yin', 'yang', 'dot']
for i in range(size):
# keep num of class instances balanced by using rejection sampling
# choose class for this sample
goal_class = self.rng.randint(3)
x, y, c = self.get_sample(goal=goal_class)
x, y = x+offset, y+offset
# add mirrod axis values
x_flipped = (1.+offset) - x
y_flipped = (1.+offset) - y
val = np.array([x, y, x_flipped, y_flipped])
self.__vals.append(val)
self.__cs.append(c)
def get_sample(self, goal=None):
# sample until goal is satisfied
found_sample_yet = False
while not found_sample_yet:
# sample x,y coordinates
x, y = self.rng.rand(2) * 2. * self.r_big
# check if within yin-yang circle
if np.sqrt((x - self.r_big)**2 + (y - self.r_big)**2) > self.r_big:
continue
# check if they have the same class as the goal for this sample
c = self.which_class(x, y)
if goal is None or c == goal:
found_sample_yet = True
break
return x, y, c
def which_class(self, x, y):
# equations inspired by
# https://link.springer.com/content/pdf/10.1007/11564126_19.pdf
d_right = self.dist_to_right_dot(x, y)
d_left = self.dist_to_left_dot(x, y)
criterion1 = d_right <= self.r_small
criterion2 = d_left > self.r_small and d_left <= 0.5 * self.r_big
criterion3 = y > self.r_big and d_right > 0.5 * self.r_big
is_yin = criterion1 or criterion2 or criterion3
is_circles = d_right < self.r_small or d_left < self.r_small
if is_circles:
return 2
return int(is_yin)
def dist_to_right_dot(self, x, y):
return np.sqrt((x - 1.5 * self.r_big)**2 + (y - self.r_big)**2)
def dist_to_left_dot(self, x, y):
return np.sqrt((x - 0.5 * self.r_big)**2 + (y - self.r_big)**2)
def __getitem__(self, index):
sample = (self.__vals[index].copy(), self.__cs[index])
if self.transform:
sample = self.transform(sample)
return sample
def __len__(self):
return len(self.__cs)