Skip to content

Commit

Permalink
Release v0.7.0 (#596)
Browse files Browse the repository at this point in the history
  • Loading branch information
seanmor5 authored Oct 7, 2024
1 parent 5246eaa commit 35004b4
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 83 deletions.
8 changes: 4 additions & 4 deletions lib/axon/metrics.ex
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ defmodule Axon.Metrics do
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.true_positives(y_true, y_pred)
#Nx.Tensor<
u64
u32
1
>
"""
Expand Down Expand Up @@ -198,7 +198,7 @@ defmodule Axon.Metrics do
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.false_negatives(y_true, y_pred)
#Nx.Tensor<
u64
u32
3
>
"""
Expand Down Expand Up @@ -230,7 +230,7 @@ defmodule Axon.Metrics do
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.true_negatives(y_true, y_pred)
#Nx.Tensor<
u64
u32
1
>
"""
Expand Down Expand Up @@ -262,7 +262,7 @@ defmodule Axon.Metrics do
iex> y_pred = Nx.tensor([0.8, 0.6, 0.4, 0.2, 0.8, 0.2, 0.2])
iex> Axon.Metrics.false_positives(y_true, y_pred)
#Nx.Tensor<
u64
u32
2
>
"""
Expand Down
8 changes: 4 additions & 4 deletions mix.exs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defmodule Axon.MixProject do
use Mix.Project

@source_url "https://github.com/elixir-nx/axon"
@version "0.6.1"
@version "0.7.0"

def project do
[
Expand Down Expand Up @@ -35,9 +35,9 @@ defmodule Axon.MixProject do
# Run "mix help deps" to learn about dependencies.
defp deps do
[
{:exla, "~> 0.7.0", [only: :test] ++ exla_opts()},
{:torchx, "~> 0.7.0", [only: :test] ++ torchx_opts()},
{:nx, "~> 0.6.0 or ~> 0.7.0", nx_opts()},
{:nx, "~> 0.9", nx_opts()},
{:exla, "~> 0.9", [only: :test] ++ exla_opts()},
{:torchx, "~> 0.9", [only: :test] ++ torchx_opts()},
{:ex_doc, "~> 0.23", only: :docs},
{:table_rex, "~> 3.1.1", optional: true},
{:kino, "~> 0.7", optional: true},
Expand Down
29 changes: 15 additions & 14 deletions mix.lock
Original file line number Diff line number Diff line change
@@ -1,22 +1,23 @@
%{
"complex": {:hex, :complex, "0.5.0", "af2d2331ff6170b61bb738695e481b27a66780e18763e066ee2cd863d0b1dd92", [:mix], [], "hexpm", "2683bd3c184466cfb94fad74cbfddfaa94b860e27ad4ca1bffe3bff169d91ef1"},
"earmark_parser": {:hex, :earmark_parser, "1.4.39", "424642f8335b05bb9eb611aa1564c148a8ee35c9c8a8bba6e129d51a3e3c6769", [:mix], [], "hexpm", "06553a88d1f1846da9ef066b87b57c6f605552cfbe40d20bd8d59cc6bde41944"},
"elixir_make": {:hex, :elixir_make, "0.7.8", "505026f266552ee5aabca0b9f9c229cbb496c689537c9f922f3eb5431157efc7", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "7a71945b913d37ea89b06966e1342c85cfe549b15e6d6d081e8081c493062c07"},
"ex_doc": {:hex, :ex_doc, "0.31.1", "8a2355ac42b1cc7b2379da9e40243f2670143721dd50748bf6c3b1184dae2089", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.1", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1", [hex: :makeup_erlang, repo: "hexpm", optional: false]}], "hexpm", "3178c3a407c557d8343479e1ff117a96fd31bafe52a039079593fb0524ef61b0"},
"exla": {:hex, :exla, "0.7.0", "27fac40a580f0d3816fe3bf35c50dfc2f99597d26ac7e2aca4a3c62b89bb427f", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.6.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "d3bfc622deb52cec95efc9d76063891afc7cd33e38eddbb01f3385c53e043c40"},
"earmark_parser": {:hex, :earmark_parser, "1.4.41", "ab34711c9dc6212dda44fcd20ecb87ac3f3fce6f0ca2f28d4a00e4154f8cd599", [:mix], [], "hexpm", "a81a04c7e34b6617c2792e291b5a2e57ab316365c2644ddc553bb9ed863ebefa"},
"elixir_make": {:hex, :elixir_make, "0.8.4", "4960a03ce79081dee8fe119d80ad372c4e7badb84c493cc75983f9d3bc8bde0f", [:mix], [{:castore, "~> 0.1 or ~> 1.0", [hex: :castore, repo: "hexpm", optional: true]}, {:certifi, "~> 2.0", [hex: :certifi, repo: "hexpm", optional: true]}], "hexpm", "6e7f1d619b5f61dfabd0a20aa268e575572b542ac31723293a4c1a567d5ef040"},
"ex_doc": {:hex, :ex_doc, "0.34.2", "13eedf3844ccdce25cfd837b99bea9ad92c4e511233199440488d217c92571e8", [:mix], [{:earmark_parser, "~> 1.4.39", [hex: :earmark_parser, repo: "hexpm", optional: false]}, {:makeup_c, ">= 0.1.0", [hex: :makeup_c, repo: "hexpm", optional: true]}, {:makeup_elixir, "~> 0.14 or ~> 1.0", [hex: :makeup_elixir, repo: "hexpm", optional: false]}, {:makeup_erlang, "~> 0.1 or ~> 1.0", [hex: :makeup_erlang, repo: "hexpm", optional: false]}, {:makeup_html, ">= 0.1.0", [hex: :makeup_html, repo: "hexpm", optional: true]}], "hexpm", "5ce5f16b41208a50106afed3de6a2ed34f4acfd65715b82a0b84b49d995f95c1"},
"exla": {:hex, :exla, "0.9.0", "e048c7a3d33917c214774a7ea1a0c626eb9de01e3fb2423cf9e2b89ef6dada3a", [:make, :mix], [{:elixir_make, "~> 0.6", [hex: :elixir_make, repo: "hexpm", optional: false]}, {:nimble_pool, "~> 1.0", [hex: :nimble_pool, repo: "hexpm", optional: false]}, {:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}, {:xla, "~> 0.8.0", [hex: :xla, repo: "hexpm", optional: false]}], "hexpm", "cbd30b54992d0da01a5aaee361a3160fc29de05a9f6c3dbcbd1fa04b4aa72302"},
"fss": {:hex, :fss, "0.1.1", "9db2344dbbb5d555ce442ac7c2f82dd975b605b50d169314a20f08ed21e08642", [:mix], [], "hexpm", "78ad5955c7919c3764065b21144913df7515d52e228c09427a004afe9c1a16b0"},
"kino": {:hex, :kino, "0.12.3", "a5f48a243c60a7ac18ba23869f697b1c775fc7794e8cd55dd248ba33c6fe9445", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "a6dfa3d54ba0edec9ca6e5940154916b381901001f171c85a2d8c67869dbc2d8"},
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.11", "d3c2a00b3685b95f91833920d06cc9b1fd7fb293a2663d89affe9aaec16a5b77", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "5ccd9148ce7cfcc95a137e12596cd8b95b371e9ea107e745bc262c39c5d8d48e"},
"makeup": {:hex, :makeup, "1.1.1", "fa0bc768698053b2b3869fa8a62616501ff9d11a562f3ce39580d60860c3a55e", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "5dc62fbdd0de44de194898b6710692490be74baa02d9d108bc29f007783b0b48"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.1", "cc9e3ca312f1cfeccc572b37a09980287e243648108384b97ff2b76e505c3555", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "e127a341ad1b209bd80f7bd1620a15693a9908ed780c3b763bccf7d200c767c6"},
"makeup_erlang": {:hex, :makeup_erlang, "0.1.4", "29563475afa9b8a2add1b7a9c8fb68d06ca7737648f28398e04461f008b69521", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "f4ed47ecda66de70dd817698a703f8816daa91272e7e45812469498614ae8b29"},
"kino": {:hex, :kino, "0.14.1", "c499afb1cd0be462feaf0a75c0631aa65aacc545b1c10f431b439b74f104be22", [:mix], [{:fss, "~> 0.1.0", [hex: :fss, repo: "hexpm", optional: false]}, {:nx, "~> 0.1", [hex: :nx, repo: "hexpm", optional: true]}, {:plug, "~> 1.0", [hex: :plug, repo: "hexpm", optional: true]}, {:table, "~> 0.1.2", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "090aea1aaa267e42e5ac24ee6bc5ed515aecc0a9edb8619aa4ee839201e704aa"},
"kino_vega_lite": {:hex, :kino_vega_lite, "0.1.13", "03c00405987a2202e4b8014ee55eb7f5727691b3f13d76a3764f6eeccef45322", [:mix], [{:kino, "~> 0.7", [hex: :kino, repo: "hexpm", optional: false]}, {:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}, {:vega_lite, "~> 0.1.8", [hex: :vega_lite, repo: "hexpm", optional: false]}], "hexpm", "00c72bc270e7b9d3c339f726cdab0012fd3f2fc75e36c7548e0f250fe420fa10"},
"makeup": {:hex, :makeup, "1.1.2", "9ba8837913bdf757787e71c1581c21f9d2455f4dd04cfca785c70bbfff1a76a3", [:mix], [{:nimble_parsec, "~> 1.2.2 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "cce1566b81fbcbd21eca8ffe808f33b221f9eee2cbc7a1706fc3da9ff18e6cac"},
"makeup_elixir": {:hex, :makeup_elixir, "0.16.2", "627e84b8e8bf22e60a2579dad15067c755531fea049ae26ef1020cad58fe9578", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}, {:nimble_parsec, "~> 1.2.3 or ~> 1.3", [hex: :nimble_parsec, repo: "hexpm", optional: false]}], "hexpm", "41193978704763f6bbe6cc2758b84909e62984c7752b3784bd3c218bb341706b"},
"makeup_erlang": {:hex, :makeup_erlang, "1.0.1", "c7f58c120b2b5aa5fd80d540a89fdf866ed42f1f3994e4fe189abebeab610839", [:mix], [{:makeup, "~> 1.0", [hex: :makeup, repo: "hexpm", optional: false]}], "hexpm", "8a89a1eeccc2d798d6ea15496a6e4870b75e014d1af514b1b71fa33134f57814"},
"nimble_parsec": {:hex, :nimble_parsec, "1.4.0", "51f9b613ea62cfa97b25ccc2c1b4216e81df970acd8e16e8d1bdc58fef21370d", [:mix], [], "hexpm", "9c565862810fb383e9838c1dd2d7d2c437b3d13b267414ba6af33e50d2d1cf28"},
"nx": {:hex, :nx, "0.7.0", "cec684cada356e9d268af01daa758882f7372aa952716dbe0369c657abb9e762", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "68edaa48a5841495ecab0dd4cf7b11b2fc0ad809754ae7f82d9c4090b91acf55"},
"nimble_pool": {:hex, :nimble_pool, "1.1.0", "bf9c29fbdcba3564a8b800d1eeb5a3c58f36e1e11d7b7fb2e084a643f645f06b", [:mix], [], "hexpm", "af2e4e6b34197db81f7aad230c1118eac993acc0dae6bc83bac0126d4ae0813a"},
"nx": {:hex, :nx, "0.9.0", "03a622a27d93eaaa2d24ff9b812d9f675cc04eb0340ca3dd065674f3642867d3", [:mix], [{:complex, "~> 0.5", [hex: :complex, repo: "hexpm", optional: false]}, {:telemetry, "~> 0.4.0 or ~> 1.0", [hex: :telemetry, repo: "hexpm", optional: false]}], "hexpm", "3810a5a90db0654b6e538430c0fb473a22bfc11b3d02ea7834db493cf3f56153"},
"polaris": {:hex, :polaris, "0.1.0", "dca61b18e3e801ecdae6ac9f0eca5f19792b44a5cb4b8d63db50fc40fc038d22", [:mix], [{:nx, "~> 0.5", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "13ef2b166650e533cb24b10e2f3b8ab4f2f449ba4d63156e8c569527f206e2c2"},
"table": {:hex, :table, "0.1.2", "87ad1125f5b70c5dea0307aa633194083eb5182ec537efc94e96af08937e14a8", [:mix], [], "hexpm", "7e99bc7efef806315c7e65640724bf165c3061cdc5d854060f74468367065029"},
"table_rex": {:hex, :table_rex, "3.1.1", "0c67164d1714b5e806d5067c1e96ff098ba7ae79413cc075973e17c38a587caa", [:mix], [], "hexpm", "678a23aba4d670419c23c17790f9dcd635a4a89022040df7d5d772cb21012490"},
"telemetry": {:hex, :telemetry, "1.2.1", "68fdfe8d8f05a8428483a97d7aab2f268aaff24b49e0f599faa091f1d4e7f61c", [:rebar3], [], "hexpm", "dad9ce9d8effc621708f99eac538ef1cbe05d6a874dd741de2e689c47feafed5"},
"torchx": {:hex, :torchx, "0.7.0", "c71fd603b0133ed8709450d82aa3434cbcf485a37c9a68e9ebcce86f5e4fb7f0", [:mix], [{:nx, "~> 0.7.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "a324079c56bb67750b1da16f859d994982bb467020a8c2cba324639552f3adb8"},
"vega_lite": {:hex, :vega_lite, "0.1.8", "7f6119126ecaf4bc2c1854084370d7091424f5cce4795fbac044eee9963f0752", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "6c8a9271f850612dd8a90de8d1ebd433590ed07ffef76fc2397c240dc04d3fdc"},
"xla": {:hex, :xla, "0.6.0", "67bb7695efa4a23b06211dc212de6a72af1ad5a9e17325e05e0a87e4c241feb8", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "dd074daf942312c6da87c7ed61b62fb1a075bced157f1cc4d47af2d7c9f44fb7"},
"telemetry": {:hex, :telemetry, "1.3.0", "fedebbae410d715cf8e7062c96a1ef32ec22e764197f70cda73d82778d61e7a2", [:rebar3], [], "hexpm", "7015fc8919dbe63764f4b4b87a95b7c0996bd539e0d499be6ec9d7f3875b79e6"},
"torchx": {:hex, :torchx, "0.9.0", "936cbd32233f89d73700c39b7ef56f94b3f3541db03c90f8ddf6b3fe73260e28", [:mix], [{:nx, "~> 0.9.0", [hex: :nx, repo: "hexpm", optional: false]}], "hexpm", "4e057d6b93fc91191957230b2c61c408861b888abdf6a900baf0db4125405505"},
"vega_lite": {:hex, :vega_lite, "0.1.9", "d7a288665f916181b68d0a3617f1b3611d16a4dcd5fafb51b847b71db1159d4c", [:mix], [{:table, "~> 0.1.0", [hex: :table, repo: "hexpm", optional: false]}], "hexpm", "c6a056e763162198e73ae6dfb46c09753bb0298474410fd085074e1cdcee7418"},
"xla": {:hex, :xla, "0.8.0", "fef314d085dd3ee16a0816c095239938f80769150e15db16dfaa435553d7cb16", [:make, :mix], [{:elixir_make, "~> 0.4", [hex: :elixir_make, repo: "hexpm", optional: false]}], "hexpm", "739c61c8d93b97e12ba0369d10e76130224c208f1a76ad293e3581f056833e57"},
}
57 changes: 4 additions & 53 deletions test/axon/integration_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -241,55 +241,6 @@ defmodule Axon.IntegrationTest do
end)
end

test "gradient accumulation test" do
{train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337)

train =
train
|> Stream.map(fn {xs, ys} ->
{xs, one_hot(ys, num_classes: 2)}
end)
|> Enum.to_list()

[{x_test, _}] = Enum.take(train, 1)

model =
Axon.input("input")
|> Axon.dense(16)
|> Axon.dropout(rate: 0.1)
|> Axon.dense(2, activation: :softmax)

ExUnit.CaptureIO.capture_io(fn ->
results =
model
|> Axon.Loop.trainer(
:categorical_cross_entropy,
Polaris.Optimizers.adam(learning_rate: 5.0e-3),
gradient_accumulation_steps: 3
)
# TODO: Fix default output transform
|> Map.update(:output_transform, nil, fn _ -> & &1 end)
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.validate(model, train)
|> Axon.Loop.run(train, Axon.ModelState.empty(), epochs: 10)

assert %{step_state: %{model_state: model_state}, metrics: %{9 => last_epoch_metrics}} =
results

eval_results =
model
|> Axon.Loop.evaluator()
|> Axon.Loop.metric(:accuracy)
|> Axon.Loop.run(train, model_state)

assert %{0 => %{"accuracy" => final_model_val_accuracy}} = eval_results

assert_greater_equal(last_epoch_metrics["validation_accuracy"], 0.7)
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}
end)
end

test "deterministic training test" do
{train, _test} = get_test_data(100, 0, 10, {10}, 2, 1337)

Expand Down Expand Up @@ -525,8 +476,8 @@ defmodule Axon.IntegrationTest do
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}

assert Nx.type(model_state.data["dense_0"]["kernel"]) ==
unquote(Macro.escape(policy)).params
params_policy = unquote(Macro.escape(policy)).params || {:f, 32}
assert Nx.type(model_state.data["dense_0"]["kernel"]) == params_policy
end)
end

Expand Down Expand Up @@ -578,8 +529,8 @@ defmodule Axon.IntegrationTest do
assert_all_close(final_model_val_accuracy, last_epoch_metrics["validation_accuracy"])
assert Nx.shape(Axon.predict(model, model_state, x_test)) == {10, 2}

assert Nx.type(model_state.data["dense_0"]["kernel"]) ==
unquote(Macro.escape(policy)).params
params_policy = unquote(Macro.escape(policy)).params || {:f, 32}
assert Nx.type(model_state.data["dense_0"]["kernel"]) == params_policy
end)
end
end
Expand Down
31 changes: 23 additions & 8 deletions test/axon/loss_scale_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -244,15 +244,26 @@ defmodule Axon.LossScaleTest do

non_finite = Nx.tensor([:infinity, :infinity, :infinity])

# TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26
# is fixed
for i <- 0..62, reduce: state do
for i <- 0..99, reduce: state do
new_state ->
{_, %{loss_scale: loss_scale, counter: counter} = new_state} =
adjust_fn.(non_finite, new_state)

expected_new_scale = Nx.max(1, Nx.divide(init_scale, Nx.pow(factor, i + 1)))
# We want to check if init_scale / factor ** (i + 1) is greater than 1.
# If we rely on `i` directly, we run into integer overflow issues.
# Instead, we accumulate the divisor on the reduce.

scale_divisor = 2 ** (i + 1)

expected_new_scale =
if scale_divisor >= 2 ** 32 do
Nx.tensor(1)
else
Nx.max(1, Nx.divide(init_scale, scale_divisor))
end

assert_equal(counter, Nx.tensor(0))

assert_all_close(loss_scale, expected_new_scale)

new_state
Expand All @@ -277,15 +288,19 @@ defmodule Axon.LossScaleTest do

non_finite = Nx.tensor([:infinity, :infinity, :infinity])

# TODO: increase to 99 when https://github.com/elixir-nx/complex/issues/26
# is fixed
for i <- 0..62, reduce: state do
for i <- 0..99, reduce: state do
new_state ->
{_, %{loss_scale: loss_scale, counter: counter} = new_state} =
adjust_fn.(non_finite, new_state)

scale_divisor = 2 ** (i + 1)

expected_new_scale =
Nx.max(min_loss_scale, Nx.divide(init_scale, Nx.pow(factor, i + 1)))
if scale_divisor >= 2 ** 32 do
Nx.tensor(min_loss_scale)
else
Nx.max(min_loss_scale, Nx.divide(init_scale, scale_divisor))
end

assert_equal(counter, Nx.tensor(0))
assert_all_close(loss_scale, expected_new_scale)
Expand Down

0 comments on commit 35004b4

Please sign in to comment.