From e5c8a37d8777356af5220eea03fc9e71bdf6fa4f Mon Sep 17 00:00:00 2001 From: Alexandre Hamez <199517+ahamez@users.noreply.github.com> Date: Wed, 13 Nov 2024 18:36:48 +0100 Subject: [PATCH] Remove dead code related to namespaces (#602) They were removed in this commit 44f36b58ed0b505cae8ce0bde121fb6e9fb0c2e8 --- lib/axon.ex | 4 +- lib/axon/compiler.ex | 92 -------------------------------------------- 2 files changed, 2 insertions(+), 94 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index b342c6c3..e1745214 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -3875,7 +3875,7 @@ defmodule Axon do ## `init_fn` The `init_fn` receives two arguments, the input template and - an optional map with initial parameters for layers or namespaces: + an optional map with initial parameters for layers: {init_fn, predict_fn} = Axon.build(model) init_fn.(Nx.template({1, 1}, {:f, 32}), %{"dense_0" => dense_params}) @@ -3968,7 +3968,7 @@ defmodule Axon do purposes. You may optionally specify initial parameters for some layers or - namespaces by passing a partial parameter map: + by passing a partial parameter map: Axon.trace_init(model, %{"dense_0" => dense_params}) diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index e2d90f30..50b0d565 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -687,98 +687,6 @@ defmodule Axon.Compiler do {id, model_funs, cache, op_counts, block_cache, model_state_meta} end - defp recur_model_funs( - %Axon.Node{id: id, op: :namespace, name: name_fn, parent: [parent]}, - nodes, - {cache, op_counts, block_cache, model_state_meta}, - config - ) do - name = name_fn.(:namespace, op_counts) - # To ensure that a namespace always has the same layer names, - # we reset op_counts, input layers always belong to the global - # namespace, so we include those regardless - input_count = op_counts[:input] || 0 - namespace_op_counts = %{input: input_count} - namespace_model_state_meta = %{parameters: %{}, state: %{}, frozen_parameters: %{}} - - # All of the children of this namespace belong to it, so - # we forward this name to the namespace, but everything after - # it belongs to whatever namespace we're currently in - {parent_id, {cache, namespace_op_counts, block_cache, namespace_model_state_meta}} = - to_model_funs( - parent, - nodes, - {cache, namespace_op_counts, block_cache, namespace_model_state_meta}, - config - ) - - # Update the global op_count of input layers, since they - # are a global operation regardless of where they are - input_count = namespace_op_counts[:input] || 0 - op_counts = Map.put(op_counts, :input, input_count) - - # Update the model state meta to include the namespace model state meta - model_state_meta = - model_state_meta - |> Map.update!(:parameters, &Map.put(&1, name, namespace_model_state_meta[:parameters])) - |> Map.update!(:state, &Map.put(&1, name, namespace_model_state_meta[:state])) - |> Map.update!( - :frozen_parameters, - &Map.put(&1, name, namespace_model_state_meta[:frozen_parameters]) - ) - - # The function just returns the result of it's child, - # or parent depending on how you view the tree - predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace -> - # We're only concerned with this namespaces parameters, so we pair - # down parameters first given the namespace - namespace_params = params[name] - - # TODO: How should hooks be handled here? - # TODO: I think we can actually handle parameter freezing and access - # better here by only forwarding params[namespace] to the child function - {out, {state, result_cache}} = - call_predict_cache( - parent_id, - namespace_params, - inputs, - state, - cache, - result_cache, - fn_stacktrace - ) - - state = - if map_size(state) == 0 do - state - else - %{name => state} - end - - {out, {state, result_cache}} - end - - init_fun = fn template, cache, result_cache, fn_stacktrace, keys -> - {_parent_template, {namespace_params, result_cache}} = - call_init_cache(parent_id, template, %{}, cache, result_cache, fn_stacktrace, keys) - - params = - if namespace_params == %{} do - %{} - else - %{name => namespace_params} - end - - {pred_expr, {_, result_cache}} = - predict_fun.(params, template, %{}, cache, result_cache, fn_stacktrace) - - {Nx.to_template(pred_expr), {params, result_cache}} - end - - model_funs = %{predict: predict_fun, init: init_fun} - {id, model_funs, cache, op_counts, block_cache, model_state_meta} - end - defp recur_model_funs( %Axon.Node{ id: id,