diff --git a/docs/src/index.md b/docs/src/index.md index 5d4ff70..72329ec 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -15,6 +15,7 @@ AbstractPolynomialBasis maxdegree_basis basis_covering_monomials FixedPolynomialBasis +OrthonormalCoefficientsBasis ``` ## Monomial basis diff --git a/src/MultivariateBases.jl b/src/MultivariateBases.jl index dc84aed..a065ea2 100644 --- a/src/MultivariateBases.jl +++ b/src/MultivariateBases.jl @@ -13,6 +13,7 @@ include("monomial.jl") include("scaled.jl") export FixedPolynomialBasis, + OrthonormalCoefficientsBasis, AbstractMultipleOrthogonalBasis, ProbabilistsHermiteBasis, PhysicistsHermiteBasis, @@ -31,6 +32,7 @@ export generators, include("fixed.jl") import LinearAlgebra +include("orthonormal.jl") include("orthogonal.jl") include("hermite.jl") include("laguerre.jl") diff --git a/src/orthogonal.jl b/src/orthogonal.jl index c40eff6..5085c4a 100644 --- a/src/orthogonal.jl +++ b/src/orthogonal.jl @@ -209,16 +209,10 @@ function _integral( return sum([_integral(t, basis_type) for t in MP.terms(p)]) end -function MP.coefficients( - p, - basis::AbstractMultipleOrthogonalBasis; - check = true, -) +function MP.coefficients(p, basis::AbstractMultipleOrthogonalBasis) B = typeof(basis) - coeffs = [ + return [ LinearAlgebra.dot(p, el, B) / LinearAlgebra.dot(el, el, B) for el in basis ] - idx = findall(c -> !isapprox(c, 0; atol = 1e-10), coeffs) - return coeffs[idx] end diff --git a/src/orthonormal.jl b/src/orthonormal.jl new file mode 100644 index 0000000..5475a7d --- /dev/null +++ b/src/orthonormal.jl @@ -0,0 +1,48 @@ +""" + struct OrthonormalCoefficientsBasis{PT<:MP.AbstractPolynomialLike, PV<:AbstractVector{PT}} <: AbstractPolynomialBasis + polynomials::PV + end + +Polynomial basis with the polynomials of the vector `polynomials` that are +orthonormal with respect to the inner produce derived from the inner product +of their coefficients. +For instance, `FixedPolynomialBasis([1, x, 2x^2-1, 4x^3-3x])` is the Chebyshev +polynomial basis for cubic polynomials in the variable `x`. +""" +struct OrthonormalCoefficientsBasis{ + PT<:MP.AbstractPolynomialLike, + PV<:AbstractVector{PT}, +} <: AbstractPolynomialVectorBasis{PT,PV} + polynomials::PV +end + +function LinearAlgebra.dot( + p::MP.AbstractPolynomialLike{S}, + q::MP.AbstractPolynomialLike{T}, + ::Type{<:OrthonormalCoefficientsBasis}, +) where {S,T} + s = zero(MA.promote_operation(*, S, T)) + terms_p = MP.terms(p) + terms_q = MP.terms(q) + tsp = iterate(terms_p) + tsq = iterate(terms_q) + while !isnothing(tsp) && !isnothing(tsq) + tp, sp = tsp + tq, sq = tsq + cmp = MP.compare(MP.monomial(tp), MP.monomial(tq)) + if iszero(cmp) + s += conj(MP.coefficient(tp)) * MP.coefficient(tq) + tsp = iterate(terms_p, sp) + tsq = iterate(terms_q, sq) + elseif cmp < 0 + tsp = iterate(terms_p, sp) + else + tsq = iterate(terms_q, sq) + end + end + return s +end + +function MP.coefficients(p, basis::OrthonormalCoefficientsBasis) + return [LinearAlgebra.dot(q, p, typeof(basis)) for q in generators(basis)] +end diff --git a/test/fixed.jl b/test/fixed.jl index a0455b0..d74cd3c 100644 --- a/test/fixed.jl +++ b/test/fixed.jl @@ -14,8 +14,6 @@ using DynamicPolynomials @test collect(basis) == gens @test generators(basis) == gens @test length(basis) == 2 - @test firstindex(basis) == 1 - @test lastindex(basis) == 2 @test mindegree(basis) == 0 @test mindegree(basis, x) == 0 @test maxdegree(basis) == 2 diff --git a/test/monomial.jl b/test/monomial.jl index 3743427..16b598d 100644 --- a/test/monomial.jl +++ b/test/monomial.jl @@ -12,9 +12,6 @@ using DynamicPolynomials @test basis[2] == x @test generators(basis) == [y, x] @test collect(basis) == [y, x] - @test length(basis) == 2 - @test firstindex(basis) == 1 - @test lastindex(basis) == 2 @test variables(basis) == [x, y] @test nvariables(basis) == 2 @test sprint(show, basis) == "MonomialBasis([y, x])" diff --git a/test/orthonormal.jl b/test/orthonormal.jl new file mode 100644 index 0000000..922ec0a --- /dev/null +++ b/test/orthonormal.jl @@ -0,0 +1,82 @@ +using Test +using MultivariateBases +using DynamicPolynomials + +@polyvar x y + +@testset "Polynomials" begin + gens = [1 + x + y + x * y, 1 - x + y - x * y] / 2 + basis = OrthonormalCoefficientsBasis(gens) + @test iszero(dot(gens[1], gens[2], OrthonormalCoefficientsBasis)) + coefficient_test(basis, [2, -3]) + coefficient_test(basis, [-2im, 1 + 5im]) + coefficient_test(basis, [1im, 2im]) + @test polynomial_type(basis, Int) == polynomial_type(x, Float64) + @test polynomial(one, basis) == 1 + y + @test basis[1] == gens[1] + @test basis[2] == gens[2] + @test collect(basis) == gens + @test generators(basis) == gens + @test length(basis) == 2 + @test mindegree(basis) == 0 + @test mindegree(basis, x) == 0 + @test mindegree(basis, y) == 0 + @test maxdegree(basis) == 2 + @test maxdegree(basis, x) == 1 + @test maxdegree(basis, y) == 1 + @test extdegree(basis) == (0, 2) + @test extdegree(basis, x) == (0, 1) + @test extdegree(basis, y) == (0, 1) + @test variables(basis) == [x, y] + @test nvariables(basis) == 2 + @test sprint(show, basis) == + "OrthonormalCoefficientsBasis([0.5 + 0.5y + 0.5x + 0.5xy, 0.5 + 0.5y - 0.5x - 0.5xy])" + @test sprint(print, basis) == + "OrthonormalCoefficientsBasis([0.5 + 0.5*y + 0.5*x + 0.5*x*y, 0.5 + 0.5*y - 0.5*x - 0.5*x*y])" + b2 = basis[2:2] + @test length(b2) == 1 + @test b2[1] == gens[2] + b3 = basis[2:1] + @test isempty(b3) +end +@testset "Linear" begin + basis = OrthonormalCoefficientsBasis([x, y]) + @test polynomial_type(basis, Int) == polynomial_type(x, Int) + @test polynomial(identity, basis) == x + 2y +end +@testset "One variable" begin + basis = OrthonormalCoefficientsBasis([x]) + @test polynomial_type(basis, Int) == polynomial_type(x, Int) + @test polynomial(i -> 5, basis) == 5x + @test typeof(polynomial(i -> 5, basis)) == polynomial_type(basis, Int) + @test typeof(polynomial(ones(Int, 1, 1), basis, Int)) <: + AbstractPolynomial{Int} + @test typeof(polynomial(ones(Int, 1, 1), basis, Float64)) <: + AbstractPolynomial{Float64} +end +@testset "Complex" begin + for a in [1, -1, im, -im] + basis = OrthonormalCoefficientsBasis([a * x]) + @test 5x^2 == + @inferred polynomial(5ones(Int, 1, 1), basis, Complex{Int}) + @test 5x^2 == @inferred polynomial(5ones(Int, 1, 1), basis, Int) + coefficient_test(basis, [2]) + coefficient_test(basis, [-2im]) + coefficient_test(basis, [1 + 5im]) + end +end +@testset "Empty" begin + basis = OrthonormalCoefficientsBasis(typeof(x + 1)[]) + @test isempty(basis) + @test isempty(eachindex(basis)) + p = @inferred polynomial(zeros(Int, 0, 0), basis, Int) + @test iszero(p) +end + +@testset "Enumerate" begin + monos = [1, x, y^2] + basis = OrthonormalCoefficientsBasis(monos) + for (i, e) in enumerate(basis) + @test e == monos[i] + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 3bd703b..d25703b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -12,6 +12,8 @@ function api_test(B::Type{<:AbstractPolynomialBasis}, degree) ] n = binomial(2 + degree, 2) @test length(basis) == n + @test firstindex(basis) == 1 + @test lastindex(basis) == n @test typeof(copy(basis)) == typeof(basis) @test nvariables(basis) == 2 @test variables(basis) == x @@ -87,13 +89,31 @@ function orthogonal_test( end end +function coefficient_test(basis::AbstractPolynomialBasis, p, coefs; kwargs...) + cc = coefficients(p, basis) + @test isapprox(coefs, cc; kwargs...) + @test isapprox(p, polynomial(cc, basis); kwargs...) +end + +function coefficient_test( + basis::AbstractPolynomialBasis, + coefs::AbstractVector; + kwargs..., +) + return coefficient_test( + basis, + sum(generators(basis) .* coefs), + coefs; + kwargs..., + ) +end + function coefficient_test(B::Type{<:AbstractPolynomialBasis}, coefs; kwargs...) @polyvar x y p = x^4 * y^2 + x^2 * y^4 - 3 * x^2 * y^2 + 1 basis = basis_covering_monomials(B, monomials(p)) - cc = coefficients(p, basis) - @test isapprox(coefs, cc; kwargs...) - @test isapprox(p, polynomial(cc, basis); kwargs...) + coefficient_test(basis, p, coefs; kwargs...) + return end @testset "Monomial" begin @@ -105,6 +125,9 @@ end @testset "Fixed" begin include("fixed.jl") end +@testset "Orthonormal" begin + include("orthonormal.jl") +end @testset "Hermite" begin include("hermite.jl") end