Skip to content

Commit

Permalink
[TensorAlgebra] Mat-vecs in contract, change handling of empty blocke…
Browse files Browse the repository at this point in the history
…dperm blocks (#1459)

* [TensorAlgebra] Mat-vecs in contract, change handling of empty blockedperm blocks

* [NDTensors] Bump to v0.3.13
  • Loading branch information
mtfishman authored May 28, 2024
1 parent b4a6880 commit 0e6c219
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 9 deletions.
2 changes: 1 addition & 1 deletion NDTensors/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NDTensors"
uuid = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
authors = ["Matthew Fishman <[email protected]>"]
version = "0.3.12"
version = "0.3.13"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
4 changes: 1 addition & 3 deletions NDTensors/src/lib/TensorAlgebra/src/blockedpermutation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,9 +160,7 @@ end

BlockArrays.blocks(blockedperm::BlockedPermutation) = getfield(blockedperm, :blocks)

function blockedperm(length::Val, permblocks_maybe_empty::Tuple{Vararg{Int}}...)
# Drop empty blocks
permblocks = filter(!isempty, permblocks_maybe_empty)
function blockedperm(length::Val, permblocks::Tuple{Vararg{Int}}...)
@assert value(length) == sum(Base.length, permblocks; init=zero(Bool))
blockedperm = _BlockedPermutation(permblocks)
@assert isperm(blockedperm)
Expand Down
36 changes: 34 additions & 2 deletions NDTensors/src/lib/TensorAlgebra/src/contract/allocate_output.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,43 @@ function output_axes(
α::Number=true,
)
axes_contracted = blockpermute(axes(a1), perm1)
axes_contracted2 = blockpermute(axes(a2), perm2)
@assert axes_contracted == axes_contracted2
axes_contracted′ = blockpermute(axes(a2), perm2)
@assert axes_contracted == axes_contracted′
return ()
end

# Vec-mat.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{1},
a2::AbstractArray,
biperm2::BlockedPermutation{2},
α::Number=true,
)
(axes_contracted,) = blockpermute(axes(a1), perm1)
axes_contracted′, axes_dest = blockpermute(axes(a2), biperm2)
@assert axes_contracted == axes_contracted′
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
end

# Mat-vec.
function output_axes(
::typeof(contract),
perm_dest::BlockedPermutation{1},
a1::AbstractArray,
perm1::BlockedPermutation{2},
a2::AbstractArray,
biperm2::BlockedPermutation{1},
α::Number=true,
)
axes_dest, axes_contracted = blockpermute(axes(a1), perm1)
(axes_contracted′,) = blockpermute(axes(a2), biperm2)
@assert axes_contracted == axes_contracted′
return genperm((axes_dest...,), invperm(Tuple(perm_dest)))
end

# TODO: Use `ArrayLayouts`-like `MulAdd` object,
# i.e. `ContractAdd`?
function allocate_output(
Expand Down
9 changes: 6 additions & 3 deletions NDTensors/src/lib/TensorAlgebra/src/contract/blockedperms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,11 @@ function blockedperms(::typeof(contract), dimnames_dest, dimnames1, dimnames2)
perm_codomain2 = BaseExtensions.indexin(contracted, dimnames2)
perm_domain2 = BaseExtensions.indexin(domain, dimnames2)

biperm_dest = blockedperm(perm_codomain_dest, perm_domain_dest)
biperm1 = blockedperm(perm_codomain1, perm_domain1)
biperm2 = blockedperm(perm_codomain2, perm_domain2)
permblocks_dest = (perm_codomain_dest, perm_domain_dest)
biperm_dest = blockedperm(filter(!isempty, permblocks_dest)...)
permblocks1 = (perm_codomain1, perm_domain1)
biperm1 = blockedperm(filter(!isempty, permblocks1)...)
permblocks2 = (perm_codomain2, perm_domain2)
biperm2 = blockedperm(filter(!isempty, permblocks2)...)
return biperm_dest, biperm1, biperm2
end
Original file line number Diff line number Diff line change
Expand Up @@ -39,3 +39,19 @@ function _mul!(
a_dest[] = transpose(a1) * a2 * α + a_dest[] * β
return a_dest
end

# Vec-mat.
function _mul!(
a_dest::AbstractVector, a1::AbstractVector, a2::AbstractMatrix, α::Number, β::Number
)
mul!(transpose(a_dest), transpose(a1), a2, α, β)
return a_dest
end

# Mat-vec.
function _mul!(
a_dest::AbstractVector, a1::AbstractMatrix, a2::AbstractVector, α::Number, β::Number
)
mul!(a_dest, a1, a2, α, β)
return a_dest
end
14 changes: 14 additions & 0 deletions NDTensors/src/lib/TensorAlgebra/test/test_basics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
@test blocklasts(p) == (3, 5)
@test invperm(p) == blockedperm((5, 4, 1), (2, 3))

# Empty block.
p = blockedperm((3, 2), (), (1,))
@test Tuple(p) === (3, 2, 1)
@test isperm(p)
@test length(p) == 3
@test blocks(p) == ((3, 2), (), (1,))
@test blocklength(p) == 3
@test blocklengths(p) == (2, 0, 1)
@test blockfirsts(p) == (1, 3, 3)
@test blocklasts(p) == (2, 2, 3)
@test invperm(p) == blockedperm((3, 2), (), (1,))

# Split collection into `BlockedPermutation`.
p = blockedperm_indexin(("a", "b", "c", "d"), ("c", "a"), ("b", "d"))
@test p == blockedperm((3, 1), (2, 4))
Expand Down Expand Up @@ -120,6 +132,8 @@ end
for (d1s, d2s, d_dests) in (
((1, 2), (1, 2), ()),
((1, 2), (2, 1), ()),
((1, 2), (2, 1, 3), (3,)),
((1, 2, 3), (2, 1), (3,)),
((1, 2), (2, 3), (1, 3)),
((1, 2), (2, 3), (3, 1)),
((2, 1), (2, 3), (3, 1)),
Expand Down

2 comments on commit 0e6c219

@mtfishman
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register subdir=NDTensors

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/107835

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a NDTensors-v0.3.13 -m "<description of version>" 0e6c21955008b0b5830161151608cec35ca320df
git push origin NDTensors-v0.3.13

Please sign in to comment.