diff --git a/src/functions.jl b/src/functions.jl index 48d31eb..fb272a0 100644 --- a/src/functions.jl +++ b/src/functions.jl @@ -12,7 +12,7 @@ function nameddimsarray_result(original_nda, reduced_data, reduction_dims::Colon return reduced_data end -################################################################################### +################################################ # Overloads # 1 Arg @@ -48,19 +48,6 @@ for (mod, funs) in ( end end -if VERSION > v"1.1-" - function Base.eachslice(a::NamedDimsArray{L}; dims, kwargs...) where L - numerical_dims = dim(a, dims) - slices = eachslice(parent(a); dims=numerical_dims, kwargs...) - return Base.Generator(slices) do slice - # For unknown reasons (something to do with hoisting?) having this in the - # function passed to `Generator` actually results in less memory being allocated - names = remaining_dimnames_after_dropping(L, numerical_dims) - return NamedDimsArray(slice, names) - end - end -end - # 1 arg before - no default for `dims` keyword for (mod, funs) in ( (:Base, (:mapslices,)), @@ -110,3 +97,33 @@ function Base.append!(A::NamedDimsArray{L,T,1}, B::AbstractVector) where {L,T} data = append!(parent(A), unname(B)) return NamedDimsArray{newL}(data) end + +################################################ +# Generators + +@static if VERSION >= v"1.1" + Base.eachslice(A::NamedDimsArray; dims) = _eachslice(A, dims) +else + eachcol(A::AbstractVecOrMat) = (view(A, :, i) for i in axes(A, 2)) + eachrow(A::AbstractVecOrMat) = (view(A, i, :) for i in axes(A, 1)) + # every line identical to Base, but no _eachslice(A, dims) to disatch on. + eachslice(A::AbstractArray; dims) = _eachslice(A, dims) +end + +function _eachslice(A::AbstractArray, dims::Symbol) + numerical_dims = dim(A, dims) + return _eachslice(A, numerical_dims) +end +function _eachslice(A::AbstractArray, dims::Tuple) + length(dims) == 1 || throw(ArgumentError("only single dimensions are supported")) + return _eachslice(A, first(dims)) +end +@inline function _eachslice(A::AbstractArray, dim::Int) + dim <= ndims(A) || throw(DimensionMismatch("A doesn't have $dim dimensions")) + idx1, idx2 = ntuple(d->(:), dim-1), ntuple(d->(:), ndims(A)-dim) + return (view(A, idx1..., i, idx2...) for i in axes(A, dim)) +end + +function Base.collect(itr::Base.Generator{<:NamedDimsArray{L}}) where {L} + NamedDimsArray{L}(collect(Base.Generator(itr.f, parent(itr.iter)))) +end diff --git a/test/functions.jl b/test/functions.jl index fdc77ca..470b33c 100644 --- a/test/functions.jl +++ b/test/functions.jl @@ -45,7 +45,9 @@ using Statistics end @testset "eachslice" begin - if VERSION > v"1.1-" + if VERSION < v"1.1-" + using NamedDims: eachslice + end slices = [[111 121; 211 221], [112 122; 212 222]] a = cat(slices...; dims=3) nda = NamedDimsArray(a, (:a, :b, :c)) @@ -67,7 +69,24 @@ using Statistics names(first(eachslice(nda; dims=2))) == (:a, :c) ) + end + + @testset "eachcol, eachrow" begin + if VERSION < v"1.1-" + using NamedDims: eachrow, eachcol end + nda = NamedDimsArray([10 20; 31 40], (:x, :y)) + + @test names(first(eachcol(nda))) == (:x,) + @test names(first(eachrow(nda))) == (:y,) + end + + @testset "collect" begin + ndv = NamedDimsArray([1, 9, 7, 3], :vec) + @test names([sqrt(x) for x in ndv]) == (:vec,) + + nda = NamedDimsArray([10 20; 31 40], (:x, :y)) + @test names([x^2 for x in nda]) == (:x, :y) end @testset "mapslices" begin