Skip to content

Commit

Permalink
Provide build_linear_operator utility.
Browse files Browse the repository at this point in the history
This function allows to build a SciMLOperators.FunctionOperator from a matrix-valued function `A(p)` to represent the matrix-vector product `A(p) * u` in matrix-free form.
  • Loading branch information
lassepe committed Dec 27, 2024
1 parent 259b824 commit de782ae
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 9 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ version = "0.1.1"

[deps]
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
FastDifferentiation = "0.3,0.4"
SciMLOperators = "0.3.12"
SparseArrays = "1.10"
Symbolics = "4,5,6"
julia = "1.10"
2 changes: 2 additions & 0 deletions src/SymbolicTracingUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@ module SymbolicTracingUtils
using Symbolics: Symbolics
using FastDifferentiation: FastDifferentiation as FD
using SparseArrays: SparseArrays
using SciMLOperators: FunctionOperator

export build_function,
build_linear_operator,
FastDifferentiationBackend,
get_constant_entries,
get_result_buffer,
Expand Down
21 changes: 21 additions & 0 deletions src/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,4 +78,25 @@ function build_function(
end
end
end

"""
Build a linear SciMLOperators.FunctionOperator from a matrix-valued function `A(p)`
to represent the matrix-vector product `A(p) * u` in matrix-free form.
"""
function build_linear_operator(A_of_p::AbstractMatrix{<:SymbolicNumber}, p; in_place)
u = make_variables(infer_backend(A_of_p), gensym(), size(A_of_p)[end])
A_of_p_times_u = build_function(A_of_p * u, p, u; in_place)
# TODO: also analyze symmetry and other matrix properties to forward to the operator
input_prototype = zeros(size(u))
p_prototype = zeros(size(p))

if in_place
FunctionOperator(input_prototype; p = p_prototype, islinear = true) do result, u, p, _t
A_of_p_times_u(result, p, u)
end
else
FunctionOperator(input_prototype; p = p_prototype, islinear = true) do u, p, _t
A_of_p_times_u(p, u)
end
end
end
35 changes: 26 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using SymbolicTracingUtils
using Test: @test, @testset, @test_broken
using LinearAlgebra: Diagonal
using LinearAlgebra: Diagonal, mul!
using SparseArrays: spzeros, findnz, nnz, rowvals

function dummy_function(x)
Expand All @@ -17,10 +17,13 @@ end
global x = make_variables(backend, :x, 10)
global fx = dummy_function(x)
global x_value = [1:10;]
global y_true = dummy_function(x_value)
global g_true = dummy_function_gradient(x_value)
global J_true = Diagonal(dummy_function_gradient(x_value))

@testset "non-ad-tracing" begin
f = build_function(fx, x; in_place = false)
f! = build_function(fx, x; in_place = true)
y_true = dummy_function(x_value)
y_out_of_place = f(x_value)
y_in_place = zeros(10)
f!(y_in_place, x_value)
Expand All @@ -32,19 +35,17 @@ end
gx = gradient(sum(fx), x)
g = build_function(gx, x; in_place = false)
g! = build_function(gx, x; in_place = true)
y_true = dummy_function_gradient(x_value)
y_out_of_place = g(x_value)
y_in_place = zeros(10)
g!(y_in_place, x_value)
@test y_out_of_place y_true
@test y_in_place y_true
g_out_of_place = g(x_value)
g_in_place = zeros(10)
g!(g_in_place, x_value)
@test g_out_of_place g_true
@test g_in_place g_true
end

@testset "jacobian" begin
Jx = jacobian(fx, x)
J = build_function(Jx, x; in_place = false)
J! = build_function(Jx, x; in_place = true)
J_true = Diagonal(dummy_function_gradient(x_value))
J_out_of_place = J(x_value)
J_in_place = zeros(10, 10)
J!(J_in_place, x_value)
Expand Down Expand Up @@ -85,6 +86,22 @@ end
@test nnz(J_sparse!) == nnz(Jx) # same structure as symbolic version
@test rowvals(J_sparse!) == rows
end

@testset "build_linear_operator" begin
J_op = build_linear_operator(Jx, x; in_place = false)
J_op! = build_linear_operator(Jx, x; in_place = true)
v_value = [11.0:20.0;]
Jv_true = J_true * v_value

J_op.p = x_value
Jv_out_of_place = J_op * v_value
@test Jv_out_of_place Jv_true

J_op!.p = x_value
Jv_in_place = zeros(10)
mul!(Jv_in_place, J_op!, v_value)
@test Jv_in_place Jv_true
end
end
end
end
Expand Down

0 comments on commit de782ae

Please sign in to comment.