Skip to content

Commit

Permalink
[CPU] Fix ARM precision issue
Browse files Browse the repository at this point in the history
ARM64 ACL prefers fp16 than fp32, API 2.0 requires input/output precision not changes,
then fp32 input will trigger convert node is added to convert fp32 to fp16.
  • Loading branch information
riverlijunjie committed Nov 3, 2023
1 parent c243864 commit 9aa8054
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/bindings/python/tests/test_runtime/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ def test_core_class(device):

input_tensor = Tensor(input_data)
results = request.infer({"data": input_tensor})
assert np.allclose(results[list(results)[0]], expected_output)
# convert node may be introduced by API 2.0, which brings some deviation
assert np.allclose(results[list(results)[0]], expected_output, 1e-4, 1e-4)


# request - https://docs.pytest.org/en/7.1.x/reference/reference.html#request
Expand Down
8 changes: 7 additions & 1 deletion src/bindings/python/tests/test_runtime/test_infer_request.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
# Copyright (C) 2018-2023 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
import platform

from collections.abc import Iterable
from copy import deepcopy
Expand Down Expand Up @@ -97,7 +98,12 @@ def abs_model_with_data(device, ov_type, numpy_dtype):

def test_get_profiling_info(device):
core = Core()
param = ops.parameter([1, 3, 32, 32], np.float32, name="data")
if platform.system() == "Darwin" and platform.machine() == "arm64":
# arm64 prefers fp16, and fp32 input will trigger a convert node
# to be added, so assert 'Softmax' will failed.
param = ops.parameter([1, 3, 32, 32], np.float16, name="data")
else:
param = ops.parameter([1, 3, 32, 32], np.float32, name="data")
softmax = ops.softmax(param, 1, name="fc_out")
model = Model([softmax], [param], "test_model")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -294,9 +294,12 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
};
type_to_fuse_map empty_fuse_map = {};
const bool keep_precision_sensitive_in_fp32 = true;
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertPrecision, fp_convert_precision_map,
empty_fuse_map,
keep_precision_sensitive_in_fp32);
CPU_REGISTER_PASS_COMMON(manager,
ov::pass::ConvertPrecision,
fp_convert_precision_map,
empty_fuse_map,
keep_precision_sensitive_in_fp32,
false);
}
#endif
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression);
Expand Down

0 comments on commit 9aa8054

Please sign in to comment.