Skip to content

Commit

Permalink
Sandbox run src/main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sweep-nightly[bot] authored Nov 26, 2023
1 parent 35f37cd commit 52cb808
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,30 @@
from torchvision import datasets, transforms

# Step 1: Load MNIST Data and Preprocess
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
transform = transforms.Compose(
[transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

trainset = datasets.MNIST('.', download=True, train=True, transform=transform)
trainset = datasets.MNIST(".", download=True, train=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)


# Step 2: Define the PyTorch Model
class Net(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28 * 28, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 10)

def forward(self, x):
x = x.view(-1, 28 * 28)
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = self.fc3(x)
return nn.functional.log_softmax(x, dim=1)


class Trainer:
def __init__(self, learning_rate, model_path):
self.model = Net()
Expand All @@ -44,7 +46,7 @@ def train(self, epochs):
loss = self.criterion(output, labels)
loss.backward()
self.optimizer.step()

def save_model(self):
torch.save(self.model.state_dict(), self.model_path)

Expand Down

0 comments on commit 52cb808

Please sign in to comment.