Skip to content

Commit

Permalink
Breaking: Stack rand (#567)
Browse files Browse the repository at this point in the history
* test rand on stacks

* add rand for stack

* indexing fixes
  • Loading branch information
rafaqz authored Feb 4, 2024
1 parent d23b606 commit a9dd0e9
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 10 deletions.
50 changes: 45 additions & 5 deletions src/stack/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,58 @@ end
end

# Array-like indexing
@propagate_inbounds Base.getindex(s::AbstractDimStack, i::Int, I::Int...) =
map(A -> Base.getindex(A, i, I...), data(s))
for f in (:getindex, :view, :dotview)
@eval begin
@propagate_inbounds function Base.$f(s::AbstractDimStack, I...; kw...)
newlayers = map(A -> Base.$f(A, I...; kw...), layers(s))
@propagate_inbounds function Base.$f(
s::AbstractDimStack, i1::Union{Integer,CartesianIndex}, Is::Union{Integer,CartesianIndex}...
)
# Convert to Dimension wrappers to handle mixed size layers
Base.$f(s, DimIndices(s)[i1, Is...]...)
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, i1, Is...)
I = (i1, Is...)
if length(dims(s)) > length(I)
throw(BoundsError(dims(s), I))
elseif length(dims(s)) < length(I)
# Allow trailing ones
if all(i -> i isa Integer && i == 1, I[length(dims(s)):end])
I = I[1:length(dims)]
else
throw(BoundsError(dims(s), I))
end
end
# Convert to Dimension wrappers to handle mixed size layers
Base.$f(s, map(rebuild, dims(s), I)...)
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, i::AbstractArray)
# Multidimensional: return vectors of values
if length(dims(s)) > 1
Ds = DimIndices(s)[i]
map(s) do A
map(D -> A[D...], Ds)
end
else
map(A -> A[i], s)
end
end
@propagate_inbounds function Base.$f(s::AbstractDimStack, D::Dimension...; kw...)
alldims = (D..., kwdims(values(kw))...)
extradims = otherdims(alldims, dims(s))
length(extradims) > 0 && Dimensions._extradimswarn(extradims)
newlayers = map(layers(s)) do A
layerdims = dims(alldims, dims(A))
I = length(layerdims) > 0 ? layerdims : map(_ -> :, size(A))
Base.$f(A, I...)
end
if all(map(v -> v isa AbstractDimArray, newlayers))
rebuildsliced(Base.$f, s, newlayers, (dims2indices(dims(s), (I..., kwdims(values(kw))...))))
rebuildsliced(Base.$f, s, newlayers, (dims2indices(dims(s), alldims)))
else
newlayers
end
end
@propagate_inbounds function Base.$f(s::AbstractDimStack)
map($f, s)
end
end
end

Expand Down
7 changes: 7 additions & 0 deletions src/stack/methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,10 @@ for fname in (:one, :oneunit, :zero, :copy)
end

Base.reverse(s::AbstractDimStack; dims=1) = map(A -> reverse(A; dims=dims), s)

# Random
Random.Sampler(RNG::Type{<:AbstractRNG}, st::AbstractDimStack, n::Random.Repetition) =
Random.SamplerSimple(st, Random.Sampler(RNG, DimIndices(st), n))

Random.rand(rng::AbstractRNG, sp::Random.SamplerSimple{<:AbstractDimStack,<:Random.Sampler}) =
@inbounds return sp[][rand(rng, sp.data)...]
2 changes: 1 addition & 1 deletion src/stack/stack.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ Base.axes(s::AbstractDimStack) = map(first ∘ axes, dims(s))
Base.axes(s::AbstractDimStack, dims::DimOrDimType) = axes(s, dimnum(s, dims))
Base.axes(s::AbstractDimStack, dims::Integer) = axes(s)[dims]
Base.similar(s::AbstractDimStack, args...) = map(A -> similar(A, args...), s)
Base.eltype(s::AbstractDimStack, args...) = map(eltype, s)
Base.eltype(s::AbstractDimStack, args...) = NamedTuple{keys(s),Tuple{map(eltype, s)...}}
Base.iterate(s::AbstractDimStack, args...) = iterate(layers(s), args...)
Base.read(s::AbstractDimStack) = map(read, s)
# `merge` for AbstractDimStack and NamedTuple.
Expand Down
15 changes: 11 additions & 4 deletions test/stack.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using DimensionalData, Test, LinearAlgebra, Statistics, ConstructionBase
using DimensionalData, Test, LinearAlgebra, Statistics, ConstructionBase, Random

using DimensionalData: data
using DimensionalData: Sampled, Categorical, AutoLookup, NoLookup, Transformed,
Expand Down Expand Up @@ -63,7 +63,7 @@ end
@testset "low level base methods" begin
@test keys(data(s)) == (:one, :two, :three)
@test keys(data(mixed)) == (:one, :two, :extradim)
@test eltype(mixed) === (one=Float64, two=Float32, extradim=Float64)
@test eltype(mixed) === @NamedTuple{one::Float64, two::Float32, extradim::Float64}
@test haskey(s, :one) == true
@test haskey(s, :zero) == false
@test length(s) == 3 # length is as for NamedTuple
Expand All @@ -86,11 +86,11 @@ end
@test all(map(similar(mixed), mixed) do s, m
dims(s) == dims(m) && dims(s) === dims(m) && eltype(s) === eltype(m)
end)
@test all(map(==(Int), eltype(similar(s, Int))))
@test eltype(similar(s, Int)) === @NamedTuple{one::Int, two::Int, three::Int}
st2 = similar(mixed, Bool, x, y)
@test dims(st2) === (x, y)
@test dims(st2[:one]) === (x, y)
@test all(map(==(Bool), eltype(st2)))
@test eltype(st2) === @NamedTuple{one::Bool, two::Bool, extradim::Bool}
end

@testset "merge" begin
Expand Down Expand Up @@ -332,3 +332,10 @@ end
@test extrema(f, s) == (one=(2.0, 12.0), two=(4.0, 24.0), three=(6.0, 36.0))
@test mean(f, s) == (one=7.0, two=14.0, three=21)
end

@testset "rand sampling" begin
@test rand(s) isa @NamedTuple{one::Float64, two::Float32, three::Int}
@test rand(Xoshiro(), s) isa @NamedTuple{one::Float64, two::Float32, three::Int}
@test rand(mixed) isa @NamedTuple{one::Float64, two::Float32, extradim::Float64}
@test rand(MersenneTwister(), mixed) isa @NamedTuple{one::Float64, two::Float32, extradim::Float64}
end

0 comments on commit a9dd0e9

Please sign in to comment.