Add new CNN class for MNIST handling and update model usage in API #76
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
This PR adds a new CNN class in
src/cnn.py
that handles the MNIST dataset. The CNN class is responsible for loading and preprocessing the data, defining the CNN architecture, and training the model. The existing code insrc/main.py
for loading and preprocessing the MNIST dataset and defining the PyTorch model has been removed and replaced with the new CNN class. Additionally, the usage of the model insrc/api.py
has been updated to use the new CNN model instead of the previous Net model.Summary of Changes
src/cnn.py
to contain the new CNN class.src/cnn.py
for building the CNN model.CNN
insrc/cnn.py
that inherits fromtorch.nn.Module
.__init__
method inCNN
class to define the layers of the CNN.forward
method inCNN
class to perform the forward pass of the CNN.load_data
function insrc/cnn.py
to load and preprocess the MNIST dataset.train
function insrc/cnn.py
to train the CNN model on the MNIST dataset.main
function insrc/cnn.py
to create an instance of the CNN class, load the data, and train the model.src/main.py
to import theCNN
class fromsrc/cnn.py
.src/main.py
for loading and preprocessing the MNIST dataset and defining the PyTorch model.src/main.py
to create an instance of theCNN
class and call the train method.src/api.py
to import theCNN
class fromsrc/cnn.py
.Net
model with the newCNN
model insrc/api.py
.src/api.py
to match the location of the saved CNN model.Please review and merge this PR to incorporate the changes.
Fixes #9.
🎉 Latest improvements to Sweep:
💡 To get Sweep to edit this pull request, you can: