Skip to content

Commit

Permalink
Fix iteration counts (#572)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored May 14, 2024
1 parent 88c7823 commit f3d5bf9
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 27 deletions.
86 changes: 72 additions & 14 deletions lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ defmodule Axon.Loop do
loop
|> log(&supervised_log_message_fn/1,
event: :iteration_completed,
filter: [every: log_interval]
filter: [every: {:epoch, log_interval}]
)
|> log(fn _ -> "\n" end, event: :epoch_completed)
else
Expand Down Expand Up @@ -1912,6 +1912,29 @@ defmodule Axon.Loop do
end)
end

defp update_counts(%State{event_counts: event_counts} = state, event)
when event in [:iteration_started, :iteration_completed] do
updated_counts =
Map.update(event_counts, event, %{total: 1, epoch: 1}, fn total_and_epoch ->
total_and_epoch
|> Map.update!(:total, &(&1 + 1))
|> Map.update!(:epoch, &(&1 + 1))
end)

%{state | event_counts: updated_counts}
end

defp update_counts(%State{event_counts: event_counts} = state, event)
when event in [:epoch_halted, :epoch_completed] do
updated_counts =
event_counts
|> Map.update(:iteration_started, %{total: 0, epoch: 0}, &%{&1 | epoch: 0})
|> Map.update(:iteration_completed, %{total: 0, epoch: 0}, &%{&1 | epoch: 0})
|> Map.update(event, 1, &(&1 + 1))

%{state | event_counts: updated_counts}
end

defp update_counts(%State{event_counts: event_counts} = state, event) do
%{state | event_counts: Map.update(event_counts, event, 1, fn x -> x + 1 end)}
end
Expand Down Expand Up @@ -2165,29 +2188,53 @@ defmodule Axon.Loop do

:first ->
fn %State{event_counts: counts}, event ->
counts[event] == 1
case counts[event] do
1 -> true
%{total: 1} -> true
_ -> false
end
end

filters when is_list(filters) ->
Enum.reduce(filters, fn _, _ -> true end, fn
{:every, {key, n}}, acc ->
fn state, event ->
acc.(state, event) and filter_every_n(state, event, key, n)
end

{:every, n}, acc ->
fn state, event ->
acc.(state, event) and filter_every_n(state, event, n)
acc.(state, event) and filter_every_n(state, event, :total, n)
end

{:before, {key, n}}, acc ->
fn state, event ->
acc.(state, event) and filter_before_n(state, event, key, n)
end

{:before, n}, acc ->
fn state, event ->
acc.(state, event) and filter_before_n(state, event, n)
acc.(state, event) and filter_before_n(state, event, :total, n)
end

{:after, {key, n}}, acc ->
fn state, event ->
acc.(state, event) and filter_after_n(state, event, key, n)
end

{:after, n}, acc ->
fn state, event ->
acc.(state, event) and filter_after_n(state, event, n)
acc.(state, event) and filter_after_n(state, event, :total, n)
end

{:once, {key, n}}, acc ->
fn state, event ->
acc.(state, event) and filter_once_n(state, event, key, n)
end

{:once, n}, acc ->
fn state, event ->
acc.(state, event) and filter_once_n(state, event, n)
acc.(state, event) and filter_once_n(state, event, :total, n)
end
end)

Expand All @@ -2204,20 +2251,31 @@ defmodule Axon.Loop do
end
end

defp filter_every_n(%State{event_counts: counts}, event, n) do
rem(counts[event] - 1, n) == 0
defp filter_every_n(%State{event_counts: counts}, event, key, n) do
count = get_count(counts, event, key)
rem(count - 1, n) == 0
end

defp filter_after_n(%State{event_counts: counts}, event, n) do
counts[event] > n
defp filter_after_n(%State{event_counts: counts}, event, key, n) do
count = get_count(counts, event, key)
count > n
end

defp filter_before_n(%State{event_counts: counts}, event, n) do
counts[event] < n
defp filter_before_n(%State{event_counts: counts}, event, key, n) do
count = get_count(counts, event, key)
count < n
end

defp filter_once_n(%State{event_counts: counts}, event, n) do
counts[event] == n
defp filter_once_n(%State{event_counts: counts}, event, key, n) do
count = get_count(counts, event, key)
count == n
end

defp get_count(counts, event, key) do
case counts[event] do
%{^key => count} -> count
count -> count
end
end

# JIT-compiles the given function if jit_compile? is true
Expand Down
4 changes: 2 additions & 2 deletions lib/axon/loop/state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ defmodule Axon.Loop.State do
event_counts: %{
started: 0,
epoch_started: 0,
iteration_started: 0,
iteration_completed: 0,
iteration_started: %{total: 0, epoch: 0},
iteration_completed: %{total: 0, epoch: 0},
epoch_completed: 0,
epoch_halted: 0,
halted: 0,
Expand Down
23 changes: 12 additions & 11 deletions test/axon/loop_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -636,26 +636,26 @@ defmodule Axon.LoopTest do
started: 1,
epoch_started: 1,
epoch_completed: 1,
iteration_started: 10,
iteration_completed: 10
iteration_started: %{total: 10, epoch: 0},
iteration_completed: %{total: 10, epoch: 0}
}}

assert_received {:epoch_started,
%{
started: 1,
epoch_started: 2,
epoch_completed: 1,
iteration_started: 10,
iteration_completed: 10
iteration_started: %{total: 10, epoch: 0},
iteration_completed: %{total: 10, epoch: 0}
}}

assert_received {:epoch_completed,
%{
started: 1,
epoch_started: 2,
epoch_completed: 2,
iteration_started: 20,
iteration_completed: 20
iteration_started: %{total: 20, epoch: 0},
iteration_completed: %{total: 20, epoch: 0}
}}

refute_received _
Expand Down Expand Up @@ -786,7 +786,7 @@ defmodule Axon.LoopTest do

test "supports function filter" do
fun = fn
%{event_counts: counts}, event -> counts[event] == 5
%{event_counts: counts}, event -> counts[event][:total] == 5
end

run_dummy_loop!(:iteration_started, fun, 5, 10)
Expand Down Expand Up @@ -854,18 +854,19 @@ defmodule Axon.LoopTest do
test "saves a checkpoint on custom events", %{loop: loop} do
data = List.duplicate({Nx.iota({1, 1}), Nx.iota({1, 1})}, 5)

assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: 15}} =
assert %Axon.Loop.State{epoch: 3, iteration: 0, event_counts: %{iteration_completed: %{total: 15}}} =
loop
|> Map.put(:output_transform, & &1)
|> Loop.checkpoint(event: :iteration_completed, filter: [every: 2])
|> Loop.checkpoint(event: :iteration_completed, filter: [every: {:epoch, 2}])
|> Loop.run(data, Axon.ModelState.empty(), epochs: 3)

assert [
"checkpoint_0_0.ckpt",
"checkpoint_0_2.ckpt",
"checkpoint_0_4.ckpt",
"checkpoint_1_1.ckpt",
"checkpoint_1_3.ckpt",
"checkpoint_1_0.ckpt",
"checkpoint_1_2.ckpt",
"checkpoint_1_4.ckpt",
"checkpoint_2_0.ckpt",
"checkpoint_2_2.ckpt",
"checkpoint_2_4.ckpt"
Expand Down

0 comments on commit f3d5bf9

Please sign in to comment.