Skip to content

Commit

Permalink
[TorchFX] Performance check
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Aug 26, 2024
1 parent 1104f1b commit 7b94e7e
Show file tree
Hide file tree
Showing 8 changed files with 489 additions and 0 deletions.
47 changes: 47 additions & 0 deletions tests/torch/fx/performance_check/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# Torch compile with OpenVino backend performance check

The [main.py](main.py) script checks fp32 and int8 models performance in two setups:

* Compilation via `torch.compile(model, backend="openvino")`
* Export to OpenVino via `torch.export.export` + `ov.convert` functions

## Installation

```bash
# From the root of NNCF repo:
make install-torch-test
pip install -r tests/torch/fx/performance_check/requirements.txt
```

## Usage

Run performance check for all models:

```bash
python main.py
```

Run performance check for a specific model:

```bash
python main.py --model model_name
```

Run performance check for a specific model and save performance check result to a specific location:

```bash
python main.py --model model_name --file_name /path/to/save/resuts.csv
```

Names of the available models could be found in [model_scope.py](model_scope.py) as keys of the `MODEL_SCOPE` dict.
Performance check results are saved to a `result.csv` file by default.

## Artefacts

You will find directories named after the models in current directory. In case errors were not occured during the preformance check, each directory should contain:

* `int8_code.py` - code of the quantized torch.fx.GrpahModule model

* `int8_nncf_graph.dot` - nncf graph visualization of the quantized torch.fx.GrpahModule model

* `result.csv` - results of the performance check the current model.
10 changes: 10 additions & 0 deletions tests/torch/fx/performance_check/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
212 changes: 212 additions & 0 deletions tests/torch/fx/performance_check/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import re
import subprocess
import traceback
import warnings
from pathlib import Path
from time import time

import openvino as ov
import openvino.torch # noqa
import pandas as pd
import torch
from torch._export import capture_pre_autograd_graph
from torch.fx.passes.graph_drawer import FxGraphDrawer
from torch.jit import TracerWarning

import nncf
from nncf.common.factory import NNCFGraphFactory
from nncf.torch.dynamic_graph.patch_pytorch import disable_patching
from tests.torch.fx.performance_check.model_scope import MODEL_SCOPE

warnings.filterwarnings("ignore", category=TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)

VISUALIZE_FX_INT8_GRAPH = False


def measure_time(model, example_inputs, num_iters=500):
with torch.no_grad():
model(*example_inputs)
total_time = 0
for _ in range(num_iters):
start_time = time()
model(*example_inputs)
total_time += time() - start_time
average_time = (total_time / num_iters) * 1000
return average_time


def measure_time_ov(model, example_inputs, num_iters=500):
ie = ov.Core()
compiled_model = ie.compile_model(model, "CPU")
infer_request = compiled_model.create_infer_request()
infer_request.infer(example_inputs)
total_time = 0
for _ in range(num_iters):
start_time = time()
infer_request.infer(example_inputs)
total_time += time() - start_time
average_time = (total_time / num_iters) * 1000
return average_time


def benchmark_performance(model_path, input_shape) -> float:
command = f"benchmark_app -m {model_path} -d CPU -api async -t 30"
command += f' -shape "[{",".join(str(s) for s in input_shape)}]"'
cmd_output = subprocess.check_output(command, shell=True) # nosec

match = re.search(r"Throughput\: (.+?) FPS", str(cmd_output))
return float(match.group(1))


def process_model(model_name: str):
result = {"name": model_name}
model_config = MODEL_SCOPE[model_name]
pt_model = model_config.model_builder.build()
example_inputs = model_config.model_builder.get_example_inputs()
export_inputs = example_inputs[0] if isinstance(example_inputs[0], tuple) else example_inputs
input_sizes = model_config.model_builder.get_input_sizes()
save_dir = Path(__file__).parent.resolve() / model_name
save_dir.mkdir(exist_ok=True)

with disable_patching():
latency_fp32 = measure_time(torch.compile(pt_model, backend="openvino"), export_inputs, model_config.num_iters)
result["fp32_compile_latency"] = latency_fp32
print(f"fp32 compiled model latency: {latency_fp32}")

try:
with disable_patching():
with torch.no_grad():
ex_model = torch.export.export(pt_model, export_inputs)
ov_model = ov.convert_model(ex_model, example_input=example_inputs[0], input=input_sizes)
ov_model_path = save_dir / "openvino_model.xml"
ov.serialize(ov_model, ov_model_path)
latency_fp32_ov = measure_time_ov(ov_model, example_inputs, model_config.num_iters)
fps_fp32_ov = benchmark_performance(ov_model_path, input_sizes)
except Exception as e:
print("FAILS TO EXPORT FP32 MODEL TO OPENVINO:")
err_msg = str(e)
print(err_msg)
traceback.print_exc()
latency_fp32_ov = -1
fps_fp32_ov = -1

result["fp32_ov_latency"] = latency_fp32_ov
result["fp32_ov_benchmark_fps"] = fps_fp32_ov
print(f"fp32 ov model latency: {latency_fp32_ov}")
print(f"fp32 ov model benchmark fps: {fps_fp32_ov}")

with disable_patching():
with torch.no_grad():
exported_model = capture_pre_autograd_graph(pt_model, export_inputs)

with disable_patching():
with torch.no_grad():
quant_fx_model = nncf.quantize(
exported_model,
nncf.Dataset(example_inputs),
**model_config.quantization_params,
)

