Skip to content

Commit

Permalink
Improve GPU functionality (#780)
Browse files Browse the repository at this point in the history
* Improve GPU functionality

* Add missing weakdeps

* Update src/array/broadcast.jl

Co-authored-by: Rafael Schouten <[email protected]>

* Update src/array/broadcast.jl

Co-authored-by: Rafael Schouten <[email protected]>

* Push materialize fix

* Clean up mapreduce and add a bunch of tests for JLArray broadcast

* Add some more JLArray tests

* Just return dest in broadcast

* Update src/array/methods.jl

Co-authored-by: Rafael Schouten <[email protected]>

* Format

---------

Co-authored-by: Rafael Schouten <[email protected]>
  • Loading branch information
ptiede and rafaqz authored Aug 24, 2024
1 parent fe39de7 commit 2b2b380
Show file tree
Hide file tree
Showing 5 changed files with 347 additions and 45 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Interfaces = "0.3"
IntervalSets = "0.5, 0.6, 0.7"
InvertedIndices = "1"
IteratorInterfaceExtensions = "1"
JLArrays = "0.1"
LinearAlgebra = "1"
Makie = "0.19, 0.20, 0.21"
OffsetArrays = "1"
Expand Down Expand Up @@ -85,6 +86,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
ImageFiltering = "6a3955dd-da59-5b1f-98d4-e7296123deb5"
ImageTransformations = "02fcd773-0e25-5acc-982a-7f6622650795"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Expand All @@ -95,4 +97,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "ImageFiltering", "ImageTransformations", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsPlots", "Test", "Unitful"]
test = ["Aqua", "ArrayInterface", "BenchmarkTools", "CategoricalArrays", "ColorTypes", "Combinatorics", "CoordinateTransformations", "DataFrames", "Distributions", "Documenter", "ImageFiltering", "ImageTransformations", "JLArrays", "CairoMakie", "OffsetArrays", "Plots", "Random", "SafeTestsets", "StatsPlots", "Test", "Unitful"]
30 changes: 14 additions & 16 deletions src/array/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,26 +47,24 @@ function Broadcast.copy(bc::Broadcasted{DimensionalStyle{S}}) where S
end

function Base.copyto!(dest::AbstractArray, bc::Broadcasted{DimensionalStyle{S}}) where S
_dims = comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
#TODO: this will cause a comparisson to happen twice. We should avoid that
comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
copyto!(dest, _unwrap_broadcasted(bc))
A = _firstdimarray(bc)
if A isa Nothing || _dims isa Nothing
dest
else
rebuild(A, dest, _dims, refdims(A))
end
return dest
end
function Base.copyto!(dest::AbstractDimArray, bc::Broadcasted{DimensionalStyle{S}}) where S
_dims = comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
copyto!(parent(dest), _unwrap_broadcasted(bc))
A = _firstdimarray(bc)
if A isa Nothing || _dims isa Nothing
dest
else
rebuild(A, parent(dest), _dims, refdims(A))
end


@inline function Base.Broadcast.materialize!(dest::AbstractDimArray, bc::Base.Broadcast.Broadcasted{<:Any})
# needed because we need to check whether the dims are compatible in dest which are already
# stripped when sent to copyto!
comparedims(dims(dest), _broadcasted_dims(bc); ignore_length_one=true, order=true)
style = DimensionalData.DimensionalStyle(Base.Broadcast.combine_styles(parent(dest), bc))
Base.Broadcast.materialize!(style, parent(dest), bc)
return dest
end



function Base.similar(bc::Broadcast.Broadcasted{DimensionalStyle{S}}, ::Type{T}) where {S,T}
A = _firstdimarray(bc)
rebuildsliced(A, similar(_unwrap_broadcasted(bc), T, axes(bc)...), axes(bc), Symbol(""))
Expand Down
22 changes: 3 additions & 19 deletions src/array/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,25 +53,9 @@ for (m, f) in ((:Statistics, :median), (:Base, :any), (:Base, :all))
end
end

# These are not exported but it makes a lot of things easier using them
function Base._mapreduce_dim(f, op, nt::NamedTuple{(),<:Tuple}, A::AbstractDimArray, dims)
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, _astuple(dims))), reducedims(A, dims))
end
function Base._mapreduce_dim(f, op, nt::NamedTuple{(),<:Tuple}, A::AbstractDimArray, dims::Colon)
Base._mapreduce_dim(f, op, nt, parent(A), dims)
end
function Base._mapreduce_dim(f, op, nt, A::AbstractDimArray, dims)
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
end
function Base._mapreduce_dim(f, op, nt, A::AbstractDimArray, dims::Colon)
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
end

