Skip to content

Commit

Permalink
Generalize strides for ReinterpretArray and ReshapedArray (#44027)
Browse files Browse the repository at this point in the history
  • Loading branch information
N5N3 authored Feb 6, 2022
1 parent 5181e36 commit e0a4b77
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 46 deletions.
82 changes: 44 additions & 38 deletions base/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
if N != 0 && sizeof(S) != sizeof(T)
ax1 = axes(a)[1]
dim = length(ax1)
if Base.issingletontype(T)
if issingletontype(T)
dim == 0 || throwsingleton(S, T, "a non-empty")
else
rem(dim*sizeof(S),sizeof(T)) == 0 || thrownonint(S, T, dim)
Expand Down Expand Up @@ -75,15 +75,15 @@ struct ReinterpretArray{T,N,S,A<:AbstractArray{S},IsReshaped} <: AbstractArray{T
if sizeof(S) == sizeof(T)
N = ndims(a)
elseif sizeof(S) > sizeof(T)
Base.issingletontype(T) && throwsingleton(S, T, "with reshape a")
issingletontype(T) && throwsingleton(S, T, "with reshape a")
rem(sizeof(S), sizeof(T)) == 0 || throwintmult(S, T)
N = ndims(a) + 1
else
Base.issingletontype(S) && throwfromsingleton(S, T)
issingletontype(S) && throwfromsingleton(S, T)
rem(sizeof(T), sizeof(S)) == 0 || throwintmult(S, T)
N = ndims(a) - 1
N > -1 || throwsize0(S, T, "larger")
axes(a, 1) == Base.OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T)
axes(a, 1) == OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T)
end
readable = array_subpadding(T, S)
writable = array_subpadding(S, T)
Expand Down Expand Up @@ -148,33 +148,39 @@ StridedVector{T} = StridedArray{T,1}
StridedMatrix{T} = StridedArray{T,2}
StridedVecOrMat{T} = Union{StridedVector{T}, StridedMatrix{T}}

# the definition of strides for Array{T,N} is tuple() if N = 0, otherwise it is
# a tuple containing 1 and a cumulative product of the first N-1 sizes
# this definition is also used for StridedReshapedArray and StridedReinterpretedArray
# which have the same memory storage as Array
stride(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}, i::Int) = _stride(a, i)

function stride(a::ReinterpretArray, i::Int)
a.parent isa StridedArray || throw(ArgumentError("Parent must be strided."))
return _stride(a, i)
end
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)

function _stride(a, i)
if i > ndims(a)
return length(a)
function strides(a::ReshapedReinterpretArray)
ap = parent(a)
els, elp = elsize(a), elsize(ap)
stp = strides(ap)
els == elp && return stp
els < elp && return (1, _checked_strides(stp, els, elp)...)
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
return _checked_strides(tail(stp), els, elp)
end

function strides(a::NonReshapedReinterpretArray)
ap = parent(a)
els, elp = elsize(a), elsize(ap)
stp = strides(ap)
els == elp && return stp
stp[1] == 1 || throw(ArgumentError("Parent must be contiguous in the 1st dimension!"))
return (1, _checked_strides(tail(stp), els, elp)...)
end

@inline function _checked_strides(stp::Tuple, els::Integer, elp::Integer)
if elp > els && rem(elp, els) == 0
N = div(elp, els)
return map(i -> N * i, stp)
end
s = 1
for n = 1:(i-1)
s *= size(a, n)
end
return s
drs = map(i -> divrem(elp * i, els), stp)
all(i->iszero(i[2]), drs) ||
throw(ArgumentError("Parent's strides could not be exactly divided!"))
map(first, drs)
end

function strides(a::ReinterpretArray)
a.parent isa StridedArray || throw(ArgumentError("Parent must be strided."))
size_to_strides(1, size(a)...)
end
strides(a::Union{DenseArray,StridedReshapedArray,StridedReinterpretArray}) = size_to_strides(1, size(a)...)
_checkcontiguous(::Type{Bool}, A::ReinterpretArray) = _checkcontiguous(Bool, parent(A))

