Skip to content

Commit

Permalink
Sandbox run src/cnn.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Nov 26, 2023
1 parent d5b3bff commit 6f6638f
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions src/cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def forward(self, x):
output = nn.functional.log_softmax(x, dim=1)
return output


def train(network, train_loader, optimizer):
network.train()
for batch_idx, (data, target) in enumerate(train_loader):
Expand All @@ -39,13 +40,13 @@ def train(network, train_loader, optimizer):
loss.backward()
optimizer.step()


if __name__ == "__main__":
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)

network = CNN()
Expand Down

0 comments on commit 6f6638f

Please sign in to comment.