From efbc31703b82df722c4178f75959caedf5cdfa12 Mon Sep 17 00:00:00 2001 From: Giorgia Pitteri Date: Tue, 25 Jun 2024 13:17:06 +0200 Subject: [PATCH 1/3] Add labels to dataset object --- src/explib/datasets.py | 27 +++++++++++++++++++-------- 1 file changed, 19 insertions(+), 8 deletions(-) diff --git a/src/explib/datasets.py b/src/explib/datasets.py index 8a0941b..f756da3 100644 --- a/src/explib/datasets.py +++ b/src/explib/datasets.py @@ -21,6 +21,7 @@ class DequantizedDataset(torch.utils.data.Dataset): def __init__( self, dataset: T.Union[os.PathLike, torch.utils.data.Dataset, np.ndarray], + labels: T.Union[np.ndarray, torch.Tensor] = None, num_bits: int = 8, device: torch.device = None, ): @@ -33,6 +34,11 @@ def __init__( # self.dataset = self.dataset.to(device) + if not isinstance(labels, torch.Tensor): + labels = torch.Tensor(labels) + + self.labels = labels.to(device) + self.num_bits = num_bits self.num_levels = 2**num_bits self.transform = transforms.Compose( @@ -323,19 +329,24 @@ def __init__( MNIST(dataloc, train=train, download=True) dataset = idx2numpy.convert_from_file(path) + if scale: dataset = dataset[:, ::3, ::3] if flatten: dataset = dataset.reshape(dataset.shape[0], -1) + + if train: + rel_path = "MNIST/raw/train-labels-idx1-ubyte" + else: + rel_path = "MNIST/raw/t10k-labels-idx1-ubyte" + path = os.path.join(dataloc, rel_path) + labels = idx2numpy.convert_from_file(path) + if digit is not None: - if train: - rel_path = "MNIST/raw/train-labels-idx1-ubyte" - else: - rel_path = "MNIST/raw/t10k-labels-idx1-ubyte" - path = os.path.join(dataloc, rel_path) - labels = idx2numpy.convert_from_file(path) dataset = dataset[labels == digit] - super().__init__(torch.Tensor(dataset), num_bits=8, device=device) + labels = labels[labels == digit] + + super().__init__(torch.Tensor(dataset), labels=torch.Tensor(labels), num_bits=8, device=device) def __getitem__(self, index: int): if not isinstance(self.dataset, torch.Tensor): @@ -343,7 +354,7 @@ def __getitem__(self, index: int): else: x = self.dataset[index] x = self.transform(x) - return x, 0 + return x, self.labels[index] class MnistSplit(DataSplit): def __init__( From 5ddd683fe1143d6a6ba0313faa482ac8891999de Mon Sep 17 00:00:00 2001 From: giorgiapitteri Date: Tue, 2 Jul 2024 09:44:01 +0200 Subject: [PATCH 2/3] device set to cpu by default --- src/explib/datasets.py | 47 ++++++++++++++++++++++++------------------ 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/src/explib/datasets.py b/src/explib/datasets.py index f756da3..0571ed2 100644 --- a/src/explib/datasets.py +++ b/src/explib/datasets.py @@ -21,9 +21,9 @@ class DequantizedDataset(torch.utils.data.Dataset): def __init__( self, dataset: T.Union[os.PathLike, torch.utils.data.Dataset, np.ndarray], - labels: T.Union[np.ndarray, torch.Tensor] = None, + labels: T.Union[np.ndarray, torch.Tensor], num_bits: int = 8, - device: torch.device = None, + device: torch.device = "cpu", ): if isinstance(dataset, torch.utils.data.Dataset) or isinstance( dataset, np.ndarray @@ -32,11 +32,9 @@ def __init__( else: self.dataset = pd.read_csv(dataset).values - # self.dataset = self.dataset.to(device) if not isinstance(labels, torch.Tensor): labels = torch.Tensor(labels) - self.labels = labels.to(device) self.num_bits = num_bits @@ -49,9 +47,9 @@ def __init__( ) def __getitem__(self, index: int): - x, y = self.dataset[index] + x = self.dataset[index] x = Tensor(self.transform(x)) - return x, y + return x, self.labels[index] def __len__(self): return len(self.dataset) @@ -247,7 +245,8 @@ def __init__( dataloc: os.PathLike = None, train: bool = True, label: T.Optional[int] = None, - scale: bool = False + scale: bool = False, + device: torch.device = "cpu" ): rel_path = ( "FashionMNIST/raw/train-images-idx3-ubyte" @@ -262,21 +261,25 @@ def __init__( if scale: dataset = dataset[:, ::3, ::3] dataset = dataset.reshape(dataset.shape[0], -1) - if label is not None: - rel_path = ( + + rel_path = ( "FashionMNIST/raw/train-labels-idx1-ubyte" if train else "FashionMNIST/raw/t10k-labels-idx1-ubyte" ) - path = os.path.join(dataloc, rel_path) - labels = idx2numpy.convert_from_file(path) + path = os.path.join(dataloc, rel_path) + labels = idx2numpy.convert_from_file(path) + + if label is not None: dataset = dataset[labels == label] - super().__init__(dataset, num_bits=8) + labels = labels[labels == label] + + super().__init__(dataset, torch.Tensor(labels), num_bits=8, device=device) def __getitem__(self, index: int): x = Tensor(self.dataset[index].copy()) x = self.transform(x) - return x, 0 + return x, self.labels[index] class FashionMnistSplit(DataSplit): @@ -285,11 +288,13 @@ def __init__( dataloc: os.PathLike = None, val_split: float = 0.1, label: T.Optional[int] = None, + device: torch.device = "cpu" ): + self.label = label if dataloc is None: dataloc = os.path.join(os.getcwd(), "data") self.dataloc = dataloc - self.train = FashionMnistDequantized(self.dataloc, train=True, label=label) + self.train = FashionMnistDequantized(self.dataloc, train=True, label=label, device=device) shuffle = torch.randperm(len(self.train)) self.val = torch.utils.data.Subset( self.train, shuffle[: int(len(self.train) * val_split)] @@ -297,7 +302,7 @@ def __init__( self.train = torch.utils.data.Subset( self.train, shuffle[int(len(self.train) * val_split) :] ) - self.test = FashionMnistDequantized(self.dataloc, train=False, label=label) + self.test = FashionMnistDequantized(self.dataloc, train=False, label=label, device=device) def get_train(self) -> torch.utils.data.Dataset: return self.train @@ -318,7 +323,7 @@ def __init__( digit: T.Optional[int] = None, flatten=True, scale: bool = False, - device: torch.device = None + device: torch.device = "cpu" ): if train: rel_path = "MNIST/raw/train-images-idx3-ubyte" @@ -329,12 +334,12 @@ def __init__( MNIST(dataloc, train=train, download=True) dataset = idx2numpy.convert_from_file(path) - + if scale: dataset = dataset[:, ::3, ::3] if flatten: dataset = dataset.reshape(dataset.shape[0], -1) - + if train: rel_path = "MNIST/raw/train-labels-idx1-ubyte" else: @@ -346,7 +351,7 @@ def __init__( dataset = dataset[labels == digit] labels = labels[labels == digit] - super().__init__(torch.Tensor(dataset), labels=torch.Tensor(labels), num_bits=8, device=device) + super().__init__(torch.Tensor(dataset), torch.Tensor(labels), num_bits=8, device=device) def __getitem__(self, index: int): if not isinstance(self.dataset, torch.Tensor): @@ -363,8 +368,9 @@ def __init__( val_split: float = 0.1, digit: T.Optional[int] = None, scale: bool = False, - device: torch.device = None + device: torch.device = "cpu" ): + self.digit = digit if dataloc is None: dataloc = os.path.join(os.getcwd(), "data") self.dataloc = dataloc @@ -404,3 +410,4 @@ def __init__( if not os.path.exists(path): CIFAR10(dataloc, train=train, download=True) + From 38f962da8f93681536b3003f3120eefd1d76b034 Mon Sep 17 00:00:00 2001 From: giorgiapitteri Date: Fri, 12 Jul 2024 14:44:10 +0200 Subject: [PATCH 3/3] minor changes --- src/explib/datasets.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/explib/datasets.py b/src/explib/datasets.py index 0571ed2..cdfc037 100644 --- a/src/explib/datasets.py +++ b/src/explib/datasets.py @@ -32,7 +32,11 @@ def __init__( else: self.dataset = pd.read_csv(dataset).values + if not isinstance(self.dataset, torch.Tensor): + self.dataset = torch.tensor(self.dataset) + self.dataset = self.dataset.to(device) + if not isinstance(labels, torch.Tensor): labels = torch.Tensor(labels) self.labels = labels.to(device) @@ -277,7 +281,10 @@ def __init__( super().__init__(dataset, torch.Tensor(labels), num_bits=8, device=device) def __getitem__(self, index: int): - x = Tensor(self.dataset[index].copy()) + if not isinstance(self.dataset, torch.Tensor): + x = Tensor(self.dataset[index].copy()) + else: + x = self.dataset[index] x = self.transform(x) return x, self.labels[index]