From 8b26f8996b9c71ee728a248ddf3d4c12d862a9b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Fri, 14 Jun 2024 10:57:44 +0200 Subject: [PATCH] Multiplication of 3 or more AlgebraElement (#46) * Multiplication of 3 or more AlgebraElement * No aggregate_constants * Add test * Fix * Add test --- src/arithmetic.jl | 68 ++++++++++++++++++++++++++++------------- src/diracs_augmented.jl | 3 +- src/mstructures.jl | 39 +++++++++++------------ src/mtables.jl | 3 ++ src/sparse_coeffs.jl | 6 ++++ test/monoid_algebra.jl | 7 +++++ 6 files changed, 82 insertions(+), 44 deletions(-) diff --git a/src/arithmetic.jl b/src/arithmetic.jl index 0ae8e1b..b2f078e 100644 --- a/src/arithmetic.jl +++ b/src/arithmetic.jl @@ -1,14 +1,12 @@ -function _preallocate_output(X::AlgebraElement, a::Number, op) - T = MA.promote_operation(op, eltype(X), typeof(a)) - return similar(X, T) -end +_coeff_type(X::AlgebraElement) = eltype(X) +_coeff_type(a) = typeof(a) -function _preallocate_output(X::AlgebraElement, Y::AlgebraElement, op) - T = MA.promote_operation(op, eltype(X), eltype(Y)) - if coeffs(Y) isa DenseArray # what a hack :) - return similar(Y, T) +function _preallocate_output(op, args::Vararg{Any,N}) where {N} + T = MA.promote_operation(op, _coeff_type.(args)...) + if args[2] isa AlgebraElement && coeffs(args[2]) isa DenseArray # what a hack :) + return similar(args[2], T) end - return similar(X, T) + return similar(args[1], T) end # module structure: @@ -18,23 +16,26 @@ Base.:(/)(X::AlgebraElement, a::Number) = inv(a) * X Base.:(//)(X::AlgebraElement, a::Number) = 1 // a * X function Base.:-(X::AlgebraElement) - return MA.operate_to!(_preallocate_output(X, -1, *), -, X) + return MA.operate_to!(_preallocate_output(*, X, -1), -, X) end function Base.:*(a::Number, X::AlgebraElement) - return MA.operate_to!(_preallocate_output(X, a, *), *, X, a) + return MA.operate_to!(_preallocate_output(*, X, a), *, X, a) end function Base.:div(X::AlgebraElement, a::Number) - return MA.operate_to!(_preallocate_output(X, a, div), div, X, a) + return MA.operate_to!(_preallocate_output(div, X, a), div, X, a) end function Base.:+(X::AlgebraElement, Y::AlgebraElement) - return MA.operate_to!(_preallocate_output(X, Y, +), +, X, Y) + return MA.operate_to!(_preallocate_output(+, X, Y), +, X, Y) end function Base.:-(X::AlgebraElement, Y::AlgebraElement) - return MA.operate_to!(_preallocate_output(X, Y, -), -, X, Y) + return MA.operate_to!(_preallocate_output(-, X, Y), -, X, Y) end function Base.:*(X::AlgebraElement, Y::AlgebraElement) - return MA.operate_to!(_preallocate_output(X, Y, *), *, X, Y) + return MA.operate_to!(_preallocate_output(*, X, Y), *, X, Y) +end +function Base.:*(args::Vararg{AlgebraElement,N}) where {N} + return MA.operate_to!(_preallocate_output(*, args...), *, args...) end Base.:^(a::AlgebraElement, p::Integer) = Base.power_by_squaring(a, p) @@ -99,11 +100,36 @@ end function MA.operate_to!( res::AlgebraElement, ::typeof(*), - X::AlgebraElement, - Y::AlgebraElement, -) - @assert parent(res) === parent(X) === parent(Y) - mstr = mstructure(basis(parent(res))) - MA.operate_to!(coeffs(res), mstr, coeffs(X), coeffs(Y)) + args::Vararg{AlgebraElement,N}, +) where {N} + for arg in args + if arg isa AlgebraElement + @assert parent(res) == parent(arg) + end + end + mstr = mstructure(basis(res)) + MA.operate_to!(coeffs(res), mstr, coeffs.(args)...) return res end + +function MA.operate!( + ::UnsafeAddMul{typeof(*)}, + res::AlgebraElement, + args::Vararg{AlgebraElement,N}, +) where {N} + for arg in args + if arg isa AlgebraElement + @assert parent(res) == parent(arg) + end + end + mstr = mstructure(basis(res)) + MA.operate!(UnsafeAddMul(mstr), coeffs(res), coeffs.(args)...) + return res +end + +# TODO just push to internal vectors once canonical `does` not just +# call `dropzeros!` but also reorders +function unsafe_push!(a::SparseArrays.SparseVector, k, v) + a[k] = MA.add!!(a[k], v) + return a +end diff --git a/src/diracs_augmented.jl b/src/diracs_augmented.jl index e555472..e0f8012 100644 --- a/src/diracs_augmented.jl +++ b/src/diracs_augmented.jl @@ -125,8 +125,7 @@ function coeffs!( MA.operate!( UnsafeAddMul(*), res, - v, - SparseCoefficients((target[Augmented(x)],), (1,)), + SparseCoefficients((target[Augmented(x)],), (v,)), ) end MA.operate!(canonical, res) diff --git a/src/mstructures.jl b/src/mstructures.jl index 2cf3e73..1fab9d1 100644 --- a/src/mstructures.jl +++ b/src/mstructures.jl @@ -44,37 +44,34 @@ struct UnsafeAddMul{M<:Union{typeof(*),MultiplicativeStructure}} structure::M end -function MA.operate_to!(res, ms::MultiplicativeStructure, v, w) - if res === v || res === w +function MA.operate_to!(res, ms::MultiplicativeStructure, args::Vararg{Any,N}) where {N} + if any(Base.Fix1(===, res), args) throw(ArgumentError("No alias allowed")) end MA.operate!(zero, res) - MA.operate!(UnsafeAddMul(ms), res, v, w) + MA.operate!(UnsafeAddMul(ms), res, args...) MA.operate!(canonical, res) return res end -function MA.operate!( - ::UnsafeAddMul{typeof(*)}, - mc::SparseCoefficients, - val, - c::AbstractCoefficients, -) - append!(mc.basis_elements, keys(c)) - vals = values(c) - if vals isa AbstractVector - append!(mc.values, val .* vals) - else - append!(mc.values, val * collect(values(c))) +function MA.operate!(::UnsafeAddMul, res, c) + for (k, v) in nonzero_pairs(c) + unsafe_push!(res, k, v) end - return mc + return res end -function MA.operate!(ms::UnsafeAddMul, res, v, w) - for (kv, a) in nonzero_pairs(v) - for (kw, b) in nonzero_pairs(w) - c = ms.structure(kv, kw) - MA.operate!(UnsafeAddMul(*), res, a * b, c) +function MA.operate!(op::UnsafeAddMul, res, b, c, args::Vararg{Any, N}) where {N} + for (kb, vb) in nonzero_pairs(b) + for (kc, vc) in nonzero_pairs(c) + for (k, v) in nonzero_pairs(op.structure(kb, kc)) + MA.operate!( + op, + res, + SparseCoefficients((_key(op.structure, k),), (vb * vc * v,)), + args..., + ) + end end end return res diff --git a/src/mtables.jl b/src/mtables.jl index 6b9fac0..20b5bee 100644 --- a/src/mtables.jl +++ b/src/mtables.jl @@ -92,6 +92,9 @@ function complete!(mt::MTable) return mt end +_key(_, k) = k +_key(mstr::MTable, k) = mstr[k] + function MA.operate!( ms::UnsafeAddMul{<:MTable}, res::AbstractCoefficients, diff --git a/src/sparse_coeffs.jl b/src/sparse_coeffs.jl index f5b8010..bd828c9 100644 --- a/src/sparse_coeffs.jl +++ b/src/sparse_coeffs.jl @@ -72,6 +72,12 @@ function MA.operate!(::typeof(canonical), res::SparseCoefficients) return MA.operate!(canonical, res, comparable(key_type(res))) end +function unsafe_push!(res::SparseCoefficients, key, value) + push!(res.basis_elements, key) + push!(res.values, value) + return res +end + # `::C` is needed to force Julia specialize on the function type # Otherwise, we get one allocation when we call `issorted` # See https://docs.julialang.org/en/v1/manual/performance-tips/#Be-aware-of-when-Julia-avoids-specializing diff --git a/test/monoid_algebra.jl b/test/monoid_algebra.jl index 92b22cc..a49f738 100644 --- a/test/monoid_algebra.jl +++ b/test/monoid_algebra.jl @@ -18,6 +18,7 @@ @test iszero(zero(fRG)) @test zero(g) == zero(fRG) @test iszero(0 * g) + @test isone(*(g, g, g)) @testset "Translations between bases" begin Z = zero(RG) @@ -131,6 +132,12 @@ @test @allocated(MA.operate_to!(d, *, a, 2)) == 0 @test d == 2a + + MA.operate!(zero, d) + MA.operate!(SA.UnsafeAddMul(*), d, a, b, b) + MA.operate!(SA.canonical, SA.coeffs(d)) + @test a * b^2 == *(a, b, b) + @test d == *(a, b, b) end end end