From 091b924c5c9fa1c5b16f174993397aa872e1815e Mon Sep 17 00:00:00 2001 From: Rafael Schouten Date: Sat, 2 Nov 2024 00:07:10 +0100 Subject: [PATCH] bugfixes for `@d` (#832) * various bugfixes for @d * DimVector lower down --- src/array/array.jl | 2 ++ src/array/broadcast.jl | 33 +++++++++++++++++++-------------- test/broadcast.jl | 15 +++++++++++++++ 3 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/array/array.jl b/src/array/array.jl index 2ae57038e..faf94b1f0 100644 --- a/src/array/array.jl +++ b/src/array/array.jl @@ -447,6 +447,8 @@ const DimVecOrMat = Union{DimVector,DimMatrix} DimVector(A::AbstractVector, dim::Dimension, args...; kw...) = DimArray(A, (dim,), args...; kw...) DimVector(A::AbstractVector, args...; kw...) = DimArray(A, args...; kw...) +DimVector(f::Function, dim::Dimension; kw...) = + DimArray(f::Function, dim::Dimension; kw...) DimMatrix(A::AbstractMatrix, args...; kw...) = DimArray(A, args...; kw...) Base.convert(::Type{DimArray}, A::AbstractDimArray) = DimArray(A) diff --git a/src/array/broadcast.jl b/src/array/broadcast.jl index b92abbb7b..57f4966d0 100644 --- a/src/array/broadcast.jl +++ b/src/array/broadcast.jl @@ -223,29 +223,34 @@ function _process_d_macro_options(options::Expr) return options_dict, options_expr end +# Handle existing variable names +_find_broadcast_vars(sym::Symbol) = esc(sym), Pair{Symbol,Any}[] +# Handle e.g. 1 in the expression +function _find_broadcast_vars(x) + var = Symbol(gensym(), :_d) + esc(var), Pair{Symbol,Any}[var => x] +end # Walk the broadcast expression, finding broadcast arguments and # pulling them out of the main broadcast into separate variables. # This lets us get `dims` from all of them and use it to reshape # and permute them so they all match. -_find_broadcast_vars(sym::Symbol) = esc(sym), Pair{Symbol,Any}[] function _find_broadcast_vars(expr::Expr) if expr.head == :macrocall && expr.args[1] == Symbol("@__dot__") return _find_broadcast_vars(Base.Broadcast.__dot__(expr.args[3])) end mdb = :($DimensionalData._maybe_dimensional_broadcast) arg_list = Pair{Symbol,Any}[] - if expr.head == :. # function dot broadcast + if expr.head == :. && !(expr.args[2] isa QuoteNode) # function dot broadcast if expr.args[2] isa Expr wrapped_args = map(expr.args[2].args) do arg var = Symbol(gensym(), :_d) - out = if arg isa Expr - expr1, arg_list1 = _find_broadcast_vars(arg) + expr1, arg_list1 = _find_broadcast_vars(arg) + out = if isempty(arg_list1) + push!(arg_list, var => arg) + esc(var) + else append!(arg_list, arg_list1) expr1 - else - arg1 = arg - push!(arg_list, var => arg1) - esc(var) end Expr(:call, mdb, out, :dims, :options) end @@ -255,13 +260,13 @@ function _find_broadcast_vars(expr::Expr) elseif expr.head == :call && string(expr.args[1])[1] == '.' # infix broadcast wrapped_args = map(expr.args[2:end]) do arg var = Symbol(gensym(), :_d) - out = if arg isa Expr - expr1, arg_list1 = _find_broadcast_vars(arg) - append!(arg_list, arg_list1) - expr1 - else + expr1, arg_list1 = _find_broadcast_vars(arg) + out = if isempty(arg_list1) push!(arg_list, var => arg) esc(var) + else + append!(arg_list, arg_list1) + expr1 end Expr(:call, mdb, out, :dims, :options) end @@ -442,6 +447,6 @@ end end end @inline _find_dims((d, args...)::Tuple{<:Dimension,Vararg}) = - (d, otherdims(_find_dims(args), (d,))) + (d, otherdims(_find_dims(args), (d,))...) @inline _find_dims(::Tuple{}) = () @inline _find_dims((_, args...)::Tuple) = _find_dims(args) diff --git a/test/broadcast.jl b/test/broadcast.jl index e2fafdb84..e4a9cea2d 100644 --- a/test/broadcast.jl +++ b/test/broadcast.jl @@ -391,6 +391,21 @@ end @testset "Dimension" begin @test (@d X(1:3) .* X(10:10:30) strict=false) == [10, 40, 90] + da = @d string.(X(10:10:30), Y([:a, :b, :c]), Z(1:2:5)) + @test da == (xs -> string(xs...)).(DimPoints((X(10:10:30), Y([:a, :b, :c]), Z(1:2:5)))) + end + + @testset "stack fields" begin + xs = 1.0:10.0 + v1 = DimVector(identity, X(xs); name=:v1) + v2 = DimVector(x -> 2x, X(xs); name=:v2) + ds = DimStack(v1) + @test (@d v1 .* v2) == (@d ds.v1 .* v2) + end + + @testset "numbers etc" begin + dv = DimArray(identity, X(1.0:10.0); name=:x) + @test (@d dv .* 2) == (dv .* 2) end end