Skip to content

Commit

Permalink
Merge pull request #245 from google-ai-edge/reuseserver
Browse files Browse the repository at this point in the history
Allow to reuse server for pytorch models
  • Loading branch information
jinjingforever authored Nov 14, 2024
2 parents 7e16ef3 + f507df6 commit 1e5e178
Show file tree
Hide file tree
Showing 10 changed files with 169 additions and 35 deletions.
159 changes: 131 additions & 28 deletions ci/playwright_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
# ==============================================================================

import logging
import model_explorer
import re
import tempfile
import time
import torch
import torchvision
from playwright.sync_api import Page, expect
from PIL import Image, ImageChops
from pathlib import Path
Expand Down Expand Up @@ -75,13 +79,17 @@ def delay_take_screenshot(page: Page, file_path: str):
page.screenshot(path=file_path)


def take_and_compare_screenshot(page: Page, name: str):
actual_image_path = TMP_SCREENSHOT_DIR / name
delay_take_screenshot(page, actual_image_path)
expected_image_path = EXPECTED_SCREENSHOT_DIR / name
assert matched_images(actual_image_path, expected_image_path)


def test_homepage(page: Page):
page.goto(LOCAL_SERVER)
expect(page).to_have_title(re.compile("Model Explorer"))
actual_image_path = TMP_SCREENSHOT_DIR / "homepage.png"
delay_take_screenshot(page, actual_image_path)
expected_image_path = EXPECTED_SCREENSHOT_DIR / "homepage.png"
assert matched_images(actual_image_path, expected_image_path)
take_and_compare_screenshot(page, "homepage.png")


def test_litert_direct_adapter(page: Page):
Expand All @@ -94,10 +102,8 @@ def test_litert_direct_adapter(page: Page):
page.get_by_text("TFLite adapter (Flatbuffer)").click()
delay_view_model(page)
page.locator("canvas").first.click(position={"x": 469, "y": 340})
actual_image_path = TMP_SCREENSHOT_DIR / "litert_direct.png"
delay_take_screenshot(page, actual_image_path)
expected_image_path = EXPECTED_SCREENSHOT_DIR / "litert_direct.png"
assert matched_images(actual_image_path, expected_image_path)

take_and_compare_screenshot(page, "litert_direct.png")


def test_litert_mlir_adapter(page: Page):
Expand All @@ -110,10 +116,8 @@ def test_litert_mlir_adapter(page: Page):
page.get_by_text("TFLite adapter (MLIR)").click()
delay_view_model(page)
page.locator("canvas").first.click(position={"x": 514, "y": 332})
actual_image_path = TMP_SCREENSHOT_DIR / "litert_mlir.png"
delay_take_screenshot(page, actual_image_path)
expected_image_path = EXPECTED_SCREENSHOT_DIR / "litert_mlir.png"
assert matched_images(actual_image_path, expected_image_path)

take_and_compare_screenshot(page, "litert_mlir.png")


def test_tf_mlir_adapter(page: Page):
Expand All @@ -126,10 +130,8 @@ def test_tf_mlir_adapter(page: Page):
page.get_by_text("TF adapter (MLIR) Default").click()
delay_view_model(page)
page.locator("canvas").first.click(position={"x": 444, "y": 281})
actual_image_path = TMP_SCREENSHOT_DIR / "tf_mlir.png"
delay_take_screenshot(page, actual_image_path)
expected_image_path = EXPECTED_SCREENSHOT_DIR / "tf_mlir.png"
assert matched_images(actual_image_path, expected_image_path)

take_and_compare_screenshot(page, "tf_mlir.png")


def test_tf_direct_adapter(page: Page):
Expand All @@ -144,10 +146,8 @@ def test_tf_direct_adapter(page: Page):
page.get_by_text("__inference__traced_save_36", exact=True).click()
page.get_by_text("__inference_add_6").click()
delay_click_canvas(page, 205, 265)
actual_image_path = TMP_SCREENSHOT_DIR / "tf_direct.png"
delay_take_screenshot(page, actual_image_path)
expected_image_path = EXPECTED_SCREENSHOT_DIR / "tf_direct.png"
assert matched_images(actual_image_path, expected_image_path)

take_and_compare_screenshot(page, "tf_direct.png")


def test_tf_graphdef_adapter(page: Page):
Expand All @@ -158,10 +158,8 @@ def test_tf_graphdef_adapter(page: Page):
page.get_by_role("button", name="Add").click()
delay_view_model(page)
page.locator("canvas").first.click(position={"x": 468, "y": 344})
actual_image_path = TMP_SCREENSHOT_DIR / "graphdef.png"
delay_take_screenshot(page, actual_image_path)
expected_image_path = EXPECTED_SCREENSHOT_DIR / "graphdef.png"
assert matched_images(actual_image_path, expected_image_path)

take_and_compare_screenshot(page, "graphdef.png")