similar(a::ReinterpretArray, T::Type, d::Dims) = similar(a.parent, T, d)

Expand Down Expand Up @@ -227,12 +233,12 @@ SCartesianIndices2{K}(indices2::AbstractUnitRange{Int}) where {K} = (@assert K::
eachindex(::IndexSCartesian2{K}, A::ReshapedReinterpretArray) where {K} = SCartesianIndices2{K}(eachindex(IndexLinear(), parent(A)))
@inline function eachindex(style::IndexSCartesian2{K}, A::AbstractArray, B::AbstractArray...) where {K}
iter = eachindex(style, A)
Base._all_match_first(C->eachindex(style, C), iter, B...) || Base.throw_eachindex_mismatch_indices(IndexSCartesian2{K}(), axes(A), axes.(B)...)
_all_match_first(C->eachindex(style, C), iter, B...) || throw_eachindex_mismatch_indices(IndexSCartesian2{K}(), axes(A), axes.(B)...)
return iter
end

size(iter::SCartesianIndices2{K}) where K = (K, length(iter.indices2))
axes(iter::SCartesianIndices2{K}) where K = (Base.OneTo(K), iter.indices2)
axes(iter::SCartesianIndices2{K}) where K = (OneTo(K), iter.indices2)

first(iter::SCartesianIndices2{K}) where {K} = SCartesianIndex2{K}(1, first(iter.indices2))
last(iter::SCartesianIndices2{K}) where {K} = SCartesianIndex2{K}(K, last(iter.indices2))
Expand Down Expand Up @@ -300,27 +306,27 @@ unaliascopy(a::ReshapedReinterpretArray{T}) where {T} = reinterpret(reshape, T,

function size(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
psize = size(a.parent)
size1 = Base.issingletontype(T) ? psize[1] : div(psize[1]*sizeof(S), sizeof(T))
size1 = issingletontype(T) ? psize[1] : div(psize[1]*sizeof(S), sizeof(T))
tuple(size1, tail(psize)...)
end
function size(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
psize = size(a.parent)
sizeof(S) > sizeof(T) && return (div(sizeof(S), sizeof(T)), psize...)
sizeof(S) < sizeof(T) && return Base.tail(psize)
sizeof(S) < sizeof(T) && return tail(psize)
return psize
end
size(a::NonReshapedReinterpretArray{T,0}) where {T} = ()

function axes(a::NonReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
paxs = axes(a.parent)
f, l = first(paxs[1]), length(paxs[1])
size1 = Base.issingletontype(T) ? l : div(l*sizeof(S), sizeof(T))
size1 = issingletontype(T) ? l : div(l*sizeof(S), sizeof(T))
tuple(oftype(paxs[1], f:f+size1-1), tail(paxs)...)
end
function axes(a::ReshapedReinterpretArray{T,N,S} where {N}) where {T,S}
paxs = axes(a.parent)
sizeof(S) > sizeof(T) && return (Base.OneTo(div(sizeof(S), sizeof(T))), paxs...)
sizeof(S) < sizeof(T) && return Base.tail(paxs)
sizeof(S) > sizeof(T) && return (OneTo(div(sizeof(S), sizeof(T))), paxs...)
sizeof(S) < sizeof(T) && return tail(paxs)
return paxs
end
axes(a::NonReshapedReinterpretArray{T,0}) where {T} = ()
Expand Down Expand Up @@ -372,7 +378,7 @@ end
@inline @propagate_inbounds function _getindex_ra(a::NonReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
if Base.issingletontype(T) # singleton types
if issingletontype(T) # singleton types
@boundscheck checkbounds(a, i1, tailinds...)
return T.instance
end
Expand Down Expand Up @@ -420,7 +426,7 @@ end
@inline @propagate_inbounds function _getindex_ra(a::ReshapedReinterpretArray{T,N,S}, i1::Int, tailinds::TT) where {T,N,S,TT}
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
if Base.issingletontype(T) # singleton types
if issingletontype(T) # singleton types
@boundscheck checkbounds(a, i1, tailinds...)
return T.instance
end
Expand Down Expand Up @@ -511,7 +517,7 @@ end
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
if Base.issingletontype(T) # singleton types
if issingletontype(T) # singleton types
@boundscheck checkbounds(a, i1, tailinds...)
# setindex! is a noop except for the index check
else
Expand Down Expand Up @@ -577,7 +583,7 @@ end
v = convert(T, v)::T
# Make sure to match the scalar reinterpret if that is applicable
if sizeof(T) == sizeof(S) && (fieldcount(T) + fieldcount(S)) == 0
if Base.issingletontype(T) # singleton types
if issingletontype(T) # singleton types
@boundscheck checkbounds(a, i1, tailinds...)
# setindex! is a noop except for the index check
else
Expand Down
17 changes: 15 additions & 2 deletions base/reshapedarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ end

@inline function _unsafe_getindex(A::ReshapedArray{T,N}, indices::Vararg{Int,N}) where {T,N}
axp = axes(A.parent)
i = offset_if_vec(Base._sub2ind(size(A), indices...), axp)
i = offset_if_vec(_sub2ind(size(A), indices...), axp)
I = ind2sub_rs(axp, A.mi, i)
_unsafe_getindex_rs(parent(A), I)
end
Expand All @@ -266,7 +266,7 @@ end

@inline function _unsafe_setindex!(A::ReshapedArray{T,N}, val, indices::Vararg{Int,N}) where {T,N}
axp = axes(A.parent)
i = offset_if_vec(Base._sub2ind(size(A), indices...), axp)
i = offset_if_vec(_sub2ind(size(A), indices...), axp)
@inbounds parent(A)[ind2sub_rs(axes(A.parent), A.mi, i)...] = val
val
end
Expand All @@ -292,3 +292,16 @@ substrides(strds::NTuple{N,Int}, I::Tuple{ReshapedUnitRange, Vararg{Any}}) where
(size_to_strides(strds[1], size(I[1])...)..., substrides(tail(strds), tail(I))...)
unsafe_convert(::Type{Ptr{T}}, V::SubArray{T,N,P,<:Tuple{Vararg{Union{RangeIndex,ReshapedUnitRange}}}}) where {T,N,P} =
unsafe_convert(Ptr{T}, V.parent) + (first_index(V)-1)*sizeof(T)


_checkcontiguous(::Type{Bool}, A::AbstractArray) = size_to_strides(1, size(A)...) == strides(A)
_checkcontiguous(::Type{Bool}, A::Array) = true
_checkcontiguous(::Type{Bool}, A::ReshapedArray) = _checkcontiguous(Bool, parent(A))
_checkcontiguous(::Type{Bool}, A::FastContiguousSubArray) = _checkcontiguous(Bool, parent(A))

function strides(a::ReshapedArray)
# We can handle non-contiguous parent if it's a StridedVector
ndims(parent(a)) == 1 && return size_to_strides(only(strides(parent(a))), size(a)...)
_checkcontiguous(Bool, a) || throw(ArgumentError("Parent must be contiguous."))
size_to_strides(1, size(a)...)
end
18 changes: 18 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1561,3 +1561,21 @@ end
r = Base.IdentityUnitRange(3:4)
@test reshape(r, :) === reshape(r, (:,)) === r
end

@testset "strides for ReshapedArray" begin
# Type-based contiguous check is tested in test/compiler/inline.jl
# General contiguous check
a = view(rand(10,10), 1:10, 1:10)
@test strides(vec(a)) == (1,)
b = view(parent(a), 1:9, 1:10)
@test_throws "Parent must be contiguous." strides(vec(b))
# StridedVector parent
for n in 1:3
a = view(collect(1:60n), 1:n:60n)
@test strides(reshape(a, 3, 4, 5)) == (n, 3n, 12n)
@test strides(reshape(a, 5, 6, 2)) == (n, 5n, 30n)
b = view(parent(a), 60n:-n:1)
@test strides(reshape(b, 3, 4, 5)) == (-n, -3n, -12n)
@test strides(reshape(b, 5, 6, 2)) == (-n, -5n, -30n)
end
end
7 changes: 7 additions & 0 deletions test/compiler/inline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -907,3 +907,10 @@ end
@test fully_eliminated((String,)) do x
Base.@invoke conditional_escape!(false::Any, x::Any)
end

@testset "strides for ReshapedArray (PR#44027)" begin
# Type-based contiguous check
a = vec(reinterpret(reshape,Int16,reshape(view(reinterpret(Int32,randn(10)),2:11),5,:)))
f(a) = only(strides(a));
@test fully_eliminated(f, Tuple{typeof(a)}) && f(a) == 1
end
62 changes: 56 additions & 6 deletions test/reinterpretarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,12 +157,62 @@ let A = collect(reshape(1:20, 5, 4))
@test reshape(R, :) isa StridedArray
end

# and ensure a reinterpret array containing a strided array can have strides computed
let A = view(reinterpret(Int16, collect(reshape(UnitRange{Int64}(1, 20), 5, 4))), :, 1:2)
R = reinterpret(Int32, A)
@test strides(R) == (1, 10)
@test stride(R, 1) == 1
@test stride(R, 2) == 10
function check_strides(A::AbstractArray)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
map(i -> stride(A, i), dims) == strides(A) || return false
# Test strides via value check.
for i in eachindex(IndexLinear(), A)
A[i] === Base.unsafe_load(pointer(A, i)) || return false
end
return true
end

@testset "strides for NonReshapedReinterpretArray" begin
A = Array{Int32}(reshape(1:88, 11, 8))
for viewax2 in (1:8, 1:2:6, 7:-1:1, 5:-2:1, 2:3:8, 7:-6:1, 3:5:11)
# dim1 is contiguous
for T in (Int16, Float32)
@test check_strides(reinterpret(T, view(A, 1:8, viewax2)))
end
if mod(step(viewax2), 2) == 0
@test check_strides(reinterpret(Int64, view(A, 1:8, viewax2)))
else
@test_throws "Parent's strides" strides(reinterpret(Int64, view(A, 1:8, viewax2)))
end
# non-integer-multipled classified
if mod(step(viewax2), 3) == 0
@test check_strides(reinterpret(NTuple{3,Int16}, view(A, 2:7, viewax2)))
else
@test_throws "Parent's strides" strides(reinterpret(NTuple{3,Int16}, view(A, 2:7, viewax2)))
end
if mod(step(viewax2), 5) == 0
@test check_strides(reinterpret(NTuple{5,Int16}, view(A, 2:11, viewax2)))
else
@test_throws "Parent's strides" strides(reinterpret(NTuple{5,Int16}, view(A, 2:11, viewax2)))
end
# dim1 is not contiguous
for T in (Int16, Int64)
@test_throws "Parent must" strides(reinterpret(T, view(A, 8:-1:1, viewax2)))
end
@test check_strides(reinterpret(Float32, view(A, 8:-1:1, viewax2)))
end
end

@testset "strides for ReshapedReinterpretArray" begin
A = Array{Int32}(reshape(1:192, 3, 8, 8))
for viewax1 in (1:8, 1:2:8, 8:-1:1, 8:-2:1), viewax2 in (1:2, 4:-1:1)
for T in (Int16, Float32)
@test check_strides(reinterpret(reshape, T, view(A, 1:2, viewax1, viewax2)))
@test check_strides(reinterpret(reshape, T, view(A, 1:2:3, viewax1, viewax2)))
end
if mod(step(viewax1), 2) == 0
@test check_strides(reinterpret(reshape, Int64, view(A, 1:2, viewax1, viewax2)))
else
@test_throws "Parent's strides" strides(reinterpret(reshape, Int64, view(A, 1:2, viewax1, viewax2)))
end
@test_throws "Parent must" strides(reinterpret(reshape, Int64, view(A, 1:2:3, viewax1, viewax2)))
end
end

@testset "strides" begin
Expand Down

2 comments on commit e0a4b77

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Executing the daily package evaluation, I will reply here when finished:

@nanosoldier runtests(ALL, isdaily = true)

@nanosoldier
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your package evaluation job has completed - possible new issues were detected. A full report can be found here.

Please sign in to comment.