Skip to content

Commit

Permalink
Merge pull request #151 from rafaqz/dimwise
Browse files Browse the repository at this point in the history
Dimwise function broadcast
  • Loading branch information
rafaqz authored Aug 21, 2020
2 parents 97b7b6f + 0f7d6cf commit c211ab6
Show file tree
Hide file tree
Showing 8 changed files with 198 additions and 11 deletions.
10 changes: 10 additions & 0 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ data
dimnum
dims
hasdim
otherdims
commondims
label
mode
metadata
Expand All @@ -121,18 +123,26 @@ refdims
shortname
units
val
basetypeof
```

And some utility methods for transforming DimensionalData objects:

```@docs
rebuild
modify
dimwise
dimwise!
setdims
swapdims
reorderindex
reorderarray
reorderrelation
reverseindex
reversearray
flipindex
fliparray
fliprelation
```

## Non-exported methods for developers
Expand Down
3 changes: 2 additions & 1 deletion src/DimensionalData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ export AbstractDimArray, DimArray, AbstractDimensionalArray, DimensionalArray
export data, dims, refdims, mode, metadata, name, shortname,
val, label, units, order, bounds, locus, mode, <|

export dimnum, hasdim, setdims, swapdims, rebuild, modify
export dimnum, hasdim, otherdims, commondims, setdims, swapdims, rebuild,
modify, dimwise, dimwise!

export order, indexorder, arrayorder,
reverseindex, reversearray, reorderindex,
Expand Down
1 change: 1 addition & 0 deletions src/dimension.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ Base.size(dim::Dimension) = size(val(dim))
Base.size(dim::Dimension{<:Val}) = (length(unwrap(val(dim))),)
Base.axes(dim::Dimension) = axes(val(dim))
Base.axes(dim::Dimension{<:Val}) = (Base.OneTo(length(dim)),)
Base.axes(dim::Dimension, i) = axes(val(dim), i)
Base.eachindex(dim::Dimension) = eachindex(val(dim))
Base.length(dim::Dimension{<:Union{AbstractArray,Number}}) = length(val(dim))
Base.length(dim::Dimension{<:Val}) = length(unwrap(val(dim)))
Expand Down
81 changes: 73 additions & 8 deletions src/primitives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const UnionAllTupleOrVector = Union{Vector{UnionAll},Tuple{UnionAll,Vararg}}
@inline _sortdims(tosort::Tuple, order::Tuple, rejected) =
# Match dims to the order, and also check if the mode has a
# transformed dimension that matches
if _dimsmatch(tosort[1], order[1])
if dimsmatch(tosort[1], order[1])
(tosort[1], _sortdims((rejected..., tail(tosort)...), tail(order), ())...)
else
_sortdims(tail(tosort), order, (rejected..., tosort[1]))
Expand All @@ -32,8 +32,37 @@ const UnionAllTupleOrVector = Union{Vector{UnionAll},Tuple{UnionAll,Vararg}}
@inline _sortdims(tosort::Tuple, order::Tuple{}, rejected) = ()
@inline _sortdims(tosort::Tuple{}, order::Tuple{}, rejected) = ()

@inline _dimsmatch(dim::DimOrDimType, match::DimOrDimType) =
"""
commondims(x, lookup)
This is basically `dims` where the order of the original is kept,
unlike `dims` where the lookup tuple determines the order
"""
commondims(A::AbstractArray, B::AbstractArray) = commondims(dims(A), dims(B))
commondims(A::AbstractArray, lookup) = commondims(dims(A), lookup)
commondims(dims::Tuple, lookup) = commondims(dims, (lookup,))
commondims(dims::Tuple, lookup::Tuple) =
if hasdim(lookup, dims[1])
(dims[1], commondims(tail(dims), lookup)...)
else
commondims(tail(dims), lookup)
end
commondims(dims::Tuple{}, lookup::Tuple) = ()


"""
dimsmatch(dim::DimOrDimType, match::DimOrDimType) => Bool
Compare 2 dimensions are of the same base type, or
are at least rotations/transformations of the same type.
"""
@inline dimsmatch(dims::Tuple, lookups::Tuple) =
all(map(dimsmatch, dims, lookups))
@inline dimsmatch(dim::DimOrDimType, match::DimOrDimType) =
basetypeof(dim) <: basetypeof(match) || basetypeof(dim) <: basetypeof(dims(mode(match)))
@inline dimsmatch(dim::DimOrDimType, match::Nothing) = false
@inline dimsmatch(dim::Nothing, match::DimOrDimType) = false
@inline dimsmatch(dim::Nothing, match::Nothing) = false

"""
dims2indices(dim::Dimension, lookup, [emptyval=Colon()])
Expand Down Expand Up @@ -195,7 +224,7 @@ julia> dimnum(A, Z)

# Match dim and lookup, also check if the mode has a transformed dimension that matches
@inline _dimnum(d::Tuple, lookup::Tuple, rejected, n) =
if !(d[1] isa Nothing) && _dimsmatch(d[1], lookup[1])
if dimsmatch(d[1], lookup[1])
# Replace found dim with nothing so it isn't found again but n is still correct
(n, _dimnum((rejected..., nothing, tail(d)...), tail(lookup), (), 1)...)
else
Expand Down Expand Up @@ -234,18 +263,54 @@ false
@inline hasdim(A::AbstractArray, lookup) = hasdim(dims(A), lookup)
@inline hasdim(d::Tuple, lookup::Tuple) = map(l -> hasdim(d, l), lookup)
@inline hasdim(d::Tuple, lookup::DimOrDimType) =
if _dimsmatch(d[1], lookup)
if dimsmatch(d[1], lookup)
true
else
hasdim(tail(d), lookup)
end
@inline hasdim(::Tuple{}, ::DimOrDimType) = false

"""
setdim(x, newdim)
otherdims(x, lookup) => Tuple{Vararg{<:Dimension,N}}
## Arguments
- `x`: any object with a `dims` method, a `Tuple` of `Dimension`.
- `lookup`: Tuple or single `Dimension` or dimension `Type`.
A tuple holding the unmatched dimensions is always returned.
## Example
```jldoctest
julia> A = DimArray(ones(10, 10, 10), (X, Y, Z));
julia> otherdims(A, X)
(Y: Base.OneTo(10), Z: Base.OneTo(10))
julia> otherdims(A, Ti)
(X: Base.OneTo(10), Y: Base.OneTo(10), Z: Base.OneTo(10))
```
"""
@inline otherdims(A::AbstractArray, lookup) = otherdims(dims(A), lookup)
@inline otherdims(dims::Tuple, lookup::DimOrDimType) = otherdims(dims, (lookup,))
@inline otherdims(dims::Tuple, lookup::Tuple) =
_otherdims(dims, _sortdims(lookup, dims))

#= Work with a sorted lookup where the missing dims are `nothing`.
Then we can compare with `dimsmatch`, and splat away the matches. =#
@inline _otherdims(dims::Tuple, sortedlookup::Tuple) =
(_otherdims(dims[1], sortedlookup[1])...,
_otherdims(tail(dims), tail(sortedlookup))...)
@inline _otherdims(dims::Tuple{}, ::Tuple{}) = ()
@inline _otherdims(dim::DimOrDimType, lookupdim) =
dimsmatch(dim, lookupdim) ? () : (dim,)

"""
setdims(A::AbstractArray, newdims) => AbstractArray
setdims(::Tuple, newdims) => Tuple{Vararg{<:Dimension,N}}
setdims(dim::Dimension, newdims) => Dimension
Replaces the first dim matching `<: basetypeof(newdim)` with newdim, and returns
a new object or tuple with the dimension updated.
Replaces the first dim matching `<: basetypeof(newdim)` with newdim,
and returns a new object or tuple with the dimension updated.
## Arguments
- `x`: any object with a `dims` method, a `Tuple` of `Dimension` or a single `Dimension`.
Expand Down Expand Up @@ -407,7 +472,7 @@ type: Z{Base.OneTo{Int64},NoIndex,Nothing}
@inline dims(d::DimTuple, lookup::Tuple) = _dims(d, lookup, (), d)

@inline _dims(d, lookup::Tuple, rejected, remaining) =
if !(remaining[1] isa Nothing) && _dimsmatch(remaining[1], lookup[1])
if dimsmatch(remaining[1], lookup[1])
# Remove found dim so it isn't found again
(remaining[1], _dims(d, tail(lookup), (), (rejected..., tail(remaining)...))...)
else
Expand Down
71 changes: 71 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,77 @@ modify(f, A::AbstractDimArray) = begin
end


"""
dimwise!(f, A::AbstractDimArray, B::AbstractDimArray)
Dimension-wise application of function `f`.
## Arguments
-`a`: `AbstractDimArray` to broacast from, along dimensions not in `b`.
-`b`: `AbstractDimArray` to broadcast from all diensions.
Dimensions must be a subset of a.
This is like broadcasting over every slice of `A` if it is
sliced by the dimensions of `B`, and storing the value in `dest`.
"""
dimwise(f, A::AbstractDimArray, B::AbstractDimArray) =
dimwise!(f, similar(A, promote_type(eltype(A), eltype(B))), A, B)

"""
dimwise!(f, dest::AbstractDimArray, A::AbstractDimArray, B::AbstractDimArray)
Dimension-wise application of function `f`.
## Arguments
-`dest`: `AbstractDimArray` to update
-`a`: `AbstractDimArray` to broacast from, along dimensions not in `b`.
-`b`: `AbstractDimArray` to broadcast from all diensions.
Dimensions must be a subset of a.
This is like broadcasting over every slice of `A` if it is
sliced by the dimensions of `B`, and storing the value in `dest`.
"""
dimwise!(f, dest::AbstractDimArray{T,N}, a::AbstractDimArray{TA,N}, b::AbstractDimArray{TB,NB}
) where {T,TA,TB,N,NB} = begin
N >= NB || error("B-array cannot have more dimensions than A array")
comparedims(dest, a)
common = commondims(a, dims(b))
generators = dimwise_generators(otherdims(a, common))
# Lazily permute B dims to match the order in A, if required
if !dimsmatch(common, dims(b))
b = PermutedDimsArray(b, common)
end
map(generators) do otherdims
I = (common..., otherdims...)
dest[I...] .= f.(a[I...], b[common...])
end
return dest
end

dimwise_generators(dims::Tuple{<:Dimension}) =
((basetypeof(dims[1])(i),) for i in axes(dims[1], 1))

dimwise_generators(dims::Tuple) = begin
dim_constructors = map(basetypeof, dims)
Base.Generator(
Base.Iterators.ProductIterator(map(d -> axes(d, 1), dims)),
vals -> map(dim_constructors, vals)
)
end

"""
basetypeof(x)
Get the base type of an object - the minimum required to
define the object without it's fields. By default this is the full
`UnionAll` for the type. But custom `basetypeof` methods can be
defined for types with free type parameters.
In DimensionalData this is primariliy used for comparing dimensions,
where `Dim{:x}` is different from `Dim{:y}`.
"""
basetypeof(x) = basetypeof(typeof(x))
@generated function basetypeof(::Type{T}) where T
getfield(parentmodule(T), nameof(T))
Expand Down
18 changes: 17 additions & 1 deletion test/primitives.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
using DimensionalData, Test

using DimensionalData: val, basetypeof, slicedims, dims2indices, formatdims, mode,
@dim, reducedims, XDim, YDim, ZDim, Forward
@dim, reducedims, XDim, YDim, ZDim, Forward, commondims

dimz = (X(), Y())

Expand Down Expand Up @@ -138,6 +138,13 @@ end
@test dims(x) == x
end

@testset "commondims" begin
commondims(da, X) == (dims(da, X),)
# Dims are always in the base order
commondims(da, (X, Y)) == dims(da, (X, Y))
commondims(da, (Y, X)) == dims(da, (X, Y))
end

@testset "hasdim" begin
@test hasdim(da, X) == true
@test hasdim(da, Ti) == false
Expand All @@ -153,6 +160,15 @@ end
@test hasdim(dims(da), (ZDim, ZDim)) == (false, false)
end

@testset "otherdims" begin
A = DimArray(ones(5, 10, 15), (X, Y, Z));
@test otherdims(A, X) == dims(A, (Y, Z))
@test otherdims(A, Y) == dims(A, (X, Z))
@test otherdims(A, Z) == dims(A, (X, Y))
@test otherdims(A, (X, Z)) == dims(A, (Y,))
@test otherdims(A, Ti) == dims(A, (X, Y, Z))
end

@testset "setdims" begin
A = setdims(da, X(LinRange(150,152,2)))
@test val(dims(dims(A), X())) == LinRange(150,152,2)
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@ if VERSION >= v"1.5.0"
Aqua.test_stale_deps(DimensionalData)
end

include("matmul.jl")
include("methods.jl")
include("utils.jl")
include("matmul.jl")
include("dimension.jl")
include("interface.jl")
include("primitives.jl")
Expand Down
22 changes: 22 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,25 @@ end
typeof(parent(mda)) == BitArray{2}
@test_throws ErrorException modify(A -> A[1, :], da)
end

@testset "dimwise" begin
A2 = [1 2 3; 4 5 6]
B1 = [1, 2, 3]
da2 = DimArray(A2, (X, Y))
db1 = DimArray(B1, (Y,))
dc2 = dimwise(+, da2, db1)
@test dc2 == [2 4 6; 5 7 9]

A3 = cat([1 2 3; 4 5 6], [11 12 13; 14 15 16]; dims=3)
da3 = DimArray(A3, (X, Y, Z))
db2 = DimArray(A2, (X, Y))
dc3 = dimwise(+, da3, db2)
@test dc3 == cat([2 4 6; 8 10 12], [12 14 16; 18 20 22]; dims=3)

@testset "works with permuted dims" begin
db2p = permutedims(da2)
dc3p = dimwise(+, da3, db2p)
@test dc3p == cat([2 4 6; 8 10 12], [12 14 16; 18 20 22]; dims=3)
end

end

0 comments on commit c211ab6

Please sign in to comment.