-
Notifications
You must be signed in to change notification settings - Fork 0
/
my_dataset.py
28 lines (22 loc) · 892 Bytes
/
my_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
import torch
from torch.utils.data import Dataset
from torchvision import transforms
class MyDataSet(Dataset):
def __init__(self, impedance_path: list, impedance_class: list):
self.impedance_path = impedance_path
self.impedance_class = impedance_class
def __len__(self):
return len(self.impedance_path)
def __getitem__(self, item):
impedance = np.load(self.impedance_path[item], allow_pickle=True)
impedance = impedance.astype(float)
# np.expand_dims(impedance, axis=0)
# impedance = impedance.transpose(2, 0, 1)
trans = transforms.Compose([transforms.ToTensor()])
trans(impedance)
impedance = torch.from_numpy(impedance).unsqueeze(0)
label = self.impedance_class[item]
label = torch.as_tensor(int(label))
print(label)
return impedance, label