From 1bacd0d18bea4697486d76a3ad06fb793c064cd2 Mon Sep 17 00:00:00 2001 From: lassepe Date: Sat, 28 Dec 2024 00:01:25 +0100 Subject: [PATCH] Provide `build_linear_operator` utility. This function allows to build a SciMLOperators.FunctionOperator from a matrix-valued function `A(p)` to represent the matrix-vector product `A(p) * u` in matrix-free form. --- Project.toml | 2 ++ src/SymbolicTracingUtils.jl | 2 ++ src/tracing.jl | 21 +++++++++++++++++++++ test/runtests.jl | 35 ++++++++++++++++++++++++++--------- 4 files changed, 51 insertions(+), 9 deletions(-) diff --git a/Project.toml b/Project.toml index 2dea498..12a7fa5 100644 --- a/Project.toml +++ b/Project.toml @@ -5,11 +5,13 @@ version = "0.1.2" [deps] FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be" +SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7" [compat] FastDifferentiation = "0.3,0.4" +SciMLOperators = "0.3.12" SparseArrays = "1.10" Symbolics = "4,5,6" julia = "1.10" diff --git a/src/SymbolicTracingUtils.jl b/src/SymbolicTracingUtils.jl index 4a90739..605b87b 100644 --- a/src/SymbolicTracingUtils.jl +++ b/src/SymbolicTracingUtils.jl @@ -9,8 +9,10 @@ module SymbolicTracingUtils using Symbolics: Symbolics using FastDifferentiation: FastDifferentiation as FD using SparseArrays: SparseArrays +using SciMLOperators: FunctionOperator export build_function, + build_linear_operator, FastDifferentiationBackend, get_constant_entries, get_result_buffer, diff --git a/src/tracing.jl b/src/tracing.jl index 080f4ec..0d04479 100644 --- a/src/tracing.jl +++ b/src/tracing.jl @@ -78,4 +78,25 @@ function build_function( end end end + +""" +Build a linear SciMLOperators.FunctionOperator from a matrix-valued function `A(p)` +to represent the matrix-vector product `A(p) * u` in matrix-free form. +""" +function build_linear_operator(A_of_p::AbstractMatrix{<:SymbolicNumber}, p; in_place) + u = make_variables(infer_backend(A_of_p), gensym(), size(A_of_p)[end]) + A_of_p_times_u = build_function(A_of_p * u, p, u; in_place) + # TODO: also analyze symmetry and other matrix properties to forward to the operator + input_prototype = zeros(size(u)) + p_prototype = zeros(size(p)) + + if in_place + FunctionOperator(input_prototype; p = p_prototype, islinear = true) do result, u, p, _t + A_of_p_times_u(result, p, u) + end + else + FunctionOperator(input_prototype; p = p_prototype, islinear = true) do u, p, _t + A_of_p_times_u(p, u) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index f060dc9..94aa189 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,6 @@ using SymbolicTracingUtils using Test: @test, @testset, @test_broken -using LinearAlgebra: Diagonal +using LinearAlgebra: Diagonal, mul! using SparseArrays: spzeros, findnz, nnz, rowvals function dummy_function(x) @@ -17,10 +17,13 @@ end global x = make_variables(backend, :x, 10) global fx = dummy_function(x) global x_value = [1:10;] + global y_true = dummy_function(x_value) + global g_true = dummy_function_gradient(x_value) + global J_true = Diagonal(dummy_function_gradient(x_value)) + @testset "non-ad-tracing" begin f = build_function(fx, x; in_place = false) f! = build_function(fx, x; in_place = true) - y_true = dummy_function(x_value) y_out_of_place = f(x_value) y_in_place = zeros(10) f!(y_in_place, x_value) @@ -32,19 +35,17 @@ end gx = gradient(sum(fx), x) g = build_function(gx, x; in_place = false) g! = build_function(gx, x; in_place = true) - y_true = dummy_function_gradient(x_value) - y_out_of_place = g(x_value) - y_in_place = zeros(10) - g!(y_in_place, x_value) - @test y_out_of_place ≈ y_true - @test y_in_place ≈ y_true + g_out_of_place = g(x_value) + g_in_place = zeros(10) + g!(g_in_place, x_value) + @test g_out_of_place ≈ g_true + @test g_in_place ≈ g_true end @testset "jacobian" begin Jx = jacobian(fx, x) J = build_function(Jx, x; in_place = false) J! = build_function(Jx, x; in_place = true) - J_true = Diagonal(dummy_function_gradient(x_value)) J_out_of_place = J(x_value) J_in_place = zeros(10, 10) J!(J_in_place, x_value) @@ -85,6 +86,22 @@ end @test nnz(J_sparse!) == nnz(Jx) # same structure as symbolic version @test rowvals(J_sparse!) == rows end + + @testset "build_linear_operator" begin + J_op = build_linear_operator(Jx, x; in_place = false) + J_op! = build_linear_operator(Jx, x; in_place = true) + v_value = [11.0:20.0;] + Jv_true = J_true * v_value + + J_op.p = x_value + Jv_out_of_place = J_op * v_value + @test Jv_out_of_place ≈ Jv_true + + J_op!.p = x_value + Jv_in_place = zeros(10) + mul!(Jv_in_place, J_op!, v_value) + @test Jv_in_place ≈ Jv_true + end end end end