def test_shlo_mlir_adapter(page: Page):
Expand All @@ -173,7 +171,112 @@ def test_shlo_mlir_adapter(page: Page):
delay_view_model(page)
page.get_by_text("unfold_more_double").click()
delay_click_canvas(page, 454, 416)
actual_image_path = TMP_SCREENSHOT_DIR / "shlo_mlir.png"
delay_take_screenshot(page, actual_image_path)
expected_image_path = EXPECTED_SCREENSHOT_DIR / "shlo_mlir.png"
assert matched_images(actual_image_path, expected_image_path)

take_and_compare_screenshot(page, "shlo_mlir.png")


def test_pytorch(page: Page):
# Serialize a pytorch model.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
ep = torch.export.export(model, inputs)
tmp_dir = tempfile.gettempdir()
pt2_file_path = f"{tmp_dir}/pytorch.pt2"
torch.export.save(ep, pt2_file_path)

# Load into ME.
page.goto(LOCAL_SERVER)
page.get_by_placeholder("Absolute file paths (").fill(pt2_file_path)
page.get_by_role("button", name="Add").click()
delay_view_model(page)
page.locator("canvas").first.click(position={"x": 458, "y": 334})

take_and_compare_screenshot(page, "pytorch.png")


def test_reuse_server_non_pytorch(page: Page):
# Load a tflite model
page.goto(LOCAL_SERVER)
page.get_by_placeholder("Absolute file paths (").fill(
TEST_FILES_DIR / "fully_connected.tflite"
)
page.get_by_role("button", name="Add").click()
delay_view_model(page)

# Load a mlir graph and reuse the existing server.
mlir_model_path = TEST_FILES_DIR / "stablehlo_sin.mlir"
model_explorer.visualize(
model_paths=mlir_model_path.as_posix(), reuse_server=True
)
time.sleep(2) # Delay for the animation

take_and_compare_screenshot(page, "reuse_server_non_pytorch.png")


def test_reuse_server_pytorch(page: Page):
# Load a tflite model
page.goto(LOCAL_SERVER)
page.get_by_placeholder("Absolute file paths (").fill(
TEST_FILES_DIR / "fully_connected.tflite"
)
page.get_by_role("button", name="Add").click()
delay_view_model(page)

# Load a pytorch model and reuse the existing server.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
ep = torch.export.export(model, inputs)
model_explorer.visualize_pytorch(
name="test pytorch", exported_program=ep, reuse_server=True
)
time.sleep(2) # Delay for the animation

take_and_compare_screenshot(page, "reuse_server_pytorch.png")


def test_reuse_server_pytorch_from_config(page: Page):
# Load a tflite model
page.goto(LOCAL_SERVER)
page.get_by_placeholder("Absolute file paths (").fill(
TEST_FILES_DIR / "fully_connected.tflite"
)
page.get_by_role("button", name="Add").click()
delay_view_model(page)

# Load a pytorch model and reuse the existing server through config.
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
ep = torch.export.export(model, inputs)
config = model_explorer.config()
config.add_model_from_pytorch("test pytorch", ep).set_reuse_server()
model_explorer.visualize_from_config(config)
time.sleep(2) # Delay for the animation

take_and_compare_screenshot(page, "reuse_server_pytorch_from_config.png")


def test_reuse_server_two_pytorch_models(page: Page):
# Serialize a pytorch model (mobilenet v2).
model = torchvision.models.mobilenet_v2().eval()
inputs = (torch.rand([1, 3, 224, 224]),)
ep = torch.export.export(model, inputs)
pt2_file_path = tempfile.NamedTemporaryFile(suffix=".pt2")
torch.export.save(ep, pt2_file_path.name)

# Load it into ME.
page.goto(LOCAL_SERVER)
page.get_by_placeholder("Absolute file paths (").fill(pt2_file_path.name)
page.get_by_role("button", name="Add").click()
delay_view_model(page)

# Load another pytorch model (mobilenet v3) and reuse the existing server.
model2 = torchvision.models.mobilenet_v3_small().eval()
inputs2 = (torch.rand([1, 3, 224, 224]),)
ep2 = torch.export.export(model2, inputs2)
model_explorer.visualize_pytorch(
name="test pytorch", exported_program=ep2, reuse_server=True
)
time.sleep(2) # Delay for the animation

