Skip to content

Commit

Permalink
make attach_hook API compatible with Axon.Node and fix doc example on…
Browse files Browse the repository at this point in the history
… Axon.map_nodes (#519)
  • Loading branch information
robinmonjo authored Aug 13, 2023
1 parent a3a05ac commit 80152d3
Showing 1 changed file with 16 additions and 10 deletions.
26 changes: 16 additions & 10 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3177,19 +3177,25 @@ defmodule Axon do
"""
@doc type: :debug
def attach_hook(%Axon{output: id, nodes: nodes} = axon, fun, opts \\ []) do
opts = Keyword.validate!(opts, on: :forward, mode: :both)
on_event = opts[:on]
mode = opts[:mode]
def attach_hook(x, fun, opts \\ [])

def attach_hook(%Axon{output: id, nodes: nodes} = axon, fun, opts) do
updated_nodes =
Map.update!(nodes, id, fn axon_node ->
%{axon_node | hooks: [{on_event, mode, fun} | axon_node.hooks]}
attach_hook(axon_node, fun, opts)
end)

%{axon | nodes: updated_nodes}
end

def attach_hook(%Axon.Node{hooks: hooks} = axon_node, fun, opts) do
opts = Keyword.validate!(opts, on: :forward, mode: :both)
on_event = opts[:on]
mode = opts[:mode]

%{axon_node | hooks: [{on_event, mode, fun} | hooks]}
end

## Graph Manipulation and Utilities

# TODO: Revisit later with new decoupled structs
Expand Down Expand Up @@ -3353,12 +3359,12 @@ defmodule Axon do
you can use this function to visualize intermediate activations
of all convolutional layers in a model:
instrumented_model = Axon. (model, fn
%Axon{op: :conv} = graph ->
Axon.attach_hook(graph, &visualize_activations/1)
instrumented_model = Axon.map_nodes(model, fn
%Axon.Node{op: :conv} = axon_node ->
Axon.attach_hook(axon_node, &visualize_activations/1)
graph ->
graph
axon_node ->
axon_node
end)
Another use case is to replace entire classes of layers
Expand Down

0 comments on commit 80152d3

Please sign in to comment.