From ce2e24791fb69a9d0e7f9ec7e1983541064b7846 Mon Sep 17 00:00:00 2001 From: Barna Kovacs Date: Sun, 13 Oct 2024 22:52:43 +0200 Subject: [PATCH] Fix: call apply/3 as intended (#598) * Fix: call apply/3 as intended * Add tests for Axon.Quantizaiton.weight_only_quantized_dense --- lib/axon/quantization.ex | 2 +- test/axon/quantization_test.exs | 14 ++++++++++++++ 2 files changed, 15 insertions(+), 1 deletion(-) diff --git a/lib/axon/quantization.ex b/lib/axon/quantization.ex index b48d18ff..ed976b8d 100644 --- a/lib/axon/quantization.ex +++ b/lib/axon/quantization.ex @@ -132,7 +132,7 @@ defmodule Axon.Quantization do fun = case opts[:kernel_initializer] do init when is_atom(init) -> - apply(Axon.Initializers, []) + apply(Axon.Initializers, init, []) fun when is_function(fun) -> fun diff --git a/test/axon/quantization_test.exs b/test/axon/quantization_test.exs index 4a289ce0..3d728158 100644 --- a/test/axon/quantization_test.exs +++ b/test/axon/quantization_test.exs @@ -42,4 +42,18 @@ defmodule Axon.QuantizationTest do assert_equal(predict_fn.(quantized_model_state, inp), real_fn.(quantized_model_state, inp)) end end + + describe "weight_only_quantized_dense" do + test "inits and executes properly" do + model = + Axon.input("input") + |> Axon.Quantization.weight_only_quantized_dense(10) + + assert {init_fn, _} = Axon.build(model) + assert %ModelState{} = model_state = init_fn.(Nx.template({1, 1}, :f32), ModelState.empty()) + + assert {_, predict_fn} = Axon.build(model) + assert predict_fn.(model_state, Nx.broadcast(1.0, {1, 1})) + end + end end