function Base._mapreduce_dim(f, op, nt::Base._InitialValue, A::AbstractDimArray, dims)
rebuild(A, Base._mapreduce_dim(f, op, nt, parent(A), dimnum(A, dims)), reducedims(A, dims))
end
function Base._mapreduce_dim(f, op, nt::Base._InitialValue, A::AbstractDimArray, dims::Colon)
Base._mapreduce_dim(f, op, nt, parent(A), dims)
function Base.mapreduce(f, op, A::AbstractDimArray; dims=Base.Colon(), kw...)
dims === Colon() && return mapreduce(f, op, parent(A); kw...)
rebuild(A, mapreduce(f, op, parent(A); dims=dimnum(A, dims), kw...), reducedims(A, dims))
end


Expand Down
118 changes: 112 additions & 6 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
using DimensionalData, Test

using JLArrays
using DimensionalData: NoLookup

# Tests taken from NamedDims. Thanks @oxinabox

da = ones(X(3))
dajl = rebuild(da, JLArray(parent(da)));
@test Base.BroadcastStyle(typeof(da)) isa DimensionalData.DimensionalStyle

@testset "standard case" begin
Expand All @@ -19,18 +20,35 @@ end
@test da2 .* da2[:, 1:1] == [1, 4, 9, 16] * (1:2:8)'
end

