Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: build linear operator #2

Merged
merged 3 commits into from
Dec 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@ version = "0.1.2"

[deps]
FastDifferentiation = "eb9bf01b-bf85-4b60-bf87-ee5de06c00be"
SciMLOperators = "c0aeaf25-5076-4817-a8d5-81caf7dfa961"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[compat]
FastDifferentiation = "0.3,0.4"
SciMLOperators = "0.3.12"
SparseArrays = "1.10"
Symbolics = "4,5,6"
julia = "1.10"
5 changes: 5 additions & 0 deletions src/SymbolicTracingUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,15 @@ module SymbolicTracingUtils
using Symbolics: Symbolics
using FastDifferentiation: FastDifferentiation as FD
using SparseArrays: SparseArrays
using SciMLOperators: FunctionOperator

export build_function,
build_linear_operator,
FastDifferentiationBackend,
get_constant_entries,
get_result_buffer,
gradient,
infer_backend,
jacobian,
make_variables,
sparse_jacobian,
Expand All @@ -25,6 +28,8 @@ export build_function,
struct SymbolicsBackend end
struct FastDifferentiationBackend end
const SymbolicNumber = Union{Symbolics.Num,FD.Node}
infer_backend(v::Union{Symbolics.Num,AbstractArray{<:Symbolics.Num}}) = SymbolicsBackend()
infer_backend(v::Union{FD.Node,AbstractArray{<:FD.Node}}) = FastDifferentiationBackend()

include("tracing.jl")
include("derivatives.jl")
Expand Down
34 changes: 33 additions & 1 deletion src/tracing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,5 +66,37 @@ function build_function(
in_place,
backend_options = (;),
) where {T<:FD.Node}
FD.make_function(f_symbolic, args_symbolic...; in_place, backend_options...)
f = FD.make_function(f_symbolic, args_symbolic...; in_place, backend_options...)

if in_place
function (result, args...)
f(result, reduce(vcat, args))
end
else
function (args...)
f(reduce(vcat, args))
end
end
end

"""
Build a linear SciMLOperators.FunctionOperator from a matrix-valued function `A(p)`
to represent the matrix-vector product `A(p) * u` in matrix-free form.
"""
function build_linear_operator(A_of_p::AbstractMatrix{<:SymbolicNumber}, p; in_place)
u = make_variables(infer_backend(A_of_p), gensym(), size(A_of_p)[end])
A_of_p_times_u = build_function(A_of_p * u, p, u; in_place)
# TODO: also analyze symmetry and other matrix properties to forward to the operator
input_prototype = zeros(size(u))
p_prototype = zeros(size(p))

if in_place
FunctionOperator(input_prototype; p = p_prototype, islinear = true) do result, u, p, _t
A_of_p_times_u(result, p, u)
end
else
FunctionOperator(input_prototype; p = p_prototype, islinear = true) do u, p, _t
A_of_p_times_u(p, u)
end
end
end
35 changes: 26 additions & 9 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
using SymbolicTracingUtils
using Test: @test, @testset, @test_broken
using LinearAlgebra: Diagonal
using LinearAlgebra: Diagonal, mul!
using SparseArrays: spzeros, findnz, nnz, rowvals

function dummy_function(x)
Expand All @@ -17,10 +17,13 @@ end
global x = make_variables(backend, :x, 10)
global fx = dummy_function(x)
global x_value = [1:10;]
global y_true = dummy_function(x_value)
global g_true = dummy_function_gradient(x_value)
global J_true = Diagonal(dummy_function_gradient(x_value))

@testset "non-ad-tracing" begin
f = build_function(fx, x; in_place = false)
f! = build_function(fx, x; in_place = true)
y_true = dummy_function(x_value)
y_out_of_place = f(x_value)
y_in_place = zeros(10)
f!(y_in_place, x_value)
Expand All @@ -32,19 +35,17 @@ end
gx = gradient(sum(fx), x)
g = build_function(gx, x; in_place = false)
g! = build_function(gx, x; in_place = true)
y_true = dummy_function_gradient(x_value)
y_out_of_place = g(x_value)
y_in_place = zeros(10)
g!(y_in_place, x_value)
@test y_out_of_place ≈ y_true
@test y_in_place ≈ y_true
g_out_of_place = g(x_value)
g_in_place = zeros(10)
g!(g_in_place, x_value)
@test g_out_of_place ≈ g_true
@test g_in_place ≈ g_true
end

@testset "jacobian" begin
Jx = jacobian(fx, x)
J = build_function(Jx, x; in_place = false)
J! = build_function(Jx, x; in_place = true)
J_true = Diagonal(dummy_function_gradient(x_value))
J_out_of_place = J(x_value)
J_in_place = zeros(10, 10)
J!(J_in_place, x_value)
Expand Down Expand Up @@ -85,6 +86,22 @@ end
@test nnz(J_sparse!) == nnz(Jx) # same structure as symbolic version
@test rowvals(J_sparse!) == rows
end

@testset "build_linear_operator" begin
J_op = build_linear_operator(Jx, x; in_place = false)
J_op! = build_linear_operator(Jx, x; in_place = true)
v_value = [11.0:20.0;]
Jv_true = J_true * v_value

J_op.p = x_value
Jv_out_of_place = J_op * v_value
@test Jv_out_of_place ≈ Jv_true

J_op!.p = x_value
Jv_in_place = zeros(10)
mul!(Jv_in_place, J_op!, v_value)
@test Jv_in_place ≈ Jv_true
end
end
end
end
Expand Down
Loading