int8_graph_visualization_path = str(save_dir / "int8_nncf_graph.dot")
NNCFGraphFactory.create(quant_fx_model).visualize_graph(int8_graph_visualization_path)
print(f"NNCFGraph visualization of int8 model is saved to {int8_graph_visualization_path}")

int8_code_path = str(save_dir / "int8_code.py")
with open(int8_code_path, "w") as f:
f.write(quant_fx_model.code)
print(f"int8 FX code is saved to {int8_code_path}")

if VISUALIZE_FX_INT8_GRAPH:
int8_model_visualization_path = str(save_dir / "int8_fx_graph.svg")
g = FxGraphDrawer(quant_fx_model, int8_model_visualization_path)
g.get_dot_graph().write_svg(int8_model_visualization_path)
print(f"Visualization of int8 model is saved to {int8_model_visualization_path}")

quant_fx_model = torch.compile(quant_fx_model, backend="openvino")

with disable_patching():
latency_int8 = measure_time(quant_fx_model, export_inputs, model_config.num_iters)
result["int8_compiled_latency"] = latency_int8
print(f"int8 compiled model latency: {latency_int8}")

try:
with disable_patching():
with torch.no_grad():
ex_int8_model = torch.export.export(quant_fx_model, export_inputs)
ov_int8_model = ov.convert_model(ex_int8_model, example_input=example_inputs[0], input=input_sizes)
ov_int8_model_path = save_dir / "openvino_model_int8.xml"
ov.serialize(ov_int8_model, ov_int8_model_path)

latency_int8_ov = measure_time_ov(ov_int8_model, export_inputs, model_config.num_iters)
fps_int8_ov = benchmark_performance(ov_int8_model_path, input_sizes)
except Exception as e:
print("FAILS TO EXPORT INT8 MODEL TO OPENVINO:")
err_msg = str(e)
print(err_msg)
traceback.print_exc()
latency_int8_ov = -1
fps_int8_ov = -1

result["int8_ov_latency"] = latency_int8_ov
result["int8_ov_benchmark_fps"] = fps_int8_ov
print(f"int8 ov model latency: {latency_int8_ov}")
print(f"int8 ov model benchmark fps: {fps_int8_ov}")
print("*" * 100)
print(f"Torch compile latency speed up: {latency_fp32 / latency_int8}")
print(f"Torch export + openvino latenyc speed up: {latency_fp32_ov / latency_int8_ov}")
print(f"Openvino FPS benchmark speed up: {fps_int8_ov / fps_fp32_ov}")
print("*" * 100)

result["compile_latency_diff_speedup"] = latency_fp32 / latency_int8
result["ov_latency_diff_speedup"] = latency_fp32_ov / latency_int8_ov
result["ov_benchmark_fps_speedup"] = fps_int8_ov / fps_fp32_ov
pd.DataFrame([result]).to_csv(save_dir / "result.csv")
return result


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--model", help="Target model name", type=str, default="all")
parser.add_argument("--file_name", help="Output csv file_name", type=str, default="result.csv")

args = parser.parse_args()

target_models = []
if args.model == "all":
for model_name in MODEL_SCOPE:
target_models.append(model_name)
else:
target_models.append(args.model)

results_list = []
for model_name in target_models:
print("---------------------------------------------------")
print(f"name: {model_name}")
try:
results_list.append(process_model(model_name))
except Exception as e:
print(f"FAILS TO CHECK PERFORMANCE FOR {model_name} MODEL:")
err_msg = str(e)
print(err_msg)
traceback.print_exc()

df = pd.DataFrame(results_list)
print(df)
df.to_csv(args.file_name)


if __name__ == "__main__":
main()
28 changes: 28 additions & 0 deletions tests/torch/fx/performance_check/model_builders/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import abstractclassmethod

import torch


class BaseModelBuilder:
@abstractclassmethod
def build(self) -> torch.nn.Module:
pass

@abstractclassmethod
def get_example_inputs(self) -> torch.Tensor:
pass

@abstractclassmethod
def get_input_sizes(self):
pass
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from diffusers import StableDiffusionPipeline

from tests.torch.fx.performance_check.model_builders.base import BaseModelBuilder


class StableDiffusion2UnetBuilder(BaseModelBuilder):
def __init__(self):
latents_shape = (2, 4, 96, 96)
encoder_hidden_state_shape = (2, 77, 1024)
time_shape = ()
self._input_sizes = (latents_shape, time_shape, encoder_hidden_state_shape)
self._example_input = tuple([torch.ones(shape) for shape in self._input_sizes])

def build(self):
pipe = StableDiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1-base")
pipe = pipe.to("cpu")
return pipe.unet.eval()

def get_example_inputs(self) -> torch.Tensor:
return (self._example_input,)

def get_input_sizes(self):
return self._input_sizes
33 changes: 33 additions & 0 deletions tests/torch/fx/performance_check/model_builders/torchvision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) 2024 Intel Corporation
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
from torchvision import models

from tests.torch.fx.performance_check.model_builders.base import BaseModelBuilder


class TorchvisionModelBuilder(BaseModelBuilder):
INPUT_SHAPE = (1, 3, 224, 224)

def __init__(self, model_cls: str, model_weights: models.WeightsEnum):
self._model_cls = model_cls
self._model_weights = model_weights
self._example_input = self._model_weights.transforms()(torch.ones(self.INPUT_SHAPE))

def build(self):
return self._model_cls(weights=self._model_weights).eval()

def get_example_inputs(self) -> torch.Tensor:
return (self._example_input,)

def get_input_sizes(self):
return tuple(self._example_input.shape)
Loading

0 comments on commit 7b94e7e

Please sign in to comment.