From 646c05919cf61b1bf43886b67886f019fe3f7d7f Mon Sep 17 00:00:00 2001 From: rafaqz Date: Sun, 1 Sep 2024 18:16:49 +0200 Subject: [PATCH] add @d macro --- src/DimensionalData.jl | 2 + src/array/broadcast.jl | 189 +++++++++++++++++++++++++++++++++++++++++ src/utils.jl | 15 +--- test/broadcast.jl | 43 +++++++++- 4 files changed, 234 insertions(+), 15 deletions(-) diff --git a/src/DimensionalData.jl b/src/DimensionalData.jl index 9eb0c0ae2..71991ea8c 100644 --- a/src/DimensionalData.jl +++ b/src/DimensionalData.jl @@ -79,6 +79,8 @@ export set, rebuild, reorder, modify, broadcast_dims, broadcast_dims!, mergedims export groupby, seasons, months, hours, intervals, ranges +export @d + const DD = DimensionalData # Common diff --git a/src/array/broadcast.jl b/src/array/broadcast.jl index 556326335..af49ff1ed 100644 --- a/src/array/broadcast.jl +++ b/src/array/broadcast.jl @@ -100,3 +100,192 @@ _broadcasted_dims(bc::Broadcasted) = _broadcasted_dims(bc.args...) _broadcasted_dims(a, bs...) = (_broadcasted_dims(a)..., _broadcasted_dims(bs...)...) _broadcasted_dims(a::AbstractBasicDimArray) = (dims(a),) _broadcasted_dims(a) = () + +""" + @d broadcast_expression options + +Dimensional broadcast macro. + +Will permute and and singleton dimensions +so that all `AbstractDimArray` in the broadcast will +broadcast their matching dimensions. + +It is possible to pass options as the second argument of +the macro to control the behaviour, as a single assignment +or as a NamedTuple. Options names must be written explicitly, +not passed in namedtuple variable. + +# Options + +- `dims`: Pass a Tuple of `Dimension`s, `Dimension` types or `Symbol`s + to fix the dimension order of the output array. Otherwise dimensions + will be in order of appearance. +- `strict`: `true` or `false`. Check that all lookup values match explicitly. + +# Example + +```julia +da1 = ones(X(3)) +da2 = fill(2, Y(4), X(3)) + +@d da1 .* da2 +@d da1 .* da2 .+ 5 dims=(Y, X) +``` + +""" +macro d(expr::Expr, options::Union{Expr,Nothing}=nothing) + options_dict = _process_d_macro_options(options) + broadcast_expr, var_list = _wrap_broadcast_vars(expr) + var_list_assignments = map(var_list) do (name, expr) + Expr(:(=), name, expr) + end + vars_expr = Expr(:tuple, map(first, var_list)...) + var_list_expr = Expr(:block, var_list_assignments...) + dims_expr = if haskey(options_dict, :dims) + order_dims = options_dict[:dims] + quote + order_dims = $order_dims + found_dims = _find_dims(vars) + all(hasdim(order_dims, found_dims)) || throw(ArgumentError("order $(basedims(order_dims)) dont match dimensions found in arrays $(basedims(found_dims))")) + dims = $DimensionalData.dims(found_dims, order_dims) + end + else + :(dims = _find_dims(vars)) + end + quote + let + $var_list_expr + vars = $vars_expr + $dims_expr + $broadcast_expr + end + end +end +macro d(sym::Symbol, options::Union{Expr,Nothing}=nothing) + esc(sym) +end + +_process_d_macro_options(::Nothing) = Dict{Symbol,Any}() +function _process_d_macro_options(options::Expr) + options_dict = Dict{Symbol,Any}() + if options.head == :tuple + if options.args[1].head == :parameters + # Keyword syntax (; order=... + for arg in options.args[1].args + arg.head == :kw || throw(ArgumentError("malformed options")) + options_dict[arg.args[1]] = esc(arg.args[2]) + end + else + # Tuple syntax (order=... + for arg in options.args + arg.head == :(=) || throw(ArgumentError("malformed options")) + options_dict[arg.args[1]] = esc(arg.args[2]) + end + end + elseif options.head == :(=) + # Single assignmen order=... + options_dict[options.args[1]] = esc(options.args[2]) + end + + return options_dict +end + +_wrap_broadcast_vars(sym::Symbol) = esc(sym), Expr[] +function _wrap_broadcast_vars(expr::Expr) + arg_list = Pair{Symbol,Expr}[] + if expr.head == :. # function dot broadcast + if expr.args[2] isa Expr + tuple_args = map(expr.args[2].args) do arg + if arg isa Expr + expr1, arg_list1 = _wrap_broadcast_vars(arg) + append!(arg_list, arg_list1) + expr1 + else + var = Symbol(gensym(), :var) + push!(arg_list, var => esc(arg)) + Expr(:call, :_maybe_dimensional_broadcast, var, :dims) + end + end + expr2 = Expr(expr.head, esc(expr.args[1]), Expr(:tuple, tuple_args...)) + return expr2, arg_list + end + elseif expr.head == :call && string(expr.args[1])[1] == '.' # infix broadcast + args = map(expr.args[2:end]) do arg + if arg isa Expr + expr1, arg_list1 = _wrap_broadcast_vars(arg) + append!(arg_list, arg_list1) + expr1 + else + var = Symbol(gensym(), :var) + push!(arg_list, var => esc(arg)) + Expr(:call, :_maybe_dimensional_broadcast, var, :dims) + end + end + expr2 = Expr(expr.head, expr.args[1], args...) + return expr2, arg_list + else # Not part of the broadcast, just wrap return it + expr2 = esc(expr) + return expr2, arg_list + end +end + +@inline function _find_dims((A, args...)::Tuple{<:AbstractBasicDimArray,Vararg})::DimTupleOrEmpty + expanded = _find_dims(args) + if expanded === () + dims(A) + else + (dims(A)..., otherdims(expanded, dims(A))...) + end +end +@inline _find_dims((d, args...)::Tuple{<:Dimension,Vararg}) = + (d, otherdims(_find_dims(args), (d,))) +@inline _find_dims(::Tuple{}) = () +@inline _find_dims((_, args...)::Tuple) = _find_dims(args) + +_maybe_dimensional_broadcast(x, _) = x +function _maybe_dimensional_broadcast(A::AbstractBasicDimArray, dest_dims) + len1s = basedims(otherdims(dest_dims, dims(A))) + # Reshape first to avoid a ReshapedArray wrapper if possible + A1 = _maybe_insert_length_one_dims(A, dest_dims) + # Then permute and reorder + A2 = _maybe_lazy_permute(A1, dest_dims) + # Then rebuild with the new data and dims + data = parent(A2) + return rebuild(A; data, dims=format(dims(A2), data)) +end +_maybe_dimensional_broadcast(d::Dimension, dims) = + _maybe_dimensional_broadcast(DimArray(parent(d), d), dims) + +function _maybe_lazy_permute(A::AbstractBasicDimArray, dest) + if dimsmatch(commondims(dims(A), dims(dest)), commondims(dims(dest), dims(A))) + A + else + PermutedDimsArray(A, commondims(dims(dest), dims(A))) + end +end + +function _maybe_insert_length_one_dims(A::AbstractBasicDimArray, dims) + if all(hasdim(A, dims)) + A + else + _insert_length_one_dims(A, dims) + end +end + +function _insert_length_one_dims(A::AbstractBasicDimArray, alldims) + if basedims(dims(A)) == basedims(dims(A), alldims) + lengths = map(alldims) do d + hasdim(A, d) ? size(A, d) : 1 + end + newdims = map(alldims) do d + hasdim(A, d) ? dims(A, d) : rebuild(d, Lookups.Length1NoLookup()) + end + else + odims = otherdims(alldims, DD.dims(A)) + lengths = (size(A)..., map(_ -> 1, odims)...) + newdims = (dims(A)..., map(d -> rebuild(d, Lookups.Length1NoLookup()), odims)...) + end + newdata = reshape(parent(A), lengths) + A1 = rebuild(A, newdata, format(newdims, newdata)) + return A1 +end diff --git a/src/utils.jl b/src/utils.jl index 472e0df82..40dc06e6c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -155,11 +155,7 @@ function broadcast_dims!(f, dest::AbstractDimArray{<:Any,N}, As::AbstractBasicDi isempty(otherdims(A, dims(dest))) || throw(DimensionMismatch("Cannot broadcast over dimensions not in the dest array")) # comparedims(dest, dims(A, dims(dest))) # Lazily permute B dims to match the order in A, if required - if !dimsmatch(commondims(A, dest), commondims(dest, A)) - PermutedDimsArray(A, commondims(dest, A)) - else - A - end + _maybe_lazy_permute(A, dims(dest)) end od = map(A -> otherdims(dest, dims(A)), As) return _broadcast_dims_inner!(f, dest, As, od) @@ -173,20 +169,13 @@ function _broadcast_dims_inner!(f, dest, As, od) else not_shared_dims = combinedims(od...) reshaped = map(As) do A - all(hasdim(A, dims(dest))) ? parent(A) : _insert_length_one_dims(A, dims(dest)) + _maybe_insert_length_one_dims(A, dims(dest)) end dest .= f.(reshaped...) end return dest end -function _insert_length_one_dims(A, alldims) - lengths = map(alldims) do d - hasdim(A, d) ? size(A, d) : 1 - end - return reshape(parent(A), lengths) -end - @deprecate dimwise broadcast_dims @deprecate dimwise! broadcast_dims! diff --git a/test/broadcast.jl b/test/broadcast.jl index facbff544..7576d6c31 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -25,7 +25,7 @@ dajl = rebuild(da, JLArray(parent(da))); end @testset "broadcast over length one dimension" begin - da2 = DimArray((1:4) * (1:2:8)', (X, Y)); + da2 = DimArray((1:4) * (1:2:8)', (X, Y)) @test (@inferred da2 .* da2[:, 1:1]) == [1, 4, 9, 16] * (1:2:8)' @test (@inferred da2[:, 1:1] .* da2) == [1, 4, 9, 16] * (1:2:8)' end @@ -65,7 +65,6 @@ end Sampled(1.0:1:3.0; span=Regular(1.0), sampling=Points(), order=ForwardOrdered()), Sampled(1.0:1:3.0; span=Regular(1.0), sampling=Intervals(Start()), order=ForwardOrdered()), ) - l = first(ls) for l in ls @test (@inferred lookup(zeros(X(l),) .* zeros(X(3),), X)) == NoLookup(Base.OneTo(3)) @test (@inferred lookup(zeros(X(l),) .* zeros(X(1),), X)) == NoLookup(Base.OneTo(3)) @@ -327,6 +326,46 @@ end @test Array(A[DimSelectors(sub)]) == Array(C[DimSelectors(sub)]) end +@testset "@d macro" begin + f(x, y) = x * y + da1 = ones(X(3)) + da2 = fill(2, X(3), Y(4)) + da2a = fill(2, Y(4), X(3)) + da3 = fill(3, Y(4), Z(5), X(3)) + @d da1 .* da2 + @d f.(da1, da2) + @d 0 .+ f.(da2, da1) .* f.(da1 ./ 1, da2a) + @d da1 .* da2 + @d da2 + @d da3 .+ f.(da2, da1) .* f.(da1 ./ 1, da2a) + + res = @d da3 .* f.(da2, da1) .* f.(da1 ./ 1, da2a) (; dims=(X, Y, Z),) + @test all(==(12.0), res) + @test DimensionalData.basedims(res) == (X(), Y(), Z()) + @test size(res) == (3, 4, 5) + @test_throws ArgumentError @d da3 .+ f.(da2, da1) .* f.(da1 ./ 1, da2a) dims=(X, Y) + + res = @d da3 .* f.(da2, da1) .* f.(da1 ./ 1, da2a) (; order=(X, Y, Z),) + + p(da1, da2, da3) = @d da3 .* f.(da2, da1) .* f.(da1 ./ 1, da2) dims=(X(), Y(), Z()) + p(da1, da2, da3, n) = for i in 1:n p(da1, da2, da3) end + p(da1, da2, da3, 10000) + + using ProfileView + @profview p(da1, da2, da3, 100000) + + x, y, z = X(1:3), Y(DateTime(2000):Month(2):DateTime(2001)), Z(5) + da1 = ones(y) .* (1.0:7.0) + da2 = fill(2, x, y) .* (1:3) + da3 = fill(3, y, z, x) .* (1:7) + f(da1, da2, da3, 100) + + # Shape and permutaton do not matter + @test f(da1, da2, da3) == + f(da1, permutedims(da2, (Y, X)), da3) + f(da1, da2, permutedims(da3, (X, Y, Z))) +end + # @testset "Competing Wrappers" begin # da = DimArray(ones(4), X) # ta = TrackedArray(5 * ones(4))