Skip to content

Commit

Permalink
add todo tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
pwochner committed Sep 27, 2022
1 parent dc07e83 commit 0c9e9a2
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 15 deletions.
25 changes: 11 additions & 14 deletions catdog_segmentation_model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,17 @@
class CatDogUNet:
def __init__(self):
filename = "unet_model.ckpt"
if not os.path.exists(filename):
model_path = os.path.join(
"https://connectionsworkshop.blob.core.windows.net/pets", filename
)
r = requests.get(model_path)
with open(filename, "wb") as outfile:
outfile.write(r.content)
self.model = unet.CatDogUNet(num_classes=1)
### TODO ###
### Download the file containing the weights of the pre-trained model from url
### if the file doesn't exist already locally, and write it to a file.
### After that, create an instance of the model class.
### The code to load the model weights from file and evaluate the model is
### already provided.
self.model.load_state_dict(torch.load(filename))
self.model.eval()

def predict(self, image):
# transform input image (as required by model)
# pre-process input image (as required by model)
transform_input = transforms.Compose([transforms.Resize((192, 192)),])
image = image.values
image = image[:, :, 0:3] # make sure we have only 3 channels
Expand All @@ -31,9 +29,8 @@ def predict(self, image):
image = transform_input(image)

# make prediction
prediction = self.model(image)
return prediction
### TODO ###
### Apply the model to the pre-processed input image.
### Return the segmentation mask.
return "FIX ME"


# -------------
# - We need to return an image with class labels
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
name="catdog_segmentation_model",
version="0.0.1",
description="scivision plugin, using a UNet to segment cat/dog images",
url="https://github.com/pwochner/catdog_segmentation_model",
url="TODO: INSERT URL OF THE MODEL GITHUB REPO HERE",
packages=find_packages(),
install_requires=requirements,
python_requires=">=3.7",
Expand Down

0 comments on commit 0c9e9a2

Please sign in to comment.