From 57994ff8ea3a869ec0a457fe766032faee7941b4 Mon Sep 17 00:00:00 2001 From: Matt Fishman Date: Wed, 13 Nov 2024 19:47:02 -0500 Subject: [PATCH] [BlockSparseArrays] Direct sum/`cat` (#1579) * [BlockSparseArrays] Direct sum/`cat` * [NDTensors] Bump to v0.3.64 --- NDTensors/Project.toml | 2 +- .../src/BlockSparseArrays.jl | 2 + .../src/abstractblocksparsearray/cat.jl | 7 ++ .../src/blocksparsearrayinterface/cat.jl | 26 ++++++++ .../lib/BlockSparseArrays/test/test_basics.jl | 27 ++++++++ .../src/SparseArrayInterface.jl | 2 + .../src/abstractsparsearray/cat.jl | 4 ++ .../src/sparsearrayinterface/cat.jl | 64 +++++++++++++++++++ .../src/sparsearrayinterface/indexing.jl | 19 ++++++ .../test/test_abstractsparsearray.jl | 32 ++++++++++ 10 files changed, 184 insertions(+), 1 deletion(-) create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl create mode 100644 NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl create mode 100644 NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl create mode 100644 NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index 23a9edd645..6a559b114e 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.63" +version = "0.3.64" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl index d0a1e4cdd7..dc43ba560a 100644 --- a/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl +++ b/NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl @@ -7,6 +7,7 @@ include("blocksparsearrayinterface/broadcast.jl") include("blocksparsearrayinterface/map.jl") include("blocksparsearrayinterface/arraylayouts.jl") include("blocksparsearrayinterface/views.jl") +include("blocksparsearrayinterface/cat.jl") include("abstractblocksparsearray/abstractblocksparsearray.jl") include("abstractblocksparsearray/wrappedabstractblocksparsearray.jl") include("abstractblocksparsearray/abstractblocksparsematrix.jl") @@ -17,6 +18,7 @@ include("abstractblocksparsearray/sparsearrayinterface.jl") include("abstractblocksparsearray/broadcast.jl") include("abstractblocksparsearray/map.jl") include("abstractblocksparsearray/linearalgebra.jl") +include("abstractblocksparsearray/cat.jl") include("blocksparsearray/defaults.jl") include("blocksparsearray/blocksparsearray.jl") include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl") diff --git a/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl new file mode 100644 index 0000000000..eac4ea1b02 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/cat.jl @@ -0,0 +1,7 @@ +# TODO: Change to `AnyAbstractBlockSparseArray`. +function Base.cat(as::BlockSparseArrayLike...; dims) + # TODO: Use `sparse_cat` instead, currently + # that erroneously allocates too many blocks that are + # zero and shouldn't be stored. + return blocksparse_cat(as...; dims) +end diff --git a/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl new file mode 100644 index 0000000000..22d1a24a02 --- /dev/null +++ b/NDTensors/src/lib/BlockSparseArrays/src/blocksparsearrayinterface/cat.jl @@ -0,0 +1,26 @@ +using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths +using NDTensors.SparseArrayInterface: SparseArrayInterface, allocate_cat_output, sparse_cat! + +# TODO: Maybe move to `SparseArrayInterfaceBlockArraysExt`. +# TODO: Handle dual graded unit ranges, for example in a new `SparseArrayInterfaceGradedAxesExt`. +function SparseArrayInterface.axis_cat( + a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange +) + return blockedrange(vcat(blocklengths(a1), blocklengths(a2))) +end + +# that erroneously allocates too many blocks that are +# zero and shouldn't be stored. +function blocksparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims) + sparse_cat!(blocks(a_dest), blocks.(as)...; dims) + return a_dest +end + +# TODO: Delete this in favor of `sparse_cat`, currently +# that erroneously allocates too many blocks that are +# zero and shouldn't be stored. +function blocksparse_cat(as::AbstractArray...; dims) + a_dest = allocate_cat_output(as...; dims) + blocksparse_cat!(a_dest, as...; dims) + return a_dest +end diff --git a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl index 4374be541c..32990471a0 100644 --- a/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl +++ b/NDTensors/src/lib/BlockSparseArrays/test/test_basics.jl @@ -866,6 +866,33 @@ using .NDTensorsTestUtils: devices_list, is_supported_eltype @test a1' * a2 ≈ Array(a1)' * Array(a2) @test dot(a1, a2) ≈ a1' * a2 end + @testset "cat" begin + a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3])) + a1[Block(2, 1)] = dev(randn(elt, size(@view(a1[Block(2, 1)])))) + a2 = dev(BlockSparseArray{elt}([2, 3], [2, 3])) + a2[Block(1, 2)] = dev(randn(elt, size(@view(a2[Block(1, 2)])))) + + a_dest = cat(a1, a2; dims=1) + @test block_nstored(a_dest) == 2 + @test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3]) + @test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 2)]) + @test a_dest[Block(2, 1)] == a1[Block(2, 1)] + @test a_dest[Block(3, 2)] == a2[Block(1, 2)] + + a_dest = cat(a1, a2; dims=2) + @test block_nstored(a_dest) == 2 + @test blocklengths.(axes(a_dest)) == ([2, 3], [2, 3, 2, 3]) + @test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(1, 4)]) + @test a_dest[Block(2, 1)] == a1[Block(2, 1)] + @test a_dest[Block(1, 4)] == a2[Block(1, 2)] + + a_dest = cat(a1, a2; dims=(1, 2)) + @test block_nstored(a_dest) == 2 + @test blocklengths.(axes(a_dest)) == ([2, 3, 2, 3], [2, 3, 2, 3]) + @test issetequal(block_stored_indices(a_dest), [Block(2, 1), Block(3, 4)]) + @test a_dest[Block(2, 1)] == a1[Block(2, 1)] + @test a_dest[Block(3, 4)] == a2[Block(1, 2)] + end @testset "TensorAlgebra" begin a1 = dev(BlockSparseArray{elt}([2, 3], [2, 3])) a1[Block(1, 1)] = dev(randn(elt, size(@view(a1[Block(1, 1)])))) diff --git a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl index 33647bf476..f192225f27 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/SparseArrayInterface.jl @@ -12,6 +12,7 @@ include("sparsearrayinterface/broadcast.jl") include("sparsearrayinterface/conversion.jl") include("sparsearrayinterface/wrappers.jl") include("sparsearrayinterface/zero.jl") +include("sparsearrayinterface/cat.jl") include("sparsearrayinterface/SparseArrayInterfaceLinearAlgebraExt.jl") include("abstractsparsearray/abstractsparsearray.jl") include("abstractsparsearray/abstractsparsematrix.jl") @@ -24,6 +25,7 @@ include("abstractsparsearray/broadcast.jl") include("abstractsparsearray/map.jl") include("abstractsparsearray/baseinterface.jl") include("abstractsparsearray/convert.jl") +include("abstractsparsearray/cat.jl") include("abstractsparsearray/SparseArrayInterfaceSparseArraysExt.jl") include("abstractsparsearray/SparseArrayInterfaceLinearAlgebraExt.jl") end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl new file mode 100644 index 0000000000..a9db504e38 --- /dev/null +++ b/NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/cat.jl @@ -0,0 +1,4 @@ +# TODO: Change to `AnyAbstractSparseArray`. +function Base.cat(as::SparseArrayLike...; dims) + return sparse_cat(as...; dims) +end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl new file mode 100644 index 0000000000..9f2b3179a5 --- /dev/null +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/cat.jl @@ -0,0 +1,64 @@ +unval(x) = x +unval(::Val{x}) where {x} = x + +# TODO: Assert that `a1` and `a2` start at one. +axis_cat(a1::AbstractUnitRange, a2::AbstractUnitRange) = Base.OneTo(length(a1) + length(a2)) +function axis_cat( + a1::AbstractUnitRange, a2::AbstractUnitRange, a_rest::AbstractUnitRange... +) + return axis_cat(axis_cat(a1, a2), a_rest...) +end +function cat_axes(as::AbstractArray...; dims) + return ntuple(length(first(axes.(as)))) do dim + return if dim in unval(dims) + axis_cat(map(axes -> axes[dim], axes.(as))...) + else + axes(first(as))[dim] + end + end +end + +function allocate_cat_output(as::AbstractArray...; dims) + eltype_dest = promote_type(eltype.(as)...) + axes_dest = cat_axes(as...; dims) + # TODO: Promote the block types of the inputs rather than using + # just the first input. + # TODO: Make this customizable with `cat_similar`. + # TODO: Base the zero element constructor on those of the inputs, + # for example block sparse arrays. + return similar(first(as), eltype_dest, axes_dest...) +end + +# https://github.com/JuliaLang/julia/blob/v1.11.1/base/abstractarray.jl#L1748-L1857 +# https://docs.julialang.org/en/v1/base/arrays/#Concatenation-and-permutation +# This is very similar to the `Base.cat` implementation but handles zero values better. +function cat_offset!( + a_dest::AbstractArray, offsets, a1::AbstractArray, a_rest::AbstractArray...; dims +) + inds = ntuple(ndims(a_dest)) do dim + dim in unval(dims) ? offsets[dim] .+ axes(a1, dim) : axes(a_dest, dim) + end + a_dest[inds...] = a1 + new_offsets = ntuple(ndims(a_dest)) do dim + dim in unval(dims) ? offsets[dim] + size(a1, dim) : offsets[dim] + end + cat_offset!(a_dest, new_offsets, a_rest...; dims) + return a_dest +end +function cat_offset!(a_dest::AbstractArray, offsets; dims) + return a_dest +end + +# TODO: Define a generic `cat!` function. +function sparse_cat!(a_dest::AbstractArray, as::AbstractArray...; dims) + offsets = ntuple(zero, ndims(a_dest)) + # TODO: Fill `a_dest` with zeros if needed. + cat_offset!(a_dest, offsets, as...; dims) + return a_dest +end + +function sparse_cat(as::AbstractArray...; dims) + a_dest = allocate_cat_output(as...; dims) + sparse_cat!(a_dest, as...; dims) + return a_dest +end diff --git a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl index ae0f6f6d61..f416ca421e 100644 --- a/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl +++ b/NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/indexing.jl @@ -137,12 +137,31 @@ function sparse_setindex!(a::AbstractArray, value, I::Vararg{Int}) return a end +# Fix ambiguity error +function sparse_setindex!(a::AbstractArray, value) + sparse_setindex!(a, value, CartesianIndex()) + return a +end + # Linear indexing function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex{1}) sparse_setindex!(a, value, CartesianIndices(a)[I]) return a end +# Slicing +# TODO: Make this handle more general slicing operations, +# base it off of `ArrayLayouts.sub_materialize`. +function sparse_setindex!(a::AbstractArray, value, I::AbstractUnitRange...) + inds = CartesianIndices(I) + for i in stored_indices(value) + if i in CartesianIndices(inds) + a[inds[i]] = value[i] + end + end + return a +end + # Handle trailing indices function sparse_setindex!(a::AbstractArray, value, I::CartesianIndex) t = Tuple(I) diff --git a/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl b/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl index 743f457d43..47cf6668c6 100644 --- a/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl +++ b/NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl @@ -342,6 +342,38 @@ using Test: @test, @testset @test a_dest isa SparseArray{elt} @test SparseArrayInterface.nstored(a_dest) == 2 + # cat + a1 = SparseArray{elt}(2, 3) + a1[1, 2] = 12 + a1[2, 1] = 21 + a2 = SparseArray{elt}(2, 3) + a2[1, 1] = 11 + a2[2, 2] = 22 + + a_dest = cat(a1, a2; dims=1) + @test size(a_dest) == (4, 3) + @test SparseArrayInterface.nstored(a_dest) == 4 + @test a_dest[1, 2] == a1[1, 2] + @test a_dest[2, 1] == a1[2, 1] + @test a_dest[3, 1] == a2[1, 1] + @test a_dest[4, 2] == a2[2, 2] + + a_dest = cat(a1, a2; dims=2) + @test size(a_dest) == (2, 6) + @test SparseArrayInterface.nstored(a_dest) == 4 + @test a_dest[1, 2] == a1[1, 2] + @test a_dest[2, 1] == a1[2, 1] + @test a_dest[1, 4] == a2[1, 1] + @test a_dest[2, 5] == a2[2, 2] + + a_dest = cat(a1, a2; dims=(1, 2)) + @test size(a_dest) == (4, 6) + @test SparseArrayInterface.nstored(a_dest) == 4 + @test a_dest[1, 2] == a1[1, 2] + @test a_dest[2, 1] == a1[2, 1] + @test a_dest[3, 4] == a2[1, 1] + @test a_dest[4, 5] == a2[2, 2] + ## # Sparse matrix of matrix multiplication ## TODO: Make this work, seems to require ## a custom zero constructor.