From cc7dec610c0030e16dcfd9d44912ffbea317216e Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Tue, 15 Oct 2024 21:04:32 -0700 Subject: [PATCH] Raise on ambiguous inputs (#599) --- lib/axon/compiler.ex | 33 +++++++++++++++++++++++++-------- test/axon/compiler_test.exs | 14 ++++++++++++++ test/axon/loop_test.exs | 6 +++++- 3 files changed, 44 insertions(+), 9 deletions(-) diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 1816e170..e2d90f30 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -486,15 +486,16 @@ defmodule Axon.Compiler do name: name_fn, opts: [shape: _input_shape, optional: optional?] }, - _nodes, + nodes, {cache, op_counts, block_cache, model_state_meta}, %{mode: mode, print_values: print_values} ) do name = name_fn.(:input, op_counts) op_counts = Map.update(op_counts, :input, 1, fn x -> x + 1 end) + all_inputs = get_all_inputs(nodes) predict_fun = fn _params, inputs, state, _cache, result_cache, _fn_stacktrace -> - value = get_input(inputs, name, optional?) + value = get_input(all_inputs, inputs, name, optional?) # TODO: Add this back in # validate_input_shape!(value, shape) @@ -509,7 +510,7 @@ defmodule Axon.Compiler do end init_fun = fn template, _cache, result_cache, _fn_stacktrace, _keys -> - input = get_input(template, name, optional?) + input = get_input(all_inputs, template, name, optional?) {Nx.to_template(input), {%{}, result_cache}} end @@ -889,16 +890,32 @@ defmodule Axon.Compiler do {id, model_funs, cache, op_counts, block_cache, model_state_meta} end - defp get_input(inputs, name, optional?) do + defp get_all_inputs(nodes) do + nodes + |> Enum.filter(fn {_, %{op: op}} -> op == :input end) + |> Enum.map(fn {_, %{name: name_fn}} -> + # inputs require a name, so we can just ignore op counts + name_fn.(:input, %{}) + end) + |> Enum.uniq() + end + + defp get_input(all_input_names, inputs, name, optional?) do res = - case inputs do - %Nx.Tensor{} = inputs -> + case {all_input_names, inputs} do + {[^name], %Nx.Tensor{} = inputs} -> inputs - %{} = inputs -> + {_, %Nx.Tensor{}} -> + raise ArgumentError, + "ambiguous input given to the model," <> + " expected inputs with names #{inspect(all_input_names)}" <> + " but received a single tensor as input" + + {_, %{} = inputs} -> inputs[name] - inputs when is_tuple(inputs) -> + {[^name], inputs} when is_tuple(inputs) -> inputs _ -> diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index f955690f..80c27c6d 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -128,6 +128,20 @@ defmodule CompilerTest do assert message =~ "exception found when compiling layer Axon.Layers.add/2 named add_0" assert message =~ "cannot broadcast tensor of dimensions {1, 32} to {1, 64}" end + + test "raises if inputs are ambiguous" do + x = Axon.input("x") + y = Axon.input("y") + model = Axon.add(x, y) + + {_, predict_fn} = Axon.build(model) + + exception = assert_raise ArgumentError, fn -> + predict_fn.(ModelState.empty(), Nx.tensor([1])) + end + + assert Exception.message(exception) =~ "ambiguous" + end end describe "optional" do diff --git a/test/axon/loop_test.exs b/test/axon/loop_test.exs index d88432f3..1b7ec8ac 100644 --- a/test/axon/loop_test.exs +++ b/test/axon/loop_test.exs @@ -132,7 +132,11 @@ defmodule Axon.LoopTest do Loop.trainer(model, [mean_squared_error: 0.5, mean_absolute_error: 0.5], :adam) assert %{model_state: %{}} = - pstate = init_fn.({Nx.tensor([[2]]), Nx.tensor([[2]])}, Axon.ModelState.empty()) + pstate = + init_fn.( + {%{"input_0" => Nx.tensor([[2]]), "input_1" => Nx.tensor([[2]])}, Nx.tensor(0)}, + Axon.ModelState.empty() + ) state = %State{step_state: pstate}