@testset "JLArray broadcast over length one dimension" begin
da2 = DimArray(JLArray((1:4) * (1:2:8)'), (X, Y))
@test Array(da2 .* da2[:, 1:1]) == [1, 4, 9, 16] * (1:2:8)'
end

@testset "in place" begin
@test parent(da .= 1 .* da .+ 7) == 8 * ones(3)
@test dims(da .= 1 .* da .+ 7) == dims(da)
end

@testset "JLArray in place" begin
@test Array(parent(dajl .= 1 .* dajl .+ 7)) == 8 * ones(3)
@test dims(dajl .= 1 .* dajl .+ 7) == dims(da)
end

@testset "Dimension disagreement" begin
@test_throws DimensionMismatch begin
DimArray(zeros(3, 3, 3), (X, Y, Z)) .+
DimArray(ones(3, 3, 3), (Y, Z, X))
end
end

@testset "JLArray Dimension disagreement" begin
@test_throws DimensionMismatch begin
DimArray(JLArray(zeros(3, 3, 3)), (X, Y, Z)) .+
DimArray(JLArray(ones(3, 3, 3)), (Y, Z, X))
end
end

@testset "dims and regular" begin
da = DimArray(ones(3, 3, 3), (X, Y, Z))
left_sum = da .+ ones(3, 3, 3)
Expand All @@ -41,6 +59,16 @@ end
@test dims(right_sum) == dims(da)
end

@testset "JLArray dims and regular" begin
da = DimArray(JLArray(ones(3, 3, 3)), (X, Y, Z))
left_sum = da .+ ones(3, 3, 3)
@test Array(left_sum) == fill(2, 3, 3, 3)
@test dims(left_sum) == dims(da)
right_sum = ones(3, 3, 3) .+ da
@test Array(right_sum) == fill(2, 3, 3, 3)
@test dims(right_sum) == dims(da)
end

@testset "changing type" begin
@test (da .> 0) isa DimArray
@test (da .* da .> 0) isa DimArray
Expand All @@ -51,6 +79,16 @@ end
@test (rand(3) .> 1 .> 0 .* da) isa DimArray
end

@testset "JLArray changing type" begin
@test (dajl .> 0) isa DimArray
@test (dajl .* dajl .> 0) isa DimArray
@test (dajl .> 0 .> rand(3)) isa DimArray
@test (dajl .* rand(3) .> 0.0) isa DimArray
@test (0 .> dajl .> 0 .> rand(3)) isa DimArray
@test (rand(3) .> dajl .> 0 .* rand(3)) isa DimArray
@test (rand(3) .> 1 .> 0 .* dajl) isa DimArray
end

@testset "trailng dimensions" begin
@test zeros(X(10), Y(5)) .* zeros(X(10), Y(1)) ==
zeros(X(10), Y(5)) .* zeros(X(1), Y(1)) ==
Expand Down Expand Up @@ -79,6 +117,18 @@ end
@test dims(s .+ v .+ m) == dims(m .+ s .+ v)
end

@testset "JLArray broadcasting" begin
v = DimArray(JLArray(zeros(3,)), X)
m = DimArray(JLArray(ones(3, 3)), (X, Y))
s = 0
@test Array(v .+ m) == ones(3, 3) == Array(m .+ v)
@test Array(s .+ m) == ones(3, 3) == Array(m .+ s)
@test Array(s .+ v .+ m) == ones(3, 3) == Array(m .+ s .+ v)
@test dims(v .+ m) == dims(m .+ v)
@test dims(s .+ m) == dims(m .+ s)
@test dims(s .+ v .+ m) == dims(m .+ s .+ v)
end

@testset "adjoint broadcasting" begin
a = DimArray(reshape(1:12, (4, 3)), (X, Y))
b = DimArray(1:3, Y)
Expand All @@ -88,6 +138,17 @@ end
@test dims(a .* b') == dims(a)
end

@testset "JLArray adjoint broadcasting" begin
a = DimArray(JLArray(reshape(1:12, (4, 3))), (X, Y))
b = DimArray(JLArray(1:3), Y)
@test_throws DimensionMismatch a .* b
@test_throws DimensionMismatch parent(a) .* parent(b)
@test Array(parent(a) .* parent(b)') == Array(parent(a .* b'))
@test dims(a .* b') == dims(a)
end



@testset "Mixed array types" begin
casts = (
A -> DimArray(A, (X, Y)), # Named Matrix
Expand Down Expand Up @@ -121,13 +182,26 @@ end
@test_throws DimensionMismatch ac .= ab .+ ba

# check that dest is written into:
@test dims(z .= ab .+ ba') == dims(ab .+ ba')
z .= ab .+ ba'
@test z == (ab.data .+ ba.data')
end

@test dims(z .= ab .+ a_) ==
(X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
@test dims(a_ .= ba' .+ ab) ==
(X(NoLookup(Base.OneTo(2))), Y(NoLookup(Base.OneTo(2))))
@testset "JLArray in-place assignment .=" begin
ab = DimArray(JLArray(rand(2,2)), (X, Y))
ba = DimArray(JLArray(rand(2,2)), (Y, X))
ac = DimArray(JLArray(rand(2,2)), (X, Z))
a_ = DimArray(JLArray(rand(2,2)), (X(), DimensionalData.AnonDim()))
z = JLArray(zeros(2,2))

@test_throws DimensionMismatch z .= ab .+ ba
@test_throws DimensionMismatch z .= ab .+ ac
@test_throws DimensionMismatch a_ .= ab .+ ac
@test_throws DimensionMismatch ab .= a_ .+ ac
@test_throws DimensionMismatch ac .= ab .+ ba

# check that dest is written into:
z .= ab .+ ba'
@test z == (ab.data .+ ba.data')
end

@testset "assign using named indexing and dotview" begin
Expand All @@ -137,6 +211,13 @@ end
@test A == [1.0 1.0; 2.0 2.0; 7.0 7.0]
end

@testset "JLArray assign using named indexing and dotview" begin
A = DimArray(JLArray(zeros(3,2)), (X, Y))
A[X=1:2] .= JLArray([1, 2])
A[X=3] .= 7
@test Array(A) == [1.0 1.0; 2.0 2.0; 7.0 7.0]
end

@testset "0-dimensional array broadcasting" begin
x = DimArray(fill(3), ())
y = DimArray(fill(4), ())
Expand Down Expand Up @@ -168,6 +249,31 @@ end
@test A[DimSelectors(sub)] == C[DimSelectors(sub)]
end

@testset "JLArray DimIndices broadcasting" begin
ds = X(1.0:0.2:2.0), Y(10:2:20)
_A = (rand(ds))
_B = (zeros(ds))
_C = (zeros(ds))

A = rebuild(_A, JLArray(parent(_A)))
B = rebuild(_B, JLArray(parent(_B)))
C = rebuild(_C, JLArray(parent(_C)))

B[DimIndices(B)] .+= A
C[DimSelectors(C)] .+= A
@test Array(A) == Array(B) == Array(C)
sub = A[1:4, 1:3]
B .= 0
C .= 0
B[DimIndices(sub)] .+= sub
C[DimSelectors(sub)] .+= sub
@test Array(A[DimIndices(sub)]) == Array(B[DimIndices(sub)]) == Array(C[DimIndices(sub)])
sub = A[2:4, 2:5]
C .= 0
C[DimSelectors(sub)] .+= sub
@test Array(A[DimSelectors(sub)]) == Array(C[DimSelectors(sub)])
end

# @testset "Competing Wrappers" begin
# da = DimArray(ones(4), X)
# ta = TrackedArray(5 * ones(4))
Expand Down
Loading

0 comments on commit 2b2b380

Please sign in to comment.