Skip to content

Commit

Permalink
Merge pull request #168 from rafaqz/fix_dimwise
Browse files Browse the repository at this point in the history
Fix dimwise
  • Loading branch information
rafaqz authored Aug 30, 2020
2 parents 5786c8d + 31527a1 commit 90fe7f1
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 9 deletions.
22 changes: 15 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,24 +218,32 @@ dimwise!(f, dest::AbstractDimArray{T,N}, a::AbstractDimArray{TA,N}, b::AbstractD
if !dimsmatch(common, dims(b))
b = PermutedDimsArray(b, common)
end
map(generators) do otherdims
I = (common..., otherdims...)
dest[I...] .= f.(a[I...], b[common...])
# Broadcast over b for each combination of dimensional indices D
map(generators) do D
dest[D...] .= f.(a[D...], b)
end
return dest
end

# Single dimension generator
dimwise_generators(dims::Tuple{<:Dimension}) =
((basetypeof(dims[1])(i),) for i in axes(dims[1], 1))

# Multi dimensional generators
dimwise_generators(dims::Tuple) = begin
dim_constructors = map(basetypeof, dims)
Base.Generator(
Base.Iterators.ProductIterator(map(d -> axes(d, 1), dims)),
vals -> map(dim_constructors, vals)
)
# Get the axes of the dims to iterate over
dimaxes = map(d -> axes(d, 1), dims)
# Make an iterator over all axes
proditr = Base.Iterators.ProductIterator(dimaxes)
# Wrap the produced index I in dimensions as it is generated
Base.Generator(proditr) do I
map((D, i) -> D(i), dim_constructors, I)
end
end



"""
basetypeof(x) => Type
Expand Down
10 changes: 8 additions & 2 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ end
@testset "dimwise" begin
A2 = [1 2 3; 4 5 6]
B1 = [1, 2, 3]
da2 = DimArray(A2, (X, Y))
db1 = DimArray(B1, (Y,))
da2 = DimArray(A2, (X([20, 30]), Y([:a, :b, :c])))
db1 = DimArray(B1, (Y([:a, :b, :c]),))
dc2 = dimwise(+, da2, db1)
@test dc2 == [2 4 6; 5 7 9]

Expand All @@ -82,6 +82,12 @@ end
dc3 = dimwise(+, da3, db2)
@test dc3 == cat([2 4 6; 8 10 12], [12 14 16; 18 20 22]; dims=3)

A3 = cat([1 2 3; 4 5 6], [11 12 13; 14 15 16]; dims=3)
da3 = DimArray(A3, (X([20, 30]), Y([:a, :b, :c]), Z(10:10:20)))
db1 = DimArray(B1, (Y([:a, :b, :c]),))
dc3 = dimwise(+, da3, db1)
@test dc3 == cat([2 4 6; 5 7 9], [12 14 16; 15 17 19]; dims=3)

@testset "works with permuted dims" begin
db2p = permutedims(da2)
dc3p = dimwise(+, da3, db2p)
Expand Down

0 comments on commit 90fe7f1

Please sign in to comment.