Skip to content

Commit

Permalink
indexing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
rafaqz committed Jan 14, 2024
1 parent b72c7b3 commit 6e45fef
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 12 deletions.
51 changes: 42 additions & 9 deletions src/stack/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,58 @@ end
end

# Array-like indexing
@propagate_inbounds Base.getindex(s::AbstractDimStack, i::Int, I::Int...) =
map(A -> Base.getindex(A, i, I...), data(s))
for f in (:getindex, :view, :dotview)
@eval begin

@propagate_inbounds function Base.$f(s::AbstractDimStack, I...; kw...)
indexdims = (I..., kwdims(values(kw))...)
extradims = otherdims(indexdims, dims(s))
@propagate_inbounds function Base.$f(
s::AbstractDimStack, i1::Union{Integer,CartesianIndex}, Is::Union{Integer,CartesianIndex}...
)
# Convert to Dimension wrappers to handle mixed size layers
Base.$f(s, DimIndices(s)[i1, Is...]...)
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, i1, Is...)
I = (i1, Is...)
if length(dims(s)) > length(I)
throw(BoundsError(dims(s), I))
elseif length(dims(s)) < length(I)
# Allow trailing ones
if all(i -> i isa Integer && i == 1, I[length(dims(s)):end])
I = I[1:length(dims)]
else
throw(BoundsError(dims(s), I))
end
end
# Convert to Dimension wrappers to handle mixed size layers
Base.$f(s, map(rebuild, dims(s), I)...)
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, i::AbstractArray)
# Multidimensional: return vectors of values
if length(dims(s)) > 1
Ds = DimIndices(s)[i]
map(s) do A
map(D -> A[D...], Ds)
end
else
map(A -> A[i], s)
end
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, D::Dimension...; kw...)
alldims = (D..., kwdims(values(kw))...)
extradims = otherdims(alldims, dims(s))
length(extradims) > 0 && Dimensions._extradimswarn(extradims)
newlayers = map(layers(s)) do A
layerdims = dims(indexdims, dims(A))
Base.$f(A, layerdims...)
layerdims = dims(alldims, dims(A))
I = length(layerdims) > 0 ? layerdims : map(_ -> :, size(A))
Base.$f(A, I...)
end
if all(map(v -> v isa AbstractDimArray, newlayers))
rebuildsliced(Base.$f, s, newlayers, (dims2indices(dims(s), (I..., kwdims(values(kw))...))))
rebuildsliced(Base.$f, s, newlayers, (dims2indices(dims(s), alldims)))
else
newlayers
end
end
@propagate_inbounds function Base.$f(s::AbstractDimStack)
map($f, s)
end
end
end

Expand Down
6 changes: 3 additions & 3 deletions test/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ end
@testset "low level base methods" begin
@test keys(data(s)) == (:one, :two, :three)
@test keys(data(mixed)) == (:one, :two, :extradim)
@test eltype(mixed) === (one=Float64, two=Float32, extradim=Float64)
@test eltype(mixed) === @NamedTuple{one::Float64, two::Float32, extradim::Float64}
@test haskey(s, :one) == true
@test haskey(s, :zero) == false
@test length(s) == 3 # length is as for NamedTuple
Expand All @@ -86,11 +86,11 @@ end
@test all(map(similar(mixed), mixed) do s, m
dims(s) == dims(m) && dims(s) === dims(m) && eltype(s) === eltype(m)
end)
@test all(map(==(Int), eltype(similar(s, Int))))
@test eltype(similar(s, Int)) === @NamedTuple{one::Int, two::Int, three::Int}
st2 = similar(mixed, Bool, x, y)
@test dims(st2) === (x, y)
@test dims(st2[:one]) === (x, y)
@test all(map(==(Bool), eltype(st2)))
@test eltype(st2) === @NamedTuple{one::Bool, two::Bool, extradim::Bool}
end

@testset "merge" begin
Expand Down

0 comments on commit 6e45fef

Please sign in to comment.