Skip to content

Commit

Permalink
add @d macro
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Sep 1, 2024
1 parent 8330424 commit 646c059
Show file tree
Hide file tree
Showing 4 changed files with 234 additions and 15 deletions.
2 changes: 2 additions & 0 deletions src/DimensionalData.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ export set, rebuild, reorder, modify, broadcast_dims, broadcast_dims!, mergedims

export groupby, seasons, months, hours, intervals, ranges

export @d

const DD = DimensionalData

# Common
Expand Down
189 changes: 189 additions & 0 deletions src/array/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,3 +100,192 @@ _broadcasted_dims(bc::Broadcasted) = _broadcasted_dims(bc.args...)
_broadcasted_dims(a, bs...) = (_broadcasted_dims(a)..., _broadcasted_dims(bs...)...)
_broadcasted_dims(a::AbstractBasicDimArray) = (dims(a),)
_broadcasted_dims(a) = ()

"""
@d broadcast_expression options
Dimensional broadcast macro.
Will permute and and singleton dimensions
so that all `AbstractDimArray` in the broadcast will
broadcast their matching dimensions.
It is possible to pass options as the second argument of
the macro to control the behaviour, as a single assignment
or as a NamedTuple. Options names must be written explicitly,
not passed in namedtuple variable.
# Options
- `dims`: Pass a Tuple of `Dimension`s, `Dimension` types or `Symbol`s
to fix the dimension order of the output array. Otherwise dimensions
will be in order of appearance.
- `strict`: `true` or `false`. Check that all lookup values match explicitly.
# Example
```julia
da1 = ones(X(3))
da2 = fill(2, Y(4), X(3))
@d da1 .* da2
@d da1 .* da2 .+ 5 dims=(Y, X)
```
"""
macro d(expr::Expr, options::Union{Expr,Nothing}=nothing)
options_dict = _process_d_macro_options(options)
broadcast_expr, var_list = _wrap_broadcast_vars(expr)
var_list_assignments = map(var_list) do (name, expr)
Expr(:(=), name, expr)
end
vars_expr = Expr(:tuple, map(first, var_list)...)
var_list_expr = Expr(:block, var_list_assignments...)
dims_expr = if haskey(options_dict, :dims)
order_dims = options_dict[:dims]
quote
order_dims = $order_dims
found_dims = _find_dims(vars)
all(hasdim(order_dims, found_dims)) || throw(ArgumentError("order $(basedims(order_dims)) dont match dimensions found in arrays $(basedims(found_dims))"))
dims = $DimensionalData.dims(found_dims, order_dims)
end
else
:(dims = _find_dims(vars))
end
quote
let
$var_list_expr
vars = $vars_expr
$dims_expr
$broadcast_expr
end
end
end
macro d(sym::Symbol, options::Union{Expr,Nothing}=nothing)
esc(sym)
end

_process_d_macro_options(::Nothing) = Dict{Symbol,Any}()
function _process_d_macro_options(options::Expr)
options_dict = Dict{Symbol,Any}()
if options.head == :tuple
if options.args[1].head == :parameters
# Keyword syntax (; order=...
for arg in options.args[1].args
arg.head == :kw || throw(ArgumentError("malformed options"))
options_dict[arg.args[1]] = esc(arg.args[2])
end
else
# Tuple syntax (order=...
for arg in options.args
arg.head == :(=) || throw(ArgumentError("malformed options"))
options_dict[arg.args[1]] = esc(arg.args[2])
end
end
elseif options.head == :(=)
# Single assignmen order=...
options_dict[options.args[1]] = esc(options.args[2])
end

return options_dict
end

