Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme.gradient allocates on SVector #1968

Open
gdalle opened this issue Oct 16, 2024 · 5 comments
Open

Enzyme.gradient allocates on SVector #1968

gdalle opened this issue Oct 16, 2024 · 5 comments
Assignees

Comments

@gdalle
Copy link
Contributor

gdalle commented Oct 16, 2024

Hi!
As you know, @ExpandingMan and I are looking to optimize performance for StaticArrays. Forward mode works splendidly, but reverse mode still makes one allocation during the gradient call:

using StaticArrays, Enzyme, BenchmarkTools
f(x) = sum(abs2, x);
x = SVector(1.0, 2.0);
@btime Enzyme.gradient(Enzyme.Reverse, f, $x)  # 8.999 ns (1 allocation: 32 bytes)
@btime Enzyme.autodiff(Enzyme.Reverse, f, Enzyme.Active($x))  # 4.218 ns (0 allocations: 0 bytes)

I found it surprising because Enzyme guesses the right activity for SVector:

Enzyme.guess_activity(typeof(x), Enzyme.Reverse)  # Active{SVector{2, Float64}}

The allocation happens on the following line:

Ref(make_zero($arg))

From what I understand, the generated function Enzyme.gradient puts a Ref there to treat every argument as (Mixed)Duplicated. This means that all gradient results are stored in the passed arguments:
(; derivs = ($(resargs...),), val = res[2])

Otherwise, you would have to recover some gradients from the result and others from the arguments, which is understandably tricky.
Do you think there is an easy fix in Enzyme? Otherwise, since DI only has one differentiated argument, I assume it will be rather straightfoward to call Enzyme.autodiff directly inside DI.gradient and recover allocation-free behavior.

Related:

@wsmoses
Copy link
Member

wsmoses commented Oct 16, 2024

sure, PR welcome!

@gdalle
Copy link
Contributor Author

gdalle commented Oct 16, 2024

Sure! I'll try to handle this case correctly in DI first, because it still errors at the moment. Once I have a handle on the single-argument solution, I'll try to tamper with the generated function to do the same for multiple arguments.

@wsmoses
Copy link
Member

wsmoses commented Nov 3, 2024

bump @gdalle

@wsmoses
Copy link
Member

wsmoses commented Nov 8, 2024

gentle ping @gdalle

@gdalle
Copy link
Contributor Author

gdalle commented Nov 8, 2024

Essentially this comes down to adding the option for Active inputs here:

Enzyme.jl/src/Enzyme.jl

Lines 1714 to 1720 in 42ecd12

if $arg isa Enzyme.Const
$arg
elseif $act
MixedDuplicated($arg, $shad)
else
Duplicated($arg, $shad)
end

The variable interpolated as $arg is a boolean defined as follows, which is a bit obscure to me, care to shed some light?

Enzyme.jl/src/Enzyme.jl

Lines 1686 to 1692 in 42ecd12

!($argidx isa Enzyme.Const) &&
Compiler.active_reg_inner(
Core.Typeof($argidx),
(),
nothing,
Val(true),
) == Compiler.ActiveState #=justActive=#

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants