diff --git a/src/main.py b/src/main.py index eb11d9c..c010fb3 100644 --- a/src/main.py +++ b/src/main.py @@ -7,14 +7,14 @@ from torchvision import datasets, transforms # Step 1: Load MNIST Data and Preprocess -transform = transforms.Compose([ - transforms.ToTensor(), - transforms.Normalize((0.5,), (0.5,)) -]) +transform = transforms.Compose( + [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))] +) -trainset = datasets.MNIST('.', download=True, train=True, transform=transform) +trainset = datasets.MNIST(".", download=True, train=True, transform=transform) trainloader = DataLoader(trainset, batch_size=64, shuffle=True) + # Step 2: Define the PyTorch Model class Net(nn.Module): def __init__(self): @@ -22,13 +22,15 @@ def __init__(self): self.fc1 = nn.Linear(28 * 28, 128) self.fc2 = nn.Linear(128, 64) self.fc3 = nn.Linear(64, 10) - + def forward(self, x): x = x.view(-1, 28 * 28) x = nn.functional.relu(self.fc1(x)) x = nn.functional.relu(self.fc2(x)) x = self.fc3(x) return nn.functional.log_softmax(x, dim=1) + + class Trainer: def __init__(self, learning_rate, model_path): self.model = Net() @@ -44,7 +46,7 @@ def train(self, epochs): loss = self.criterion(output, labels) loss.backward() self.optimizer.step() - + def save_model(self): torch.save(self.model.state_dict(), self.model_path)