Skip to content

Commit

Permalink
bugfixes for @d (#832)
Browse files Browse the repository at this point in the history
* various bugfixes for @d

* DimVector lower down
  • Loading branch information
rafaqz authored Nov 1, 2024
1 parent 05250f7 commit 091b924
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 14 deletions.
2 changes: 2 additions & 0 deletions src/array/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,8 @@ const DimVecOrMat = Union{DimVector,DimMatrix}
DimVector(A::AbstractVector, dim::Dimension, args...; kw...) =
DimArray(A, (dim,), args...; kw...)
DimVector(A::AbstractVector, args...; kw...) = DimArray(A, args...; kw...)
DimVector(f::Function, dim::Dimension; kw...) =
DimArray(f::Function, dim::Dimension; kw...)
DimMatrix(A::AbstractMatrix, args...; kw...) = DimArray(A, args...; kw...)

Base.convert(::Type{DimArray}, A::AbstractDimArray) = DimArray(A)
Expand Down
33 changes: 19 additions & 14 deletions src/array/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,29 +223,34 @@ function _process_d_macro_options(options::Expr)
return options_dict, options_expr
end

# Handle existing variable names
_find_broadcast_vars(sym::Symbol) = esc(sym), Pair{Symbol,Any}[]
# Handle e.g. 1 in the expression
function _find_broadcast_vars(x)
var = Symbol(gensym(), :_d)
esc(var), Pair{Symbol,Any}[var => x]
end
# Walk the broadcast expression, finding broadcast arguments and
# pulling them out of the main broadcast into separate variables.
# This lets us get `dims` from all of them and use it to reshape
# and permute them so they all match.
_find_broadcast_vars(sym::Symbol) = esc(sym), Pair{Symbol,Any}[]
function _find_broadcast_vars(expr::Expr)
if expr.head == :macrocall && expr.args[1] == Symbol("@__dot__")
return _find_broadcast_vars(Base.Broadcast.__dot__(expr.args[3]))
end
mdb = :($DimensionalData._maybe_dimensional_broadcast)
arg_list = Pair{Symbol,Any}[]
if expr.head == :. # function dot broadcast
if expr.head == :. && !(expr.args[2] isa QuoteNode) # function dot broadcast
if expr.args[2] isa Expr
wrapped_args = map(expr.args[2].args) do arg
var = Symbol(gensym(), :_d)
out = if arg isa Expr
expr1, arg_list1 = _find_broadcast_vars(arg)
expr1, arg_list1 = _find_broadcast_vars(arg)
out = if isempty(arg_list1)
push!(arg_list, var => arg)
esc(var)
else
append!(arg_list, arg_list1)
expr1
else
arg1 = arg
push!(arg_list, var => arg1)
esc(var)
end
Expr(:call, mdb, out, :dims, :options)
end
Expand All @@ -255,13 +260,13 @@ function _find_broadcast_vars(expr::Expr)
elseif expr.head == :call && string(expr.args[1])[1] == '.' # infix broadcast
wrapped_args = map(expr.args[2:end]) do arg
var = Symbol(gensym(), :_d)
out = if arg isa Expr
expr1, arg_list1 = _find_broadcast_vars(arg)
append!(arg_list, arg_list1)
expr1
else
expr1, arg_list1 = _find_broadcast_vars(arg)
out = if isempty(arg_list1)
push!(arg_list, var => arg)
esc(var)
else
append!(arg_list, arg_list1)
expr1
end
Expr(:call, mdb, out, :dims, :options)
end
Expand Down Expand Up @@ -442,6 +447,6 @@ end
end
end
@inline _find_dims((d, args...)::Tuple{<:Dimension,Vararg}) =
(d, otherdims(_find_dims(args), (d,)))
(d, otherdims(_find_dims(args), (d,))...)
@inline _find_dims(::Tuple{}) = ()
@inline _find_dims((_, args...)::Tuple) = _find_dims(args)
15 changes: 15 additions & 0 deletions test/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,21 @@ end

@testset "Dimension" begin
@test (@d X(1:3) .* X(10:10:30) strict=false) == [10, 40, 90]
da = @d string.(X(10:10:30), Y([:a, :b, :c]), Z(1:2:5))
@test da == (xs -> string(xs...)).(DimPoints((X(10:10:30), Y([:a, :b, :c]), Z(1:2:5))))
end

@testset "stack fields" begin
xs = 1.0:10.0
v1 = DimVector(identity, X(xs); name=:v1)
v2 = DimVector(x -> 2x, X(xs); name=:v2)
ds = DimStack(v1)
@test (@d v1 .* v2) == (@d ds.v1 .* v2)
end

@testset "numbers etc" begin
dv = DimArray(identity, X(1.0:10.0); name=:x)
@test (@d dv .* 2) == (dv .* 2)
end
end

Expand Down

0 comments on commit 091b924

Please sign in to comment.