diff --git a/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py b/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py index 7e7f05d6..9ff7b7c9 100644 --- a/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py +++ b/src/server/package/src/model_explorer/pytorch_exported_program_adater_impl.py @@ -157,10 +157,14 @@ def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16): total_size *= dim if size_limit < 0 or size_limit >= total_size: - return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist()) + return json.dumps( + tensor.cpu().detach().to(torch.float32).numpy().tolist() + ) return json.dumps( - (tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist() + (tensor.cpu().detach().to(torch.float32).numpy().flatten())[ + :size_limit + ].tolist() ) def add_node_attrs(self, fx_node: torch.fx.node.Node, node: GraphNode): @@ -204,7 +208,14 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode): node.outputsMetadata.append(metadata) elif isinstance(out_vals, torch.Tensor): dtype = str(out_vals.dtype) - shape = json.dumps(list(map(lambda x: int(x) if str(x).isdigit() else str(x), out_vals.shape))) + shape = json.dumps( + list( + map( + lambda x: int(x) if str(x).isdigit() else str(x), + out_vals.shape, + ) + ) + ) metadata = MetadataItem( id='0', attrs=[KeyValue(key='tensor_shape', value=dtype + shape)] )