diff --git a/catdog_segmentation_model/model.py b/catdog_segmentation_model/model.py index f8cac2d..c76ed78 100644 --- a/catdog_segmentation_model/model.py +++ b/catdog_segmentation_model/model.py @@ -9,19 +9,17 @@ class CatDogUNet: def __init__(self): filename = "unet_model.ckpt" - if not os.path.exists(filename): - model_path = os.path.join( - "https://connectionsworkshop.blob.core.windows.net/pets", filename - ) - r = requests.get(model_path) - with open(filename, "wb") as outfile: - outfile.write(r.content) - self.model = unet.CatDogUNet(num_classes=1) + ### TODO ### + ### Download the file containing the weights of the pre-trained model from url + ### if the file doesn't exist already locally, and write it to a file. + ### After that, create an instance of the model class. + ### The code to load the model weights from file and evaluate the model is + ### already provided. self.model.load_state_dict(torch.load(filename)) self.model.eval() def predict(self, image): - # transform input image (as required by model) + # pre-process input image (as required by model) transform_input = transforms.Compose([transforms.Resize((192, 192)),]) image = image.values image = image[:, :, 0:3] # make sure we have only 3 channels @@ -31,9 +29,8 @@ def predict(self, image): image = transform_input(image) # make prediction - prediction = self.model(image) - return prediction + ### TODO ### + ### Apply the model to the pre-processed input image. + ### Return the segmentation mask. + return "FIX ME" - -# ------------- -# - We need to return an image with class labels diff --git a/setup.py b/setup.py index 9745b46..961c965 100644 --- a/setup.py +++ b/setup.py @@ -12,7 +12,7 @@ name="catdog_segmentation_model", version="0.0.1", description="scivision plugin, using a UNet to segment cat/dog images", - url="https://github.com/pwochner/catdog_segmentation_model", + url="TODO: INSERT URL OF THE MODEL GITHUB REPO HERE", packages=find_packages(), install_requires=requirements, python_requires=">=3.7",