diff --git a/NDTensors/Project.toml b/NDTensors/Project.toml index e25bea22a2..4c992ae743 100644 --- a/NDTensors/Project.toml +++ b/NDTensors/Project.toml @@ -1,7 +1,7 @@ name = "NDTensors" uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf" authors = ["Matthew Fishman "] -version = "0.3.12" +version = "0.3.13" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" diff --git a/NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl b/NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl index 160cd3aec8..c13b68a642 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl @@ -160,9 +160,7 @@ end BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks) -function blockedperm(length::Val, permblocks_maybe_empty::Tuple{Vararg{Int}}...) - # Drop empty blocks - permblocks = filter(!isempty, permblocks_maybe_empty) +function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...) @assert value(length) == sum(Base.length, permblocks; init=zero(Bool)) blockedperm = _BlockedPermutation(permblocks) @assert isperm(blockedperm) diff --git a/NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl b/NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl index 7732fa1258..2beff4c5bc 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl @@ -30,11 +30,43 @@ function output_axes( α::Number=true, ) axes_contracted = blockpermute(axes(a1), perm1) - axes_contracted2 = blockpermute(axes(a2), perm2) - @assert axes_contracted == axes_contracted2 + axes_contracted′ = blockpermute(axes(a2), perm2) + @assert axes_contracted == axes_contracted′ return () end +# Vec-mat. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{1}, + a1::AbstractArray, + perm1::BlockedPermutation{1}, + a2::AbstractArray, + biperm2::BlockedPermutation{2}, + α::Number=true, +) + (axes_contracted,) = blockpermute(axes(a1), perm1) + axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2) + @assert axes_contracted == axes_contracted′ + return genperm((axes_dest...,), invperm(Tuple(perm_dest))) +end + +# Mat-vec. +function output_axes( + ::typeof(contract), + perm_dest::BlockedPermutation{1}, + a1::AbstractArray, + perm1::BlockedPermutation{2}, + a2::AbstractArray, + biperm2::BlockedPermutation{1}, + α::Number=true, +) + axes_dest, axes_contracted = blockpermute(axes(a1), perm1) + (axes_contracted′,) = blockpermute(axes(a2), biperm2) + @assert axes_contracted == axes_contracted′ + return genperm((axes_dest...,), invperm(Tuple(perm_dest))) +end + # TODO: Use `ArrayLayouts`-like `MulAdd` object, # i.e. `ContractAdd`? function allocate_output( diff --git a/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl b/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl index 22d103c293..60009de9ef 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl @@ -22,8 +22,11 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2) perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2) perm_domain2 = BaseExtensions.indexin(domain, dimnames2) - biperm_dest = blockedperm(perm_codomain_dest, perm_domain_dest) - biperm1 = blockedperm(perm_codomain1, perm_domain1) - biperm2 = blockedperm(perm_codomain2, perm_domain2) + permblocks_dest = (perm_codomain_dest, perm_domain_dest) + biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...) + permblocks1 = (perm_codomain1, perm_domain1) + biperm1 = blockedperm(filter(!isempty, permblocks1)...) + permblocks2 = (perm_codomain2, perm_domain2) + biperm2 = blockedperm(filter(!isempty, permblocks2)...) return biperm_dest, biperm1, biperm2 end diff --git a/NDTensors/src/lib/TensorAlgebra/src/contract/contract_matricize/contract.jl b/NDTensors/src/lib/TensorAlgebra/src/contract/contract_matricize/contract.jl index 3ddc38ff76..beb70104bb 100644 --- a/NDTensors/src/lib/TensorAlgebra/src/contract/contract_matricize/contract.jl +++ b/NDTensors/src/lib/TensorAlgebra/src/contract/contract_matricize/contract.jl @@ -39,3 +39,19 @@ function _mul!( a_dest[] = transpose(a1) * a2 * α + a_dest[] * β return a_dest end + +# Vec-mat. +function _mul!( + a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number +) + mul!(transpose(a_dest), transpose(a1), a2, α, β) + return a_dest +end + +# Mat-vec. +function _mul!( + a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number +) + mul!(a_dest, a1, a2, α, β) + return a_dest +end diff --git a/NDTensors/src/lib/TensorAlgebra/test/test_basics.jl b/NDTensors/src/lib/TensorAlgebra/test/test_basics.jl index 40eefecc8b..22feb8c2aa 100644 --- a/NDTensors/src/lib/TensorAlgebra/test/test_basics.jl +++ b/NDTensors/src/lib/TensorAlgebra/test/test_basics.jl @@ -22,6 +22,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64}) @test blocklasts(p) == (3, 5) @test invperm(p) == blockedperm((5, 4, 1), (2, 3)) + # Empty block. + p = blockedperm((3, 2), (), (1,)) + @test Tuple(p) === (3, 2, 1) + @test isperm(p) + @test length(p) == 3 + @test blocks(p) == ((3, 2), (), (1,)) + @test blocklength(p) == 3 + @test blocklengths(p) == (2, 0, 1) + @test blockfirsts(p) == (1, 3, 3) + @test blocklasts(p) == (2, 2, 3) + @test invperm(p) == blockedperm((3, 2), (), (1,)) + # Split collection into `BlockedPermutation`. p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d")) @test p == blockedperm((3, 1), (2, 4)) @@ -120,6 +132,8 @@ end for (d1s, d2s, d_dests) in ( ((1, 2), (1, 2), ()), ((1, 2), (2, 1), ()), + ((1, 2), (2, 1, 3), (3,)), + ((1, 2, 3), (2, 1), (3,)), ((1, 2), (2, 3), (1, 3)), ((1, 2), (2, 3), (3, 1)), ((2, 1), (2, 3), (3, 1)),