Skip to content

Commit

Permalink
Merge branch 'develop' into add_fno_uniform
Browse files Browse the repository at this point in the history
  • Loading branch information
umangyadav authored Sep 15, 2023
2 parents fa8dc98 + 74ba964 commit 63d1420
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion tools/test_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@ def parse_args():
type=str,
default='gpu',
help='Specify where the tests execute (ref, gpu)')
parser.add_argument('--fp16', action='store_true', help='Quantize to fp16')
parser.add_argument('--atol',
type=float,
default=1e-3,
help='The absolute tolerance parameter')
parser.add_argument('--rtol',
type=float,
default=1e-3,
help='The relative tolerance parameter')
args = parser.parse_args()

return args
Expand Down Expand Up @@ -257,6 +266,8 @@ def main():

# read and compile model
model = migraphx.parse_onnx(model_path_name, map_input_dims=param_shapes)
if args.fp16:
migraphx.quantize_fp16(model)
model.compile(migraphx.get_target(target))

# get test cases
Expand All @@ -279,7 +290,10 @@ def main():
output_data = run_one_case(model, input_data)

# check output correctness
ret = check_correctness(gold_outputs, output_data)
ret = check_correctness(gold_outputs,
output_data,
atol=args.atol,
rtol=args.rtol)
if ret:
correct_num += 1

Expand Down

0 comments on commit 63d1420

Please sign in to comment.