Skip to content
This repository has been archived by the owner on Oct 24, 2023. It is now read-only.

TORCHSCRIPT_CLASSIFIER node for any user-provided .torchscript AI model #234

Merged
merged 24 commits into from
Sep 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
e065808
add torchscript classifier node
Roulbac Aug 13, 2023
a8fb815
add class_names.csv
Roulbac Aug 14, 2023
178985e
update app and example.md
Roulbac Aug 14, 2023
39fd03e
rename to TORCHSCRIPT_CLASSIFIER.md
Roulbac Aug 14, 2023
e990300
delete example.md
Roulbac Aug 14, 2023
5c8fce0
rename back to example.md
Roulbac Aug 14, 2023
183ce9b
move model path to param
Roulbac Aug 14, 2023
32ebd40
run black formatter
Roulbac Aug 14, 2023
80345d6
Update flojoy
jjerphan Sep 6, 2023
6ae0ff0
Merge branch 'develop' into reda-torchscript-classifier
jjerphan Sep 6, 2023
0dacaa1
Rename to app.json
jjerphan Sep 6, 2023
5584e87
Loosen dependencies for integration within studio
jjerphan Sep 6, 2023
482912e
Update AI_ML/CLASSIFICATION/TORCHSCRIPT_CLASSIFIER/TORCHSCRIPT_CLASSI…
Roulbac Sep 7, 2023
72b046a
Update requirements.txt
Roulbac Sep 7, 2023
3aad2d1
merge jjerphan-fork/reda-torchscript-classifier into HEAD
Roulbac Sep 7, 2023
c5c431c
Merge remote-tracking branch 'origin/develop' into reda-torchscript-c…
Roulbac Sep 7, 2023
968fe10
Merge branch 'develop' into reda-torchscript-classifier
Roulbac Sep 7, 2023
1890292
replace the json with the csv
Roulbac Sep 7, 2023
07314b6
Relax some dependencies for compatibility, fix CI (#279)
Roulbac Sep 7, 2023
b022c96
pull mobilenetv3 from the net
Roulbac Sep 7, 2023
0a0d8ad
Merge branch 'reda-torchscript-classifier' of https://github.com/floj…
Roulbac Sep 7, 2023
b13a46f
remove model torchscript
Roulbac Sep 7, 2023
11ec685
run black formatter
Roulbac Sep 7, 2023
fb4dd1e
Merge branch 'develop' into reda-torchscript-classifier
Roulbac Sep 8, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .github/workflows/pytest-slow-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ jobs:
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v3
with:
ref: reda-fix-slow-tests

- uses: actions/setup-python@v4
with:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from flojoy import flojoy, run_in_venv, Image, DataFrame


@flojoy
@run_in_venv(
pip_dependencies=[
"torch==2.0.1",
"torchvision==0.15.2",
"numpy",
"Pillow",
]
)
def TORCHSCRIPT_CLASSIFIER(
input_image: Image, class_names: DataFrame, model_path: str
) -> DataFrame:
"""
Execute a torchscript classifier against an input image.

Inputs
----------
input_image : Image
The image to classify.
class_names : DataFrame
A dataframe containing the class names.

Parameters
----------
model_path : str
The path to the torchscript model.

Returns
----------
DataFrame
A dataframe containing the class name and confidence score.
"""

import torch
import torchvision
import pandas as pd
import numpy as np
import PIL.Image

# Load model
model = torch.jit.load(model_path)
channels = [input_image.r, input_image.g, input_image.b]
mode = "RGB"

if input_image.a is not None:
channels.append(input_image.a)
mode += "A"

input_image_pil = PIL.Image.fromarray(
np.stack(channels).transpose(1, 2, 0), mode=mode
).convert("RGB")
input_tensor = torchvision.transforms.functional.to_tensor(
input_image_pil
).unsqueeze(0)

# Run model
with torch.inference_mode():
output = model(input_tensor)

# Get class name and confidence score
_, pred = torch.max(output, 1)
class_name = class_names.m.iloc[pred.item()].item()
confidence = torch.nn.functional.softmax(output, dim=1)[0][pred.item()].item()

return DataFrame(
df=pd.DataFrame({"class_name": [class_name], "confidence": [confidence]})
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import pytest
import os
import tempfile
import pandas as pd
import numpy as np
import PIL
from flojoy import run_in_venv, Image, DataFrame


@pytest.fixture
def torchscript_model_path():
# Download and save a test model, this requires
# torch which is why we need to run this in a venv.
@run_in_venv(pip_dependencies=["torch~=2.0.1", "torchvision~=0.15.2"])
def _download_test_model(path: str):
import torch
import torchvision

class ModelWithTransform(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torch.hub.load(
Roulbac marked this conversation as resolved.
Show resolved Hide resolved
"pytorch/vision:v0.15.2",
"mobilenet_v3_small",
pretrained=True,
skip_validation=True, # This will save us from github rate limiting, https://github.com/pytorch/vision/issues/4156#issuecomment-939680999
)
self.model.eval()
self.transforms = (
torchvision.models.MobileNet_V3_Small_Weights.DEFAULT.transforms()
)

def forward(self, x):
return self.model(self.transforms(x))

model = ModelWithTransform()
scripted = torch.jit.script(model)
torch.jit.save(scripted, path)

with tempfile.TemporaryDirectory() as tempdir:
model_path = os.path.join(tempdir, "mbnet_v3_small.torchscript")
_download_test_model(model_path)
yield model_path


@pytest.fixture
def class_names():
csv_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "assets", "class_names.csv"
)
return DataFrame(df=pd.read_csv(csv_path))


@pytest.fixture
def obama_image():
_image_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"assets",
"President_Barack_Obama.jpg",
)
image = np.array(PIL.Image.open(_image_path).convert("RGB"))
return Image(r=image[:, :, 0], g=image[:, :, 1], b=image[:, :, 2], a=None)


@pytest.mark.slow
def test_TORHSCRIPT_CLASSIFIER(
mock_flojoy_decorator,
mock_flojoy_venv_cache_directory,
obama_image,
torchscript_model_path,
class_names,
):
import TORCHSCRIPT_CLASSIFIER

# Test the model
clf_output = TORCHSCRIPT_CLASSIFIER.TORCHSCRIPT_CLASSIFIER(
input_image=obama_image,
model_path=torchscript_model_path,
class_names=class_names,
)

assert clf_output.m.iloc[0].class_name == "suit, suit of clothes"
Loading