_wrap_broadcast_vars(sym::Symbol) = esc(sym), Expr[]
function _wrap_broadcast_vars(expr::Expr)
arg_list = Pair{Symbol,Expr}[]
if expr.head == :. # function dot broadcast
if expr.args[2] isa Expr
tuple_args = map(expr.args[2].args) do arg
if arg isa Expr
expr1, arg_list1 = _wrap_broadcast_vars(arg)
append!(arg_list, arg_list1)
expr1
else
var = Symbol(gensym(), :var)
push!(arg_list, var => esc(arg))
Expr(:call, :_maybe_dimensional_broadcast, var, :dims)
end
end
expr2 = Expr(expr.head, esc(expr.args[1]), Expr(:tuple, tuple_args...))
return expr2, arg_list
end
elseif expr.head == :call && string(expr.args[1])[1] == '.' # infix broadcast
args = map(expr.args[2:end]) do arg
if arg isa Expr
expr1, arg_list1 = _wrap_broadcast_vars(arg)
append!(arg_list, arg_list1)
expr1
else
var = Symbol(gensym(), :var)
push!(arg_list, var => esc(arg))
Expr(:call, :_maybe_dimensional_broadcast, var, :dims)
end
end
expr2 = Expr(expr.head, expr.args[1], args...)
return expr2, arg_list
else # Not part of the broadcast, just wrap return it
expr2 = esc(expr)
return expr2, arg_list
end
end

@inline function _find_dims((A, args...)::Tuple{<:AbstractBasicDimArray,Vararg})::DimTupleOrEmpty
expanded = _find_dims(args)
if expanded === ()
dims(A)
else
(dims(A)..., otherdims(expanded, dims(A))...)
end
end
@inline _find_dims((d, args...)::Tuple{<:Dimension,Vararg}) =
(d, otherdims(_find_dims(args), (d,)))
@inline _find_dims(::Tuple{}) = ()
@inline _find_dims((_, args...)::Tuple) = _find_dims(args)

_maybe_dimensional_broadcast(x, _) = x
function _maybe_dimensional_broadcast(A::AbstractBasicDimArray, dest_dims)
len1s = basedims(otherdims(dest_dims, dims(A)))
# Reshape first to avoid a ReshapedArray wrapper if possible
A1 = _maybe_insert_length_one_dims(A, dest_dims)
# Then permute and reorder
A2 = _maybe_lazy_permute(A1, dest_dims)
# Then rebuild with the new data and dims
data = parent(A2)
return rebuild(A; data, dims=format(dims(A2), data))
end
_maybe_dimensional_broadcast(d::Dimension, dims) =
_maybe_dimensional_broadcast(DimArray(parent(d), d), dims)

function _maybe_lazy_permute(A::AbstractBasicDimArray, dest)
if dimsmatch(commondims(dims(A), dims(dest)), commondims(dims(dest), dims(A)))
A
else
PermutedDimsArray(A, commondims(dims(dest), dims(A)))
end
end

function _maybe_insert_length_one_dims(A::AbstractBasicDimArray, dims)
if all(hasdim(A, dims))
A
else
_insert_length_one_dims(A, dims)
end
end

