Skip to content

Commit

Permalink
Remove dead code related to namespaces (#602)
Browse files Browse the repository at this point in the history
They were removed in this commit 44f36b5
  • Loading branch information
ahamez authored Nov 13, 2024
1 parent d1dc13b commit e5c8a37
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 94 deletions.
4 changes: 2 additions & 2 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3875,7 +3875,7 @@ defmodule Axon do
## `init_fn`
The `init_fn` receives two arguments, the input template and
an optional map with initial parameters for layers or namespaces:
an optional map with initial parameters for layers:
{init_fn, predict_fn} = Axon.build(model)
init_fn.(Nx.template({1, 1}, {:f, 32}), %{"dense_0" => dense_params})
Expand Down Expand Up @@ -3968,7 +3968,7 @@ defmodule Axon do
purposes.
You may optionally specify initial parameters for some layers or
namespaces by passing a partial parameter map:
by passing a partial parameter map:
Axon.trace_init(model, %{"dense_0" => dense_params})
Expand Down
92 changes: 0 additions & 92 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -687,98 +687,6 @@ defmodule Axon.Compiler do
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
end

defp recur_model_funs(
%Axon.Node{id: id, op: :namespace, name: name_fn, parent: [parent]},
nodes,
{cache, op_counts, block_cache, model_state_meta},
config
) do
name = name_fn.(:namespace, op_counts)
# To ensure that a namespace always has the same layer names,
# we reset op_counts, input layers always belong to the global
# namespace, so we include those regardless
input_count = op_counts[:input] || 0
namespace_op_counts = %{input: input_count}
namespace_model_state_meta = %{parameters: %{}, state: %{}, frozen_parameters: %{}}

# All of the children of this namespace belong to it, so
# we forward this name to the namespace, but everything after
# it belongs to whatever namespace we're currently in
{parent_id, {cache, namespace_op_counts, block_cache, namespace_model_state_meta}} =
to_model_funs(
parent,
nodes,
{cache, namespace_op_counts, block_cache, namespace_model_state_meta},
config
)

# Update the global op_count of input layers, since they
# are a global operation regardless of where they are
input_count = namespace_op_counts[:input] || 0
op_counts = Map.put(op_counts, :input, input_count)

# Update the model state meta to include the namespace model state meta
model_state_meta =
model_state_meta
|> Map.update!(:parameters, &Map.put(&1, name, namespace_model_state_meta[:parameters]))
|> Map.update!(:state, &Map.put(&1, name, namespace_model_state_meta[:state]))
|> Map.update!(
:frozen_parameters,
&Map.put(&1, name, namespace_model_state_meta[:frozen_parameters])
)

# The function just returns the result of it's child,
# or parent depending on how you view the tree
predict_fun = fn params, inputs, state, cache, result_cache, fn_stacktrace ->
# We're only concerned with this namespaces parameters, so we pair
# down parameters first given the namespace
namespace_params = params[name]

# TODO: How should hooks be handled here?
# TODO: I think we can actually handle parameter freezing and access
# better here by only forwarding params[namespace] to the child function
{out, {state, result_cache}} =
call_predict_cache(
parent_id,
namespace_params,
inputs,
state,
cache,
result_cache,
fn_stacktrace
)

state =
if map_size(state) == 0 do
state
else
%{name => state}
end

{out, {state, result_cache}}
end

init_fun = fn template, cache, result_cache, fn_stacktrace, keys ->
{_parent_template, {namespace_params, result_cache}} =
call_init_cache(parent_id, template, %{}, cache, result_cache, fn_stacktrace, keys)

params =
if namespace_params == %{} do
%{}
else
%{name => namespace_params}
end

{pred_expr, {_, result_cache}} =
predict_fun.(params, template, %{}, cache, result_cache, fn_stacktrace)

{Nx.to_template(pred_expr), {params, result_cache}}
end

model_funs = %{predict: predict_fun, init: init_fun}
{id, model_funs, cache, op_counts, block_cache, model_state_meta}
end

defp recur_model_funs(
%Axon.Node{
id: id,
Expand Down

0 comments on commit e5c8a37

Please sign in to comment.