diff --git a/Project.toml b/Project.toml index eb39b92a85..166f951388 100644 --- a/Project.toml +++ b/Project.toml @@ -35,6 +35,12 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" +[weakdeps] +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" + +[extensions] +OceananigansEnzymeCoreExt = "EnzymeCore" + [compat] Adapt = "3" CUDA = "4, 5" @@ -42,6 +48,8 @@ Crayons = "4" CubedSphere = "0.1, 0.2" Distances = "0.10" DocStringExtensions = "0.8, 0.9" +EnzymeCore = "0.6" +Enzyme = "0.11.9" FFTW = "1" Glob = "1.3" IncompleteLU = "0.2" @@ -64,6 +72,8 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Coverage = "a2441757-f6aa-5fb2-8edb-039e3f45d037" CUDA_Runtime_jll = "76a88914-d11a-5bdc-97e0-2f5a05c973a2" DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" MPIPreferences = "3da0fdf6-3ccc-4f1b-acd9-58baa6c99267" OpenMPI_jll = "fe0851c0-eecd-5654-98d4-656369965a5c" @@ -74,4 +84,4 @@ TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" TimesDates = "bdfc003b-8df8-5c39-adcd-3a9087f5df4a" [targets] -test = ["BenchmarkTools", "Coverage", "CUDA_Runtime_jll", "DataDeps", "InteractiveUtils", "MPIPreferences", "OpenMPI_jll", "Plots", "Test", "TimerOutputs", "TimesDates", "SafeTestsets"] +test = ["BenchmarkTools", "Coverage", "CUDA_Runtime_jll", "DataDeps", "Enzyme", "InteractiveUtils", "MPIPreferences", "OpenMPI_jll", "Plots", "Test", "TimerOutputs", "TimesDates", "SafeTestsets"] \ No newline at end of file diff --git a/ext/OceananigansEnzymeCoreExt.jl b/ext/OceananigansEnzymeCoreExt.jl new file mode 100644 index 0000000000..9d273f56f1 --- /dev/null +++ b/ext/OceananigansEnzymeCoreExt.jl @@ -0,0 +1,136 @@ +module OceananigansEnzymeCoreExt + +using Oceananigans +using KernelAbstractions + +isdefined(Base, :get_extension) ? (import EnzymeCore) : (import ..EnzymeCore) + +EnzymeCore.EnzymeRules.inactive_noinl(::typeof(Oceananigans.Utils.flatten_reduced_dimensions), x...) = nothing +EnzymeCore.EnzymeRules.inactive(::typeof(Oceananigans.Grids.total_size), x...) = nothing + +@inline batch(::Val{1}, ::Type{T}) where T = T +@inline batch(::Val{N}, ::Type{T}) where {T,N} = NTuple{N,T} + +function EnzymeCore.EnzymeRules.augmented_primal(config, + func::EnzymeCore.Const{Type{Field}}, + ::Type{<:EnzymeCore.Annotation{RT}}, + loc::Union{EnzymeCore.Const{<:Tuple}, + EnzymeCore.Duplicated{<:Tuple}}, + grid::EnzymeCore.Const{<:Oceananigans.Grids.AbstractGrid}, + T::EnzymeCore.Const{<:DataType}; kw...) where RT + primal = if EnzymeCore.EnzymeRules.needs_primal(config) + func.val(loc.val, grid.val, T.val; kw...) + else + nothing + end + + if haskey(kw, :a) + # copy zeroing + kw[:data] = copy(kw[:data]) + end + + shadow = if EnzymeCore.EnzymeRules.width(config) == 1 + func.val(loc.val, grid.val, T.val; kw...) + else + ntuple(Val(EnzymeCore.EnzymeRules.width(config))) do i + Base.@_inline_meta + func.val(loc.val, grid.val, T.val; kw...) + end + end + + return EnzymeCore.EnzymeRules.AugmentedReturn{EnzymeCore.EnzymeRules.needs_primal(config) ? RT : Nothing, batch(Val(EnzymeCore.EnzymeRules.width(config)), RT), Nothing}(primal, shadow, nothing) +end + +function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, func::EnzymeCore.Const{Type{Field}}, ::RT, tape, loc::Union{EnzymeCore.Const{<:Tuple}, EnzymeCore.Duplicated{<:Tuple}}, grid::EnzymeCore.Const{<:Oceananigans.Grids.AbstractGrid}, T::EnzymeCore.Const{<:DataType}; kw...) where RT + return (nothing, nothing, nothing) +end + + +function EnzymeCore.EnzymeRules.augmented_primal(config, + func::EnzymeCore.Const{typeof(Oceananigans.Utils.launch!)}, + ::Type{EnzymeCore.Const{Nothing}}, + arch, + grid, + workspec, + kernel!, + kernel_args...; + include_right_boundaries = false, + reduced_dimensions = (), + location = nothing, + only_active_cells = nothing, + kwargs...) + + + workgroup, worksize = Oceananigans.Utils.work_layout(grid.val, workspec.val; + include_right_boundaries, + reduced_dimensions, + location) + + offset = Oceananigans.Utils.offsets(workspec.val) + + if !isnothing(only_active_cells) + workgroup, worksize = Oceananigans.Utils.active_cells_work_layout(workgroup, worksize, only_active_cells, grid.val) + offset = nothing + end + + if worksize != 0 + + # We can only launch offset kernels with Static sizes!!!! + + if isnothing(offset) + loop! = kernel!.val(Oceananigans.Architectures.device(arch.val), workgroup, worksize) + dloop! = (typeof(kernel!) <: EnzymeCore.Const) ? nothing : kernel!.dval(Oceananigans.Architectures.device(arch.val), workgroup, worksize) + else + loop! = kernel!.val(Oceananigans.Architectures.device(arch.val), KernelAbstractions.StaticSize(workgroup), Oceananigans.Utils.OffsetStaticSize(contiguousrange(worksize, offset))) + dloop! = (typeof(kernel!) <: EnzymeCore.Const) ? nothing : kernel!.val(Oceananigans.Architectures.device(arch.val), KernelAbstractions.StaticSize(workgroup), Oceananigans.Utils.OffsetStaticSize(contiguousrange(worksize, offset))) + end + + @debug "Launching kernel $kernel! with worksize $worksize and offsets $offset from $workspec.val" + + + duploop = (typeof(kernel!) <: EnzymeCore.Const) ? EnzymeCore.Const(loop!) : EnzymeCore.Duplicated(loop!, dloop!) + + config2 = EnzymeCore.EnzymeRules.Config{#=needsprimal=#false, #=needsshadow=#false, #=width=#EnzymeCore.EnzymeRules.width(config), EnzymeCore.EnzymeRules.overwritten(config)[5:end]}() + subtape = EnzymeCore.EnzymeRules.augmented_primal(config2, duploop, EnzymeCore.Const{Nothing}, kernel_args...).tape + + tape = (duploop, subtape) + else + tape = nothing + end + + return EnzymeCore.EnzymeRules.AugmentedReturn{Nothing, Nothing, Any}(nothing, nothing, tape) +end + +function EnzymeCore.EnzymeRules.reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{1}, + func::EnzymeCore.Const{typeof(Oceananigans.Utils.launch!)}, + ::Type{EnzymeCore.Const{Nothing}}, + tape, + arch, + grid, + workspec, + kernel!, + kernel_args...; + include_right_boundaries = false, + reduced_dimensions = (), + location = nothing, + only_active_cells = nothing, + kwargs...) + + subrets = if tape !== nothing + duploop, subtape = tape + + config2 = EnzymeCore.EnzymeRules.Config{#=needsprimal=#false, #=needsshadow=#false, #=width=#EnzymeCore.EnzymeRules.width(config), EnzymeCore.EnzymeRules.overwritten(config)[5:end]}() + + EnzymeCore.EnzymeRules.reverse(config2, duploop, EnzymeCore.Const{Nothing}, subtape, kernel_args...) + else + ntuple(Val(length(kernel_args))) do _ + Base.@_inline_meta + nothing + end + end + + return (nothing, nothing, nothing, nothing, subrets...) + +end + +end diff --git a/test/runtests.jl b/test/runtests.jl index 1b0cfb2747..49a8efa9dc 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,7 @@ CUDA.allowscalar() do # Core Oceananigans if group == :unit || group == :all @testset "Unit tests" begin + include("test_enzyme.jl") include("test_grids.jl") include("test_operators.jl") include("test_boundary_conditions.jl") diff --git a/test/test_enzyme.jl b/test/test_enzyme.jl new file mode 100644 index 0000000000..3aa983d12a --- /dev/null +++ b/test/test_enzyme.jl @@ -0,0 +1,25 @@ +using Oceananigans +using Enzyme + +# Required presently +Enzyme.API.runtimeActivity!(true) + +EnzymeRules.inactive_type(::Type{<:Oceananigans.Grids.AbstractGrid}) = true + +f(grid) = CenterField(grid) + +@testset "Enzyme Unit Tests" begin + arch=CPU() + FT=Float64 + + N = 100 + topo = (Periodic, Flat, Flat) + grid = RectilinearGrid(arch, FT, topology=topo, size=(N), halo=2, x=(-1, 1), y=(-1, 1), z=(-1, 1)) + fwd, rev = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, Const{typeof(f)}, Duplicated, typeof(Const(grid))) + + tape, primal, shadow = fwd(Const(f), Const(grid) ) + + @show tape, primal, shadow + + @test size(primal) == size(shadow) +end \ No newline at end of file