From c1333ece180d5e2412e2ec2e2d68d752b38fbdc0 Mon Sep 17 00:00:00 2001 From: "sweep-nightly[bot]" <131841235+sweep-nightly[bot]@users.noreply.github.com> Date: Sun, 26 Nov 2023 00:53:06 +0000 Subject: [PATCH] feat: Updated src/api.py --- src/api.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/api.py b/src/api.py index 36c257a..db42381 100644 --- a/src/api.py +++ b/src/api.py @@ -1,13 +1,13 @@ -from fastapi import FastAPI, UploadFile, File -from PIL import Image import torch -from torchvision import transforms +from fastapi import FastAPI, File, UploadFile from main import Net # Importing Net class from main.py +from main import Trainer # Importing Trainer class from main.py +from PIL import Image +from torchvision import transforms # Load the model -model = Net() -model.load_state_dict(torch.load("mnist_model.pth")) -model.eval() +trainer = Trainer() +trainer.load_model("mnist_model.pth") # Assuming load_model method exists # Transform used for preprocessing the image transform = transforms.Compose([ @@ -23,6 +23,6 @@ async def predict(file: UploadFile = File(...)): image = transform(image) image = image.unsqueeze(0) # Add batch dimension with torch.no_grad(): - output = model(image) + output = trainer.get_model()(image) # Assuming get_model method exists _, predicted = torch.max(output.data, 1) return {"prediction": int(predicted[0])}