diff --git a/src/cnn.py b/src/cnn.py index f94dc99..0b1be16 100644 --- a/src/cnn.py +++ b/src/cnn.py @@ -36,12 +36,12 @@ def train_cnn(self, trainloader, epochs=3): torch.save(self.state_dict(), "mnist_cnn_model.pth") -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) + +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) cnn = CNN()