# The screenshot should show "V3" in the center node.
take_and_compare_screenshot(page, "reuse_server_two_pytorch_models.png")
13 changes: 13 additions & 0 deletions src/server/package/src/model_explorer/apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ def visualize_pytorch(
node_data: Union[NodeDataInfo, list[NodeDataInfo]] = [],
colab_height=DEFAULT_COLAB_HEIGHT,
settings=DEFAULT_SETTINGS,
reuse_server: bool = False,
reuse_server_host: str = DEFAULT_HOST,
reuse_server_port: Union[int, None] = None,
) -> None:
"""Visualizes a pytorch model.
Expand All @@ -130,6 +133,11 @@ def visualize_pytorch(
node_data: The node data or a list of node data to display.
colab_height: The height of the embedded iFrame when running in colab.
settings: The settings that config the visualization.
reuse_server: Whether to reuse the current server/browser tab(s) to
visualize.
reuse_server_host: the host of the server to reuse. Default to localhost.
reuse_server_port: the port of the server to reuse. If unspecified, it will
try to find a running server from port 8080 to 8099.
"""
# Construct config.
cur_config = config()
Expand All @@ -139,6 +147,11 @@ def visualize_pytorch(

_add_node_data_to_config(node_data=node_data, config=cur_config)

if reuse_server:
cur_config.set_reuse_server(
server_host=reuse_server_host, server_port=reuse_server_port
)

# Start server.
server.start(
host=host,
Expand Down
18 changes: 16 additions & 2 deletions src/server/package/src/model_explorer/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from .consts import DEFAULT_HOST, DEFAULT_SETTINGS
from .node_data_builder import NodeData
from .types import ModelExplorerGraphs
from .utils import convert_adapter_response

if torch is not None:
from .pytorch_exported_program_adater_impl import PytorchExportedProgramAdapterImpl
Expand All @@ -52,7 +53,8 @@ class ModelExplorerConfig:

def __init__(self) -> None:
self.model_sources: list[ModelSource] = []
self.graphs_list: list[ModelExplorerGraphs] = []
# Array of ModelExplorerGraphs or json string.
self.graphs_list: list[Union[ModelExplorerGraphs, str]] = []
self.node_data_sources: list[str] = []
# List of model names to apply node data to. For the meaning of
# "model name", see comments in `add_node_data_from_path` method below.
Expand Down Expand Up @@ -220,7 +222,9 @@ def to_url_param_value(self) -> str:
# Return its json string.
return quote(json.dumps(encoded_url_data))

def get_model_explorer_graphs(self, index: int) -> ModelExplorerGraphs:
def get_model_explorer_graphs(
self, index: int
) -> Union[ModelExplorerGraphs, str]:
return self.graphs_list[index]

def get_node_data(self, index: int) -> Union[NodeData, str]:
Expand All @@ -230,13 +234,23 @@ def has_data_to_encode_in_url(self) -> bool:
return len(self.model_sources) > 0 or len(self.node_data_sources) > 0

def get_transferrable_data(self) -> Dict:
# Convert the graphs list to a list of strings.
graphs_list = []
for graph in self.graphs_list:
if isinstance(graph, str):
graphs_list.append(graph)
else:
graphs_list.append(json.dumps(convert_adapter_response(graph)))

return {
'graphs_list': json.dumps(graphs_list),
'model_sources': self.model_sources,
'node_data_sources': self.node_data_sources,
'node_data_target_models': self.node_data_target_models,
}

def set_transferrable_data(self, data: Dict):
self.graphs_list = json.loads(data['graphs_list'])
self.model_sources = data['model_sources']
self.node_data_sources = data['node_data_sources']
self.node_data_target_models = data['node_data_target_models']
Expand Down
13 changes: 8 additions & 5 deletions src/server/package/src/model_explorer/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,10 @@ def _get_latest_version_from_repo(package_json_url: str) -> str:

def _get_release_from_github(version: str) -> dict:
# Get release data through github API.
api_url_base = 'https://api.github.com/repos'
repo_name = 'google-ai-edge/model-explorer'
req = requests.get(
f'https://api.github.com/repos/google-ai-edge/model-explorer/releases/tags/model-explorer-v{version}'
f'{api_url_base}/{repo_name}/releases/tags/model-explorer-v{version}'
)
req_json = json.loads(req.text.encode('utf-8'))

Expand Down Expand Up @@ -318,9 +320,11 @@ def load_graphs_json():
if graph_index_str is None:
return {}
graph_index = int(graph_index_str)
return _make_json_response(
convert_adapter_response(config.get_model_explorer_graphs(graph_index))
)
graphs = config.get_model_explorer_graphs(graph_index)
if isinstance(graphs, str):
return _make_json_response(json.loads(graphs))
else:
return _make_json_response(convert_adapter_response(graphs))

@app.route('/api/v1/load_node_data')
def load_node_data():
Expand Down Expand Up @@ -358,7 +362,6 @@ def check_health():

@app.route('/apipost/v1/update_config', methods=['POST'])
def update_config():
# TODO(do not submit): Update confnig.
config_data = request.json
if config and config_data:
config.set_transferrable_data(config_data)
Expand Down
1 change: 1 addition & 0 deletions src/server/scripts/setup_local_dev.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,5 @@
# ==============================================================================

cd package
pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install -e .
Binary file added test/screenshots_golden/chrome-linux/pytorch.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 1e5e178

Please sign in to comment.