Skip to content

Commit

Permalink
Update pytorch api and web app
Browse files Browse the repository at this point in the history
  • Loading branch information
jinjingforever committed May 13, 2024
1 parent 4810306 commit e918572
Show file tree
Hide file tree
Showing 6 changed files with 1,447 additions and 1,446 deletions.
3 changes: 2 additions & 1 deletion example_colabs/quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,10 @@
"# Get mobilnet v2 pytorch model as an example.\n",
"model = torchvision.models.mobilenet_v2().eval()\n",
"inputs = (torch.rand([1, 3, 224, 224]),)\n",
"ep = torch.export.export(model, inputs)\n",
"\n",
"# Visualize\n",
"model_explorer.visualize_pytorch('mobilenet', model, inputs)"
"model_explorer.visualize_pytorch('mobilenet', exported_program=ep)"
]
}
],
Expand Down
2 changes: 1 addition & 1 deletion src/server/package/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "ai-edge-model-explorer"
version = "0.0.95"
version = "0.0.96"
authors = [
{ name="Google LLC", email="[email protected]" },
]
Expand Down
12 changes: 6 additions & 6 deletions src/server/package/src/model_explorer/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.
# ==============================================================================

from typing import Any, Callable, Tuple, Union
from typing import Union

import torch

from . import server
from .config import ModelExplorerConfig
Expand Down Expand Up @@ -56,24 +58,22 @@ def visualize(

def visualize_pytorch(
name: str,
model: Callable,
inputs: Tuple[Any, ...],
exported_program: torch.export.ExportedProgram,
host=DEFAULT_HOST,
port=DEFAULT_PORT,
colab_height=DEFAULT_COLAB_HEIGHT) -> None:
"""Visualizes a pytorch model.
Args:
name: The name of the model for display purpose.
model: The callable to trace.
inputs: Example positional inputs.
exported_program: The ExportedProgram from torch.export.export.
host: The host of the server. Default to localhost.
port: The port of the server. Default to 8080.
colab_height: The height of the embedded iFrame when running in colab.
"""
# Construct config.
cur_config = config()
cur_config.add_model_from_pytorch(name, model, inputs)
cur_config.add_model_from_pytorch(name, exported_program=exported_program)

# Start server.
server.start(
Expand Down
2 changes: 1 addition & 1 deletion src/server/package/src/model_explorer/cmdline.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from .consts import DEFAULT_HOST, DEFAULT_PORT

parser = argparse.ArgumentParser(
prog='model_explorer',
prog='model-explorer',
description='A modern model graph visualizer and debugger',
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('model_paths',
Expand Down
9 changes: 3 additions & 6 deletions src/server/package/src/model_explorer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,19 +65,16 @@ def add_model_from_path(self, path: str, adapterId: str = '') -> 'ModelExplorerC

def add_model_from_pytorch(self,
name: str,
model: Callable,
inputs: Tuple[Any, ...]) -> 'ModelExplorerConfig':
exported_program: torch.export.ExportedProgram) -> 'ModelExplorerConfig':
"""Adds the given pytorch model.
Args:
name: the name of the model for display purpose.
model: the callable to trace.
inputs: Example positional inputs.
exported_program: the ExportedProgram from torch.export.export.
"""
# Convert the given model to model explorer graphs.
print('Converting pytorch model to model explorer graphs...')
exported = torch.export.export(model, inputs)
adapter = PytorchExportedProgramAdapterImpl(exported)
adapter = PytorchExportedProgramAdapterImpl(exported_program)
graphs = adapter.convert()
graphs_index = len(self.graphs_list)
self.graphs_list.append(graphs)
Expand Down
Loading

0 comments on commit e918572

Please sign in to comment.