From d5a88eef498f30c2624004777cb56a3cc19a45c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Beno=C3=AEt=20Legat?= Date: Wed, 10 Jul 2024 14:22:25 +0200 Subject: [PATCH] Fixes --- src/lagrange.jl | 67 ++++++++++++++++++++++++++++++------------------- 1 file changed, 41 insertions(+), 26 deletions(-) diff --git a/src/lagrange.jl b/src/lagrange.jl index 22cc72d..d60b150 100644 --- a/src/lagrange.jl +++ b/src/lagrange.jl @@ -29,41 +29,57 @@ function sample(s::BoxSampling{T}, n::Integer) where {T} return samples end -struct LagrangePolynomial{T,V} - point::V +struct LagrangePolynomial{T,P,V} + variables::V + point::P end -struct ImplicitLagrangeBasis{T,V,N<:AbstractNodes{T,V}} <: - SA.ImplicitBasis{LagrangePolynomial{T,V},V} - sampling::AbstractNodes{T,V} - function ImplicitLagrangeBasis(nodes::AbstractNodes{T,V}) where {T,V} - return new{T,V,typeof(nodes)}(nodes) +struct ImplicitLagrangeBasis{T,P,N<:AbstractNodes{T,P},V} <: + SA.ImplicitBasis{LagrangePolynomial{T,P,V},Pair{V,P}} + variables::V + nodes::AbstractNodes{T,P} + function ImplicitLagrangeBasis(variables, nodes::AbstractNodes{T,P}) where {T,P} + return new{T,P,typeof(nodes),typeof(variables)}(variables, nodes) end end -struct LagrangeBasis{T,P,V<:AbstractVector{P}} <: - SA.ExplicitBasis{LagrangePolynomial{T,V},Int} - points::V - function LagrangeBasis(points::AbstractVector) +struct LagrangeBasis{T,P,U<:AbstractVector{P},V} <: + SA.ExplicitBasis{LagrangePolynomial{T,P,V},Int} + variables::V + points::U + function LagrangeBasis(variables, points::AbstractVector) P = eltype(points) - return new{eltype(P), P, typeof(points)}(points) + return new{eltype(P),P,typeof(points),typeof(variables)}(variables, points) end end Base.length(basis::LagrangeBasis) = length(basis.points) +MP.nvariables(basis::LagrangeBasis) = length(basis.variables) +MP.variables(basis::LagrangeBasis) = basis.variables function Base.getindex(basis::LagrangeBasis, I::AbstractVector{<:Integer}) - return LagrangeBasis(basis.points[I]) + return LagrangeBasis(basis.variables, basis.points[I]) end -function eval_basis!(univariate_buffer, result, basis::SubBasis{B}, values) where {B} +function explicit_basis_type(::Type{<:ImplicitLagrangeBasis{T,_P,N,V}}) where {T,_P,N,V} + points = eachcol(ones(T, 1, 1)) + P = eltype(points) + return LagrangeBasis{eltype(P),P,typeof(points),V} +end + +function eval_basis!(univariate_buffer, result, basis::SubBasis{B}, variables, values) where {B} + for v in MP.variables(basis) + if !(v in variables) + error("Cannot evaluate `$basis` as its variable `$v` is not part of the variables `$variables` of the `LagrangeBasis`") + end + end for i in eachindex(values) - univariate_eval!(B, view(univariate_buffer, :, i), values[i]) + l = MP.maxdegree(basis.monomials, variables[i]) + 1 + univariate_eval!(B, view(univariate_buffer, 1:l, i), values[i]) end for i in eachindex(basis) result[i] = one(eltype(result)) - exp = MP.exponents(basis.monomials[i]) - @assert length(exp) == length(values) for j in eachindex(values) - result[i] = MA.operate!!(*, result[i], univariate_buffer[exp[j] + 1, j]) + d = MP.degree(basis.monomials[i], variables[j]) + result[i] = MA.operate!!(*, result[i], univariate_buffer[d + 1, j]) end end return result @@ -72,10 +88,10 @@ end function transformation_to(basis::SubBasis, lag::LagrangeBasis{T}) where {T} # To avoid allocating this too often, we allocate it once here # and reuse it for each sample - univariate_buffer = Matrix{T}(undef, length(basis), MP.nvariables(basis)) + univariate_buffer = Matrix{T}(undef, length(basis), MP.nvariables(lag)) V = Matrix{T}(undef, length(lag), length(basis)) for i in eachindex(lag) - eval_basis!(univariate_buffer, view(V, i, :), basis, lag.points[i]) + eval_basis!(univariate_buffer, view(V, i, :), basis, MP.variables(lag), lag.points[i]) end return V end @@ -96,21 +112,20 @@ function num_samples(sample_factor, dim) return sample_factor * dim end -function sample(s::AbstractNodes, basis::SubBasis) - full = LagrangeBasis(eachcol(sample(s, num_samples(s.sample_factor, length(basis))))) +function sample(variables, s::AbstractNodes, basis::SubBasis) + samples = sample(s, num_samples(s.sample_factor, length(basis))) + full = LagrangeBasis(variables, eachcol(samples)) V = transformation_to(basis, full) - display(V) F = LinearAlgebra.qr!(Matrix(V'), LinearAlgebra.ColumnNorm()) - display(F) kept_indices = F.p[1:length(basis)] - return full[kept_indices] + return LagrangeBasis(variables, eachcol(samples[:, kept_indices])) end function explicit_basis_covering( implicit::ImplicitLagrangeBasis, basis::SubBasis, ) - return sample(implicit.sampling, basis) + return sample(implicit.variables, implicit.nodes, basis) end function SA.coeffs(coeffs, source::SubBasis, target::LagrangeBasis)