From a8d3149f010830b709eb5cab3c97b18398afb404 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Mon, 21 Oct 2024 10:28:50 -0400 Subject: [PATCH] Fix display --- lib/axon/display.ex | 81 +++++++++++++++++++++++++++++---------------- 1 file changed, 53 insertions(+), 28 deletions(-) diff --git a/lib/axon/display.ex b/lib/axon/display.ex index be2aecfb..11eb4f62 100644 --- a/lib/axon/display.ex +++ b/lib/axon/display.ex @@ -56,9 +56,24 @@ defmodule Axon.Display do vertical_symbol: "|" ) |> then(&(&1 <> "Total Parameters: #{model_info.num_params}\n")) - |> then(&(&1 <> "Total Parameters Memory: #{model_info.total_param_byte_size} bytes\n")) + |> then( + &(&1 <> "Total Parameters Memory: #{readable_size(model_info.total_param_byte_size)}\n") + ) end + defp readable_size(n) when n < 1_000, do: "#{n} bytes" + + defp readable_size(n) when n >= 1_000 and n < 1_000_000, + do: "#{float_format(n / 1_000)} kilobytes" + + defp readable_size(n) when n >= 1_000_000 and n < 1_000_000_000, + do: "#{float_format(n / 1_000_000)} megabytes" + + defp readable_size(n) when n >= 1_000_000_000 and n < 1_000_000_000_000, + do: "#{float_format(n / 1_000_000_000)} gigabytes" + + defp float_format(value), do: :io_lib.format("~.2f", [value]) + defp assert_table_rex!(fn_name) do unless Code.ensure_loaded?(TableRex) do raise RuntimeError, """ @@ -93,7 +108,6 @@ defmodule Axon.Display do defp do_axon_to_rows( %Axon.Node{ id: id, - op: structure, op_name: :container, parent: [parents], name: name_fn @@ -104,7 +118,7 @@ defmodule Axon.Display do op_counts, model_info ) do - {input_names, {cache, op_counts, model_info}} = + {_, {cache, op_counts, model_info}} = Enum.map_reduce(parents, {cache, op_counts, model_info}, fn parent_id, {cache, op_counts, model_info} -> {_, name, _shape, cache, op_counts, model_info} = @@ -119,11 +133,11 @@ defmodule Axon.Display do shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates) row = [ - "#{name} ( #{op_string} #{inspect(apply(structure, input_names))} )", + "#{name} ( #{op_string} )", "#{inspect({})}", - "#{inspect(shape)}", + render_output_shape(shape), render_options([]), - render_parameters(%{}, []) + render_parameters(nil, %{}, []) ] {row, name, shape, cache, op_counts, model_info} @@ -136,7 +150,7 @@ defmodule Axon.Display do parameters: params, name: name_fn, opts: opts, - policy: %{params: {_, bitsize}}, + policy: %{params: params_policy}, op_name: op_name }, nodes, @@ -145,6 +159,12 @@ defmodule Axon.Display do op_counts, model_info ) do + bitsize = + case params_policy do + nil -> 32 + {_, bitsize} -> bitsize + end + {input_names_and_shapes, {cache, op_counts, model_info}} = Enum.map_reduce(parents, {cache, op_counts, model_info}, fn parent_id, {cache, op_counts, model_info} -> @@ -154,39 +174,34 @@ defmodule Axon.Display do {{name, shape}, {cache, op_counts, model_info}} end) - {input_names, input_shapes} = Enum.unzip(input_names_and_shapes) + {_, input_shapes} = Enum.unzip(input_names_and_shapes) + + inputs = + Map.new(input_names_and_shapes, fn {name, shape} -> + {name, render_output_shape(shape)} + end) num_params = Enum.reduce(params, 0, fn %Parameter{shape: {:tuple, shapes}}, acc -> Enum.reduce(shapes, acc, &(Nx.size(apply(&1, input_shapes)) + &2)) - %Parameter{shape: shape_fn}, acc -> + %Parameter{template: shape_fn}, acc when is_function(shape_fn) -> acc + Nx.size(apply(shape_fn, input_shapes)) end) param_byte_size = num_params * div(bitsize, 8) op_inspect = Atom.to_string(op_name) - - inputs = - case input_names do - [] -> - "" - - [_ | _] = input_names -> - "#{inspect(input_names)}" - end - name = name_fn.(op_name, op_counts) shape = Axon.get_output_shape(%Axon{output: id, nodes: nodes}, templates) row = [ - "#{name} ( #{op_inspect}#{inputs} )", - "#{inspect(input_shapes)}", - "#{inspect(shape)}", + "#{name} ( #{op_inspect} )", + "#{inspect(inputs)}", + render_output_shape(shape), render_options(opts), - render_parameters(params, input_shapes) + render_parameters(params_policy, params, input_shapes) ] model_info = @@ -200,6 +215,14 @@ defmodule Axon.Display do {row, name, shape, cache, op_counts, model_info} end + defp render_output_shape(%Nx.Tensor{} = template) do + type = type_str(Nx.type(template)) + shape = shape_string(Nx.shape(template)) + "#{type}#{shape}" + end + + defp type_str({type, size}), do: "#{Atom.to_string(type)}#{size}" + defp render_options(opts) do opts |> Enum.map(fn {key, val} -> @@ -209,21 +232,23 @@ defmodule Axon.Display do |> Enum.join("\n") end - defp render_parameters(params, input_shapes) do + defp render_parameters(policy, params, input_shapes) do + type = policy || {:f, 32} + params |> Enum.map(fn %Parameter{name: name, shape: {:tuple, shape_fns}} -> shapes = shape_fns |> Enum.map(&apply(&1, input_shapes)) - |> Enum.map(fn shape -> "f32#{shape_string(shape)}" end) + |> Enum.map(fn shape -> "#{type_str(type)}#{shape_string(shape)}" end) |> List.to_tuple() "#{name}: tuple#{inspect(shapes)}" - %Parameter{name: name, shape: shape_fn} -> - shape = apply(shape_fn, input_shapes) - "#{name}: f32#{shape_string(shape)}" + %Parameter{name: name, template: shape_fn} when is_function(shape_fn) -> + shape = Nx.shape(apply(shape_fn, input_shapes)) + "#{name}: #{type_str(type)}#{shape_string(shape)}" end) |> Enum.join("\n") end