From 95b503c0f70e3e8eec6078c6cbef9f9a9d2d93d5 Mon Sep 17 00:00:00 2001 From: Sheehan Olver Date: Sat, 7 Dec 2024 12:01:49 +0000 Subject: [PATCH] hcat and vcat too --- src/lazyconcat.jl | 74 +++++++++++++++++++++++---------------------- test/concattests.jl | 2 ++ 2 files changed, 40 insertions(+), 36 deletions(-) diff --git a/src/lazyconcat.jl b/src/lazyconcat.jl index 7ed7bf8..8ee34c3 100644 --- a/src/lazyconcat.jl +++ b/src/lazyconcat.jl @@ -21,35 +21,37 @@ Vcat() = Vcat{Any}() @inline function applied_instantiate(::typeof(vcat), args...) iargs = map(instantiate, args) if !isempty(iargs) - m = _vcat_size(iargs[1],2) + m = _cat_size(iargs[1],2) for k=2:length(iargs) - _vcat_size(iargs[k],2) == m || throw(ArgumentError("number of columns of each array must match (got $(map(x->_vcat_size(x,2), args)))")) + _cat_size(iargs[k],2) == m || throw(ArgumentError("number of columns of each array must match (got $(map(x->_cat_size(x,2), args)))")) end end vcat, iargs end -_vcat_axes(a, k) = Base.OneTo(1) -_vcat_axes(a::AbstractArray, k) = axes(a, k) -_vcat_size(a, k) = 1 -_vcat_size(a::AbstractArray, k) = size(a, k) -_vcat_ndims(a) = 1 -_vcat_ndims(a::AbstractArray) = ndims(a) -_vcat_eltype(a) = typeof(a) -_vcat_eltype(a::AbstractArray) = eltype(a) -_vcat_length(a) = 1 -_vcat_length(a::AbstractArray) = length(a) -_vcat_getindex(a, k...) = a -_vcat_getindex(a::AbstractArray, k...) = a[k...] +_cat_axes(a, k) = Base.OneTo(1) +_cat_axes(a::AbstractArray, k) = axes(a, k) +_cat_size(a, k) = 1 +_cat_size(a::AbstractArray, k) = size(a, k) +_cat_ndims(a) = 1 +_cat_ndims(a::AbstractArray) = ndims(a) +_cat_eltype(a) = typeof(a) +_cat_eltype(a::AbstractArray) = eltype(a) +_cat_length(a) = 1 +_cat_length(a::AbstractArray) = length(a) +_cat_getindex(a, k...) = a +_cat_getindex(a::AbstractArray, k...) = a[k...] +_cat_colsupport(a, k...) = 1 +_cat_colsupport(a::AbstractArray, k...) = colsupport(a, k) @inline applied_eltype(::typeof(vcat)) = Any -@inline applied_eltype(::typeof(vcat), args...) = promote_type(map(_vcat_eltype, args)...) -@inline applied_ndims(::typeof(vcat), args...) = max(1,maximum(map(_vcat_ndims,args))) +@inline applied_eltype(::typeof(vcat), args...) = promote_type(map(_cat_eltype, args)...) +@inline applied_ndims(::typeof(vcat), args...) = max(1,maximum(map(_cat_ndims,args))) @inline applied_ndims(::typeof(vcat)) = 1 @inline axes(f::Vcat{<:Any,1,Tuple{}}) = (OneTo(0),) -@inline axes(f::Vcat{<:Any,1}) = tuple(oneto(+(map(_vcat_length,f.args)...))) -@inline axes(f::Vcat{<:Any,2}) = (oneto(+(map(a -> _vcat_size(a,1), f.args)...)), _vcat_axes(f.args[1],2)) +@inline axes(f::Vcat{<:Any,1}) = tuple(oneto(+(map(_cat_length,f.args)...))) +@inline axes(f::Vcat{<:Any,2}) = (oneto(+(map(a -> _cat_size(a,1), f.args)...)), _cat_axes(f.args[1],2)) @inline size(f::Vcat) = map(length, axes(f)) @@ -71,8 +73,8 @@ end f, idx::Tuple{Integer}, A, args...) k, = idx T = eltype(f) - n = _vcat_length(A) - k ≤ n && return convert(T, _vcat_getindex(A,k))::T + n = _cat_length(A) + k ≤ n && return convert(T, _cat_getindex(A,k))::T vcat_getindex_recursive(f, (k - n, ), args...) end @@ -80,8 +82,8 @@ end f, idx::Tuple{Integer,Integer}, A, args...) k, j = idx T = eltype(f) - n = _vcat_size(A, 1) - k ≤ n && return convert(T, _vcat_getindex(A, k, j))::T + n = _cat_size(A, 1) + k ≤ n && return convert(T, _cat_getindex(A, k, j))::T vcat_getindex_recursive(f, (k - n, j), args...) end @@ -146,9 +148,9 @@ Hcat(A...) = ApplyArray(hcat, A...) Hcat() = Hcat{Any}() Hcat{T}(A...) where T = ApplyArray{T}(hcat, A...) -@inline applied_eltype(::typeof(hcat), args...) = promote_type(map(eltype,args)...) +@inline applied_eltype(::typeof(hcat), args...) = promote_type(map(_cat_eltype,args)...) @inline applied_ndims(::typeof(hcat), args...) = 2 -@inline applied_size(::typeof(hcat), args...) = (size(args[1],1), +(map(a -> size(a,2), args)...)) +@inline applied_size(::typeof(hcat), args...) = (_cat_size(args[1],1), +(map(a -> _cat_size(a,2), args)...)) @inline applied_size(::typeof(hcat)) = (0,0) @inline hcat_getindex(f, k, j::Integer) = hcat_getindex_recursive(f, (k, j), f.args...) @@ -156,16 +158,16 @@ Hcat{T}(A...) where T = ApplyArray{T}(hcat, A...) @inline function hcat_getindex_recursive(f, idx::Tuple{Integer,Integer}, A, args...) k, j = idx T = eltype(f) - n = size(A, 2) - j ≤ n && return convert(T, A[k, j])::T + n = _cat_size(A, 2) + j ≤ n && return convert(T, _cat_getindex(A,k, j))::T hcat_getindex_recursive(f, (k, j - n), args...) end @inline function hcat_getindex_recursive(f, idx::Tuple{Union{Colon,AbstractVector},Integer}, A, args...) kr, j = idx T = eltype(f) - n = size(A, 2) - j ≤ n && return convert(AbstractVector{T}, A[kr, j]) + n = _cat_size(A, 2) + j ≤ n && return convert(AbstractVector{T}, _cat_getindex(A, kr, j)) hcat_getindex_recursive(f, (kr, j - n), args...) end @@ -197,27 +199,27 @@ end # Hvcat #### -@inline applied_eltype(::typeof(hvcat), a, b...) = promote_type(map(eltype, b)...) +@inline applied_eltype(::typeof(hvcat), a, b...) = promote_type(map(_cat_eltype, b)...) @inline applied_ndims(::typeof(hvcat), args...) = 2 -@inline applied_size(::typeof(hvcat), n::Int, b...) = sum(size.(b[1:n:end],1)),sum(size.(b[1:n],2)) +@inline applied_size(::typeof(hvcat), n::Int, b...) = sum(_cat_size.(b[1:n:end],1)),sum(_cat_size.(b[1:n],2)) @inline function applied_size(::typeof(hvcat), n::NTuple{N,Int}, b...) where N as = tuple(2, (2 .+ cumsum(Base.front(n)))...) - sum(size.(getindex.(Ref((n, b...)), as),1)),sum(size.(b[1:n[1]],2)) + sum(_cat_size.(getindex.(Ref((n, b...)), as),1)),sum(_cat_size.(b[1:n[1]],2)) end @inline hvcat_getindex(f, k, j::Integer) = hvcat_getindex_recursive(f, (k, j), f.args...) -@inline _hvcat_size(A) = size(A) -@inline _hvcat_size(A::Number) = (1,1) +@inline _hvcat_size(A::AbstractArray) = size(A) +@inline _hvcat_size(A) = (1,1) @inline _hvcat_size(A::AbstractVector) = (size(A,1),1) @inline function hvcat_getindex_recursive(f, (k,j)::Tuple{Integer,Integer}, N::Int, A, args...) T = eltype(f) m,n = _hvcat_size(A) N ≤ 0 && throw(BoundsError(f, (k,j))) # ran out of arrays - k ≤ m && j ≤ n && return convert(T, A[k, j])::T + k ≤ m && j ≤ n && return convert(T, _cat_getindex(A, k, j))::T k ≤ m && return hvcat_getindex_recursive(f, (k, j - n), N-1, args...) hvcat_getindex_recursive(f, (k - m, j), N, args[N:end]...) end @@ -769,8 +771,8 @@ end function colsupport(lay::ApplyLayout{typeof(hcat)}, H::AbstractArray, j::Integer) ξ = j for A in arguments(lay,H) - n = size(A,2) - ξ ≤ n && return colsupport(A, ξ) + n = _cat_size(A,2) + ξ ≤ n && return _cat_colsupport(A, ξ) ξ -= n end return 1:0 diff --git a/test/concattests.jl b/test/concattests.jl index dea890e..b6e4e95 100644 --- a/test/concattests.jl +++ b/test/concattests.jl @@ -662,6 +662,8 @@ import LazyArrays: MemoryLayout, DenseColumnMajor, materialize!, call, paddeddat @test Vcat("hi", "bye") == ["hi", "bye"] @test Vcat(["hi" "bye"], [2 3]) == ["hi" "bye"; 2 3] @test Vcat("hi", [2;;]) == ["hi"; 2 ;;] + @test Hcat("hi", "bye") == ["hi" "bye"] + @test ApplyArray(hvcat, 2, "hi", 2, 3, "bye") == ApplyArray(hvcat, (2,1), "hi", 2, [3 "bye"]) == ["hi" 2; 3 "bye"] end end