Skip to content

Commit

Permalink
Torch native quantization performance check
Browse files Browse the repository at this point in the history
  • Loading branch information
daniil-lyakhov committed Nov 18, 2024
1 parent a9437cb commit ad294a1
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 2 deletions.
43 changes: 42 additions & 1 deletion tests/torch/fx/performance_check/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

os.environ["TORCHINDUCTOR_FREEZING"] = "1"

import argparse
import re
import subprocess
Expand Down Expand Up @@ -222,6 +226,43 @@ def process_model(model_name: str):
return result


def process_native_torch_int8(model_name: str):
result = {"name": model_name}
model_config = MODEL_SCOPE[model_name]
pt_model = model_config.model_builder.build()
example_inputs = model_config.model_builder.get_example_inputs()
export_inputs = example_inputs[0] if isinstance(example_inputs[0], tuple) else example_inputs

with disable_patching():
latency_fp32 = measure_time(torch.compile(pt_model), export_inputs, model_config.num_iters)
result["fp32_compile_native_latency"] = latency_fp32
print(f"fp32 compiled model native latency: {latency_fp32}")

from torch.ao.quantization.quantize_pt2e import convert_pt2e
from torch.ao.quantization.quantize_pt2e import prepare_pt2e
from torch.ao.quantization.quantizer.x86_inductor_quantizer import X86InductorQuantizer
from torch.ao.quantization.quantizer.x86_inductor_quantizer import get_default_x86_inductor_quantization_config

quantizer = X86InductorQuantizer()
quantizer.set_global(get_default_x86_inductor_quantization_config())

with disable_patching():
with torch.no_grad():
exported_model = capture_pre_autograd_graph(pt_model, export_inputs)

with disable_patching():
prepared_model = prepare_pt2e(exported_model, quantizer)
prepared_model(*export_inputs)
compressed_model = convert_pt2e(prepared_model)

with disable_patching():
latency_int8 = measure_time(torch.compile(compressed_model), export_inputs, model_config.num_iters)

result["int8_compile_native_latency"] = latency_int8
print(f"int8 compiled model native latency: {latency_int8}")
return result


def process_model_native_to_ov(model_name: str):
result = {"name": model_name}
model_config = MODEL_SCOPE[model_name]
Expand Down Expand Up @@ -292,7 +333,7 @@ def main():
print("---------------------------------------------------")
print(f"name: {model_name}")
try:
results_list.append({**process_model_native_to_ov(model_name)})
results_list.append(process_native_torch_int8(model_name))
except Exception as e:
print(f"FAILS TO CHECK PERFORMANCE FOR {model_name} MODEL:")
err_msg = str(e)
Expand Down
2 changes: 1 addition & 1 deletion tests/torch/fx/performance_check/model_scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
class ModelConfig:
model_builder: BaseModelBuilder
quantization_params: Dict[str, Any]
num_iters: int = 5000
num_iters: int = 1000


MODEL_SCOPE = {
Expand Down

0 comments on commit ad294a1

Please sign in to comment.