Skip to content

Commit

Permalink
Make the dependency on PyTorch optional
Browse files Browse the repository at this point in the history
  • Loading branch information
kartynnik committed Nov 1, 2024
1 parent 01f1eb4 commit 760ccc2
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 9 deletions.
1 change: 0 additions & 1 deletion src/server/package/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ dependencies = [
"requests",
"termcolor",
"typing-extensions",
"torch >= 2.2",
"numpy < 2",
]

Expand Down
3 changes: 1 addition & 2 deletions src/server/package/src/model_explorer/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from typing import TypedDict, Union

import torch
from typing_extensions import NotRequired

from . import server
Expand Down Expand Up @@ -107,7 +106,7 @@ def visualize(

def visualize_pytorch(
name: str,
exported_program: torch.export.ExportedProgram,
exported_program: 'torch.export.ExportedProgram',
host=DEFAULT_HOST,
port=DEFAULT_PORT,
extensions: list[str] = [],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,16 @@

from typing import Dict

import torch
import torch.fx
try:
import torch
except ImportError:
torch = None

from .adapter import Adapter, AdapterMetadata
from .types import ModelExplorerGraphs
from .pytorch_exported_program_adater_impl import PytorchExportedProgramAdapterImpl

if torch is not None:
from .pytorch_exported_program_adater_impl import PytorchExportedProgramAdapterImpl


class BuiltinPytorchExportedProgramAdapter(Adapter):
Expand All @@ -40,5 +44,9 @@ def __init__(self):
super().__init__()

def convert(self, model_path: str, settings: Dict) -> ModelExplorerGraphs:
if torch is None:
raise ImportError(
'Please install the `torch` package via: `pip install torch`'
)
ep = torch.export.load(model_path)
return PytorchExportedProgramAdapterImpl(ep, settings).convert()
12 changes: 9 additions & 3 deletions src/server/package/src/model_explorer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,20 @@
from urllib.parse import quote

import requests
import torch
from typing_extensions import NotRequired

try:
import torch
except ImportError:
torch = None

from .consts import DEFAULT_HOST, DEFAULT_SETTINGS
from .node_data_builder import NodeData
from .pytorch_exported_program_adater_impl import PytorchExportedProgramAdapterImpl
from .types import ModelExplorerGraphs

if torch is not None:
from .pytorch_exported_program_adater_impl import PytorchExportedProgramAdapterImpl

ModelSource = TypedDict(
'ModelSource', {'url': str, 'adapterId': NotRequired[str]}
)
Expand Down Expand Up @@ -79,7 +85,7 @@ def add_model_from_path(
def add_model_from_pytorch(
self,
name: str,
exported_program: torch.export.ExportedProgram,
exported_program: 'torch.export.ExportedProgram',
settings=DEFAULT_SETTINGS,
) -> 'ModelExplorerConfig':
"""Adds the given pytorch model.
Expand Down

0 comments on commit 760ccc2

Please sign in to comment.