From 33492243a7dc1f51ce1a991b5a556aca3a4a4f2b Mon Sep 17 00:00:00 2001 From: Yuanzhe Dong Date: Tue, 12 Nov 2024 12:20:05 -0800 Subject: [PATCH 1/2] Support bf16 weight and dynamic shape --- .../model_explorer/pytorch_exported_program_adater_impl.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py b/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py index aba77aed..7e7f05d6 100644 --- a/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py +++ b/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py @@ -157,10 +157,10 @@ def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16): total_size *= dim if size_limit < 0 or size_limit >= total_size: - return json.dumps(tensor.cpu().detach().numpy().tolist()) + return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist()) return json.dumps( - (tensor.cpu().detach().numpy().flatten())[:size_limit].tolist() + (tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist() ) def add_node_attrs(self, fx_node: torch.fx.node.Node, node: GraphNode): @@ -204,7 +204,7 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode): node.outputsMetadata.append(metadata) elif isinstance(out_vals, torch.Tensor): dtype = str(out_vals.dtype) - shape = json.dumps(out_vals.shape) + shape = json.dumps(list(map(lambda x: int(x) if str(x).isdigit() else str(x), out_vals.shape))) metadata = MetadataItem( id='0', attrs=[KeyValue(key='tensor_shape', value=dtype + shape)] ) From f55736bdf827bbbf98e4ae47d97e9fd39a419351 Mon Sep 17 00:00:00 2001 From: Yuanzhe Dong Date: Tue, 12 Nov 2024 14:15:34 -0800 Subject: [PATCH 2/2] apply pyink formatter --- .../pytorch_exported_program_adater_impl.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py b/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py index 7e7f05d6..9ff7b7c9 100644 --- a/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py +++ b/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py @@ -157,10 +157,14 @@ def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16): total_size *= dim if size_limit < 0 or size_limit >= total_size: - return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist()) + return json.dumps( + tensor.cpu().detach().to(torch.float32).numpy().tolist() + ) return json.dumps( - (tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist() + (tensor.cpu().detach().to(torch.float32).numpy().flatten())[ + :size_limit + ].tolist() ) def add_node_attrs(self, fx_node: torch.fx.node.Node, node: GraphNode): @@ -204,7 +208,14 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode): node.outputsMetadata.append(metadata) elif isinstance(out_vals, torch.Tensor): dtype = str(out_vals.dtype) - shape = json.dumps(list(map(lambda x: int(x) if str(x).isdigit() else str(x), out_vals.shape))) + shape = json.dumps( + list( + map( + lambda x: int(x) if str(x).isdigit() else str(x), + out_vals.shape, + ) + ) + ) metadata = MetadataItem( id='0', attrs=[KeyValue(key='tensor_shape', value=dtype + shape)] )