From 6e45feff8f73e3bf5eb5a37b90126fa18b3545be Mon Sep 17 00:00:00 2001 From: rafaqz Date: Sun, 14 Jan 2024 23:36:22 +0100 Subject: [PATCH] indexing fixes --- src/stack/indexing.jl | 51 +++++++++++++++++++++++++++++++++++-------- test/stack.jl | 6 ++--- 2 files changed, 45 insertions(+), 12 deletions(-) diff --git a/src/stack/indexing.jl b/src/stack/indexing.jl index 6b4c75c40..b7df97814 100644 --- a/src/stack/indexing.jl +++ b/src/stack/indexing.jl @@ -12,25 +12,58 @@ end end # Array-like indexing -@propagate_inbounds Base.getindex(s::AbstractDimStack, i::Int, I::Int...) = - map(A -> Base.getindex(A, i, I...), data(s)) for f in (:getindex, :view, :dotview) @eval begin - - @propagate_inbounds function Base.$f(s::AbstractDimStack, I...; kw...) - indexdims = (I..., kwdims(values(kw))...) - extradims = otherdims(indexdims, dims(s)) + @propagate_inbounds function Base.$f( + s::AbstractDimStack, i1::Union{Integer,CartesianIndex}, Is::Union{Integer,CartesianIndex}... + ) + # Convert to Dimension wrappers to handle mixed size layers + Base.$f(s, DimIndices(s)[i1, Is...]...) + end + @propagate_inbounds function Base.$f(s::AbstractDimStack, i1, Is...) + I = (i1, Is...) + if length(dims(s)) > length(I) + throw(BoundsError(dims(s), I)) + elseif length(dims(s)) < length(I) + # Allow trailing ones + if all(i -> i isa Integer && i == 1, I[length(dims(s)):end]) + I = I[1:length(dims)] + else + throw(BoundsError(dims(s), I)) + end + end + # Convert to Dimension wrappers to handle mixed size layers + Base.$f(s, map(rebuild, dims(s), I)...) + end + @propagate_inbounds function Base.$f(s::AbstractDimStack, i::AbstractArray) + # Multidimensional: return vectors of values + if length(dims(s)) > 1 + Ds = DimIndices(s)[i] + map(s) do A + map(D -> A[D...], Ds) + end + else + map(A -> A[i], s) + end + end + @propagate_inbounds function Base.$f(s::AbstractDimStack, D::Dimension...; kw...) + alldims = (D..., kwdims(values(kw))...) + extradims = otherdims(alldims, dims(s)) length(extradims) > 0 && Dimensions._extradimswarn(extradims) newlayers = map(layers(s)) do A - layerdims = dims(indexdims, dims(A)) - Base.$f(A, layerdims...) + layerdims = dims(alldims, dims(A)) + I = length(layerdims) > 0 ? layerdims : map(_ -> :, size(A)) + Base.$f(A, I...) end if all(map(v -> v isa AbstractDimArray, newlayers)) - rebuildsliced(Base.$f, s, newlayers, (dims2indices(dims(s), (I..., kwdims(values(kw))...)))) + rebuildsliced(Base.$f, s, newlayers, (dims2indices(dims(s), alldims))) else newlayers end end + @propagate_inbounds function Base.$f(s::AbstractDimStack) + map($f, s) + end end end diff --git a/test/stack.jl b/test/stack.jl index c25fd7da6..0e2796a25 100644 --- a/test/stack.jl +++ b/test/stack.jl @@ -63,7 +63,7 @@ end @testset "low level base methods" begin @test keys(data(s)) == (:one, :two, :three) @test keys(data(mixed)) == (:one, :two, :extradim) - @test eltype(mixed) === (one=Float64, two=Float32, extradim=Float64) + @test eltype(mixed) === @NamedTuple{one::Float64, two::Float32, extradim::Float64} @test haskey(s, :one) == true @test haskey(s, :zero) == false @test length(s) == 3 # length is as for NamedTuple @@ -86,11 +86,11 @@ end @test all(map(similar(mixed), mixed) do s, m dims(s) == dims(m) && dims(s) === dims(m) && eltype(s) === eltype(m) end) - @test all(map(==(Int), eltype(similar(s, Int)))) + @test eltype(similar(s, Int)) === @NamedTuple{one::Int, two::Int, three::Int} st2 = similar(mixed, Bool, x, y) @test dims(st2) === (x, y) @test dims(st2[:one]) === (x, y) - @test all(map(==(Bool), eltype(st2))) + @test eltype(st2) === @NamedTuple{one::Bool, two::Bool, extradim::Bool} end @testset "merge" begin