function _insert_length_one_dims(A::AbstractBasicDimArray, alldims)
if basedims(dims(A)) == basedims(dims(A), alldims)
lengths = map(alldims) do d
hasdim(A, d) ? size(A, d) : 1
end
newdims = map(alldims) do d
hasdim(A, d) ? dims(A, d) : rebuild(d, Lookups.Length1NoLookup())
end
else
odims = otherdims(alldims, DD.dims(A))
lengths = (size(A)..., map(_ -> 1, odims)...)
newdims = (dims(A)..., map(d -> rebuild(d, Lookups.Length1NoLookup()), odims)...)
end
newdata = reshape(parent(A), lengths)
A1 = rebuild(A, newdata, format(newdims, newdata))
return A1
end
15 changes: 2 additions & 13 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,7 @@ function broadcast_dims!(f, dest::AbstractDimArray{<:Any,N}, As::AbstractBasicDi
isempty(otherdims(A, dims(dest))) || throw(DimensionMismatch("Cannot broadcast over dimensions not in the dest array"))
# comparedims(dest, dims(A, dims(dest)))
# Lazily permute B dims to match the order in A, if required
if !dimsmatch(commondims(A, dest), commondims(dest, A))
PermutedDimsArray(A, commondims(dest, A))
else
A
end
_maybe_lazy_permute(A, dims(dest))
end
od = map(A -> otherdims(dest, dims(A)), As)
return _broadcast_dims_inner!(f, dest, As, od)
Expand All @@ -173,20 +169,13 @@ function _broadcast_dims_inner!(f, dest, As, od)
else
not_shared_dims = combinedims(od...)
reshaped = map(As) do A
all(hasdim(A, dims(dest))) ? parent(A) : _insert_length_one_dims(A, dims(dest))
_maybe_insert_length_one_dims(A, dims(dest))
end
dest .= f.(reshaped...)
end
return dest
end

function _insert_length_one_dims(A, alldims)
lengths = map(alldims) do d
hasdim(A, d) ? size(A, d) : 1
end
return reshape(parent(A), lengths)
end

@deprecate dimwise broadcast_dims
@deprecate dimwise! broadcast_dims!

Expand Down
43 changes: 41 additions & 2 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ dajl = rebuild(da, JLArray(parent(da)));
end

@testset "broadcast over length one dimension" begin
da2 = DimArray((1:4) * (1:2:8)', (X, Y));
da2 = DimArray((1:4) * (1:2:8)', (X, Y))
@test (@inferred da2 .* da2[:, 1:1]) == [1, 4, 9, 16] * (1:2:8)'
@test (@inferred da2[:, 1:1] .* da2) == [1, 4, 9, 16] * (1:2:8)'
end
Expand Down Expand Up @@ -65,7 +65,6 @@ end
Sampled(1.0:1:3.0; span=Regular(1.0), sampling=Points(), order=ForwardOrdered()),
Sampled(1.0:1:3.0; span=Regular(1.0), sampling=Intervals(Start()), order=ForwardOrdered()),
)
l = first(ls)
for l in ls
@test (@inferred lookup(zeros(X(l),) .* zeros(X(3),), X)) == NoLookup(Base.OneTo(3))
@test (@inferred lookup(zeros(X(l),) .* zeros(X(1),), X)) == NoLookup(Base.OneTo(3))
Expand Down Expand Up @@ -327,6 +326,46 @@ end
@test Array(A[DimSelectors(sub)]) == Array(C[DimSelectors(sub)])
end

@testset "@d macro" begin
f(x, y) = x * y
da1 = ones(X(3))
da2 = fill(2, X(3), Y(4))
da2a = fill(2, Y(4), X(3))
da3 = fill(3, Y(4), Z(5), X(3))
@d da1 .* da2
@d f.(da1, da2)
@d 0 .+ f.(da2, da1) .* f.(da1 ./ 1, da2a)
@d da1 .* da2
@d da2
@d da3 .+ f.(da2, da1) .* f.(da1 ./ 1, da2a)

res = @d da3 .* f.(da2, da1) .* f.(da1 ./ 1, da2a) (; dims=(X, Y, Z),)
@test all(==(12.0), res)
@test DimensionalData.basedims(res) == (X(), Y(), Z())
@test size(res) == (3, 4, 5)
@test_throws ArgumentError @d da3 .+ f.(da2, da1) .* f.(da1 ./ 1, da2a) dims=(X, Y)

res = @d da3 .* f.(da2, da1) .* f.(da1 ./ 1, da2a) (; order=(X, Y, Z),)

p(da1, da2, da3) = @d da3 .* f.(da2, da1) .* f.(da1 ./ 1, da2) dims=(X(), Y(), Z())
p(da1, da2, da3, n) = for i in 1:n p(da1, da2, da3) end
p(da1, da2, da3, 10000)

using ProfileView
@profview p(da1, da2, da3, 100000)

x, y, z = X(1:3), Y(DateTime(2000):Month(2):DateTime(2001)), Z(5)
da1 = ones(y) .* (1.0:7.0)
da2 = fill(2, x, y) .* (1:3)
da3 = fill(3, y, z, x) .* (1:7)
f(da1, da2, da3, 100)

# Shape and permutaton do not matter
@test f(da1, da2, da3) ==
f(da1, permutedims(da2, (Y, X)), da3)
f(da1, da2, permutedims(da3, (X, Y, Z)))
end

# @testset "Competing Wrappers" begin
# da = DimArray(ones(4), X)
# ta = TrackedArray(5 * ones(4))
Expand Down

0 comments on commit 646c059

Please sign in to comment.