This repository has been archived by the owner on Oct 24, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
TORCHSCRIPT_CLASSIFIER node for any user-provided .torchscript AI mod…
…el (#234) * add torchscript classifier node * add class_names.csv * update app and example.md * rename to TORCHSCRIPT_CLASSIFIER.md * delete example.md * rename back to example.md * move model path to param * run black formatter * Update flojoy This resolves issue on the CI. Signed-off-by: Julien Jerphanion <[email protected]> * Rename to app.json Signed-off-by: Julien Jerphanion <[email protected]> * Loosen dependencies for integration within studio Signed-off-by: Julien Jerphanion <[email protected]> * Update AI_ML/CLASSIFICATION/TORCHSCRIPT_CLASSIFIER/TORCHSCRIPT_CLASSIFIER_test_.py Co-authored-by: Julien Jerphanion <[email protected]> * Update requirements.txt Co-authored-by: Julien Jerphanion <[email protected]> * replace the json with the csv * Relax some dependencies for compatibility, fix CI (#279) Co-authored-by: Julien Jerphanion <[email protected]> * pull mobilenetv3 from the net * remove model torchscript * run black formatter --------- Signed-off-by: Julien Jerphanion <[email protected]> Co-authored-by: Julien Jerphanion <[email protected]>
- Loading branch information
Showing
11 changed files
with
1,556 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
70 changes: 70 additions & 0 deletions
70
AI_ML/CLASSIFICATION/TORCHSCRIPT_CLASSIFIER/TORCHSCRIPT_CLASSIFIER.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
from flojoy import flojoy, run_in_venv, Image, DataFrame | ||
|
||
|
||
@flojoy | ||
@run_in_venv( | ||
pip_dependencies=[ | ||
"torch==2.0.1", | ||
"torchvision==0.15.2", | ||
"numpy", | ||
"Pillow", | ||
] | ||
) | ||
def TORCHSCRIPT_CLASSIFIER( | ||
input_image: Image, class_names: DataFrame, model_path: str | ||
) -> DataFrame: | ||
""" | ||
Execute a torchscript classifier against an input image. | ||
Inputs | ||
---------- | ||
input_image : Image | ||
The image to classify. | ||
class_names : DataFrame | ||
A dataframe containing the class names. | ||
Parameters | ||
---------- | ||
model_path : str | ||
The path to the torchscript model. | ||
Returns | ||
---------- | ||
DataFrame | ||
A dataframe containing the class name and confidence score. | ||
""" | ||
|
||
import torch | ||
import torchvision | ||
import pandas as pd | ||
import numpy as np | ||
import PIL.Image | ||
|
||
# Load model | ||
model = torch.jit.load(model_path) | ||
channels = [input_image.r, input_image.g, input_image.b] | ||
mode = "RGB" | ||
|
||
if input_image.a is not None: | ||
channels.append(input_image.a) | ||
mode += "A" | ||
|
||
input_image_pil = PIL.Image.fromarray( | ||
np.stack(channels).transpose(1, 2, 0), mode=mode | ||
).convert("RGB") | ||
input_tensor = torchvision.transforms.functional.to_tensor( | ||
input_image_pil | ||
).unsqueeze(0) | ||
|
||
# Run model | ||
with torch.inference_mode(): | ||
output = model(input_tensor) | ||
|
||
# Get class name and confidence score | ||
_, pred = torch.max(output, 1) | ||
class_name = class_names.m.iloc[pred.item()].item() | ||
confidence = torch.nn.functional.softmax(output, dim=1)[0][pred.item()].item() | ||
|
||
return DataFrame( | ||
df=pd.DataFrame({"class_name": [class_name], "confidence": [confidence]}) | ||
) |
82 changes: 82 additions & 0 deletions
82
AI_ML/CLASSIFICATION/TORCHSCRIPT_CLASSIFIER/TORCHSCRIPT_CLASSIFIER_test_.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
import pytest | ||
import os | ||
import tempfile | ||
import pandas as pd | ||
import numpy as np | ||
import PIL | ||
from flojoy import run_in_venv, Image, DataFrame | ||
|
||
|
||
@pytest.fixture | ||
def torchscript_model_path(): | ||
# Download and save a test model, this requires | ||
# torch which is why we need to run this in a venv. | ||
@run_in_venv(pip_dependencies=["torch~=2.0.1", "torchvision~=0.15.2"]) | ||
def _download_test_model(path: str): | ||
import torch | ||
import torchvision | ||
|
||
class ModelWithTransform(torch.nn.Module): | ||
def __init__(self): | ||
super().__init__() | ||
self.model = torch.hub.load( | ||
"pytorch/vision:v0.15.2", | ||
"mobilenet_v3_small", | ||
pretrained=True, | ||
skip_validation=True, # This will save us from github rate limiting, https://github.com/pytorch/vision/issues/4156#issuecomment-939680999 | ||
) | ||
self.model.eval() | ||
self.transforms = ( | ||
torchvision.models.MobileNet_V3_Small_Weights.DEFAULT.transforms() | ||
) | ||
|
||
def forward(self, x): | ||
return self.model(self.transforms(x)) | ||
|
||
model = ModelWithTransform() | ||
scripted = torch.jit.script(model) | ||
torch.jit.save(scripted, path) | ||
|
||
with tempfile.TemporaryDirectory() as tempdir: | ||
model_path = os.path.join(tempdir, "mbnet_v3_small.torchscript") | ||
_download_test_model(model_path) | ||
yield model_path | ||
|
||
|
||
@pytest.fixture | ||
def class_names(): | ||
csv_path = os.path.join( | ||
os.path.dirname(os.path.realpath(__file__)), "assets", "class_names.csv" | ||
) | ||
return DataFrame(df=pd.read_csv(csv_path)) | ||
|
||
|
||
@pytest.fixture | ||
def obama_image(): | ||
_image_path = os.path.join( | ||
os.path.dirname(os.path.realpath(__file__)), | ||
"assets", | ||
"President_Barack_Obama.jpg", | ||
) | ||
image = np.array(PIL.Image.open(_image_path).convert("RGB")) | ||
return Image(r=image[:, :, 0], g=image[:, :, 1], b=image[:, :, 2], a=None) | ||
|
||
|
||
@pytest.mark.slow | ||
def test_TORHSCRIPT_CLASSIFIER( | ||
mock_flojoy_decorator, | ||
mock_flojoy_venv_cache_directory, | ||
obama_image, | ||
torchscript_model_path, | ||
class_names, | ||
): | ||
import TORCHSCRIPT_CLASSIFIER | ||
|
||
# Test the model | ||
clf_output = TORCHSCRIPT_CLASSIFIER.TORCHSCRIPT_CLASSIFIER( | ||
input_image=obama_image, | ||
model_path=torchscript_model_path, | ||
class_names=class_names, | ||
) | ||
|
||
assert clf_output.m.iloc[0].class_name == "suit, suit of clothes" |
Oops, something went wrong.