Skip to content

Commit

Permalink
Not specific to Float64
Browse files Browse the repository at this point in the history
  • Loading branch information
blegat committed Jul 3, 2024
1 parent cdb13fc commit fffee1c
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions src/scaled.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,13 +51,13 @@ end
_float(::Type{T}) where {T<:Number} = float(T)
# Could be for instance `MathOptInterface.ScalarAffineFunction{Float64}`
# which does not implement `float`
_float(::Type{F}) where {F} = F
_float(::Type{F}) where {F} = MA.promote_operation(+, F, F)

_promote_coef(::Type{T}, ::Type{ScaledMonomial}) where {T} = _float(T)

function MP.polynomial(f::Function, basis::SubBasis{ScaledMonomial})
function MP.polynomial(f::Function, basis::SubBasis{ScaledMonomial}, ::Type{T}) where {T}
return MP.polynomial(
i -> scaling(basis.monomials[i]) * f(i),
i -> scaling(T, basis.monomials[i]) * f(i),
basis.monomials,
)
end
Expand All @@ -69,12 +69,12 @@ function Base.promote_rule(
return SubBasis{Monomial,M,V}
end

function scaling(m::MP.AbstractMonomial)
return (
function scaling(::Type{T}, m::MP.AbstractMonomial) where {T}
return (T(
factorial(MP.degree(m)) / prod(factorial, MP.exponents(m); init = 1),
)
))
end
unscale_coef(t::MP.AbstractTerm) = MP.coefficient(t) / scaling(MP.monomial(t))
unscale_coef(t::MP.AbstractTerm) = MP.coefficient(t) / scaling(MP.coefficient_type(t), MP.monomial(t))
function SA.coeffs(
t::MP.AbstractTermLike,
::FullBasis{ScaledMonomial},
Expand All @@ -84,10 +84,10 @@ function SA.coeffs(
return MP.term(mono * MP.coefficient(t), mono)
end
function MP.coefficients(p, ::FullBasis{ScaledMonomial})
return unscale_coef.(MP.terms(p))
return unscale_coef.(MP.coefficient_type(p), MP.terms(p))
end
function MP.coefficients(p, basis::SubBasis{ScaledMonomial})
return MP.coefficients(p, basis.monomials) ./ scaling.(MP.monomials(p))
return MP.coefficients(p, basis.monomials) ./ scaling.(MP.coefficient_type(p), MP.monomials(p))
end
function SA.coeffs(
p::MP.AbstractPolynomialLike,
Expand All @@ -109,7 +109,7 @@ function SA.coeffs!(
SA.unsafe_push!(
res,
target[Polynomial{Monomial}(mono)],
v * scaling(mono),
v * scaling(eltype(res), mono),
)
end
MA.operate!(SA.canonical, res)
Expand All @@ -128,7 +128,7 @@ function SA.coeffs!(
SA.unsafe_push!(
res,
target[Polynomial{ScaledMonomial}(mono)],
v / scaling(mono),
v / scaling(eltype(T), mono),
)
end
MA.operate!(SA.canonical, res)
Expand Down

0 comments on commit fffee1c

Please sign in to comment.