diff --git a/src/utils.jl b/src/utils.jl index dd876cd11..a3a8eed53 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -218,24 +218,32 @@ dimwise!(f, dest::AbstractDimArray{T,N}, a::AbstractDimArray{TA,N}, b::AbstractD if !dimsmatch(common, dims(b)) b = PermutedDimsArray(b, common) end - map(generators) do otherdims - I = (common..., otherdims...) - dest[I...] .= f.(a[I...], b[common...]) + # Broadcast over b for each combination of dimensional indices D + map(generators) do D + dest[D...] .= f.(a[D...], b) end return dest end +# Single dimension generator dimwise_generators(dims::Tuple{<:Dimension}) = ((basetypeof(dims[1])(i),) for i in axes(dims[1], 1)) +# Multi dimensional generators dimwise_generators(dims::Tuple) = begin dim_constructors = map(basetypeof, dims) - Base.Generator( - Base.Iterators.ProductIterator(map(d -> axes(d, 1), dims)), - vals -> map(dim_constructors, vals) - ) + # Get the axes of the dims to iterate over + dimaxes = map(d -> axes(d, 1), dims) + # Make an iterator over all axes + proditr = Base.Iterators.ProductIterator(dimaxes) + # Wrap the produced index I in dimensions as it is generated + Base.Generator(proditr) do I + map((D, i) -> D(i), dim_constructors, I) + end end + + """ basetypeof(x) => Type diff --git a/test/utils.jl b/test/utils.jl index 48c90fb3b..de2dea360 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -71,8 +71,8 @@ end @testset "dimwise" begin A2 = [1 2 3; 4 5 6] B1 = [1, 2, 3] - da2 = DimArray(A2, (X, Y)) - db1 = DimArray(B1, (Y,)) + da2 = DimArray(A2, (X([20, 30]), Y([:a, :b, :c]))) + db1 = DimArray(B1, (Y([:a, :b, :c]),)) dc2 = dimwise(+, da2, db1) @test dc2 == [2 4 6; 5 7 9] @@ -82,6 +82,12 @@ end dc3 = dimwise(+, da3, db2) @test dc3 == cat([2 4 6; 8 10 12], [12 14 16; 18 20 22]; dims=3) + A3 = cat([1 2 3; 4 5 6], [11 12 13; 14 15 16]; dims=3) + da3 = DimArray(A3, (X([20, 30]), Y([:a, :b, :c]), Z(10:10:20))) + db1 = DimArray(B1, (Y([:a, :b, :c]),)) + dc3 = dimwise(+, da3, db1) + @test dc3 == cat([2 4 6; 5 7 9], [12 14 16; 15 17 19]; dims=3) + @testset "works with permuted dims" begin db2p = permutedims(da2) dc3p = dimwise(+, da3, db2p)