Skip to content

Commit

Permalink
apply pyink formatter
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanzhedong committed Nov 12, 2024
1 parent 3349224 commit f55736b
Showing 1 changed file with 14 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -157,10 +157,14 @@ def print_tensor(self, tensor: torch.Tensor, size_limit: int = 16):
total_size *= dim

if size_limit < 0 or size_limit >= total_size:
return json.dumps(tensor.cpu().detach().to(torch.float32).numpy().tolist())
return json.dumps(
tensor.cpu().detach().to(torch.float32).numpy().tolist()
)

return json.dumps(
(tensor.cpu().detach().to(torch.float32).numpy().flatten())[:size_limit].tolist()
(tensor.cpu().detach().to(torch.float32).numpy().flatten())[
:size_limit
].tolist()
)

def add_node_attrs(self, fx_node: torch.fx.node.Node, node: GraphNode):
Expand Down Expand Up @@ -204,7 +208,14 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode):
node.outputsMetadata.append(metadata)
elif isinstance(out_vals, torch.Tensor):
dtype = str(out_vals.dtype)
shape = json.dumps(list(map(lambda x: int(x) if str(x).isdigit() else str(x), out_vals.shape)))
shape = json.dumps(
list(
map(
lambda x: int(x) if str(x).isdigit() else str(x),
out_vals.shape,
)
)
)
metadata = MetadataItem(
id='0', attrs=[KeyValue(key='tensor_shape', value=dtype + shape)]
)
Expand Down

0 comments on commit f55736b

Please sign in to comment.