diff --git a/src/cmult.jl b/src/cmult.jl index aea7228..acc5a77 100644 --- a/src/cmult.jl +++ b/src/cmult.jl @@ -28,13 +28,7 @@ function _operate_exponents_to!(output::Vector{Int}, op::F, z1::Vector{Int}, z2: @. output = op(z1, z2) return end -function _operate_exponents_to!(output::Vector{Int}, op::F, z1::Vector{Int}, z2::Vector{Int}, n) where {F<:Function} - resize!(output, n) - @. output = op(z1, z2) - return -end -function _operate_exponents_to!(output::Vector{Int}, op::F, z1::Vector{Int}, z2::Vector{Int}, n, maps) where {F<:Function} - resize!(output, n) +function _operate_exponents_to!(output::Vector{Int}, op::F, z1::Vector{Int}, z2::Vector{Int}, maps) where {F<:Function} I = maps[1]; i = 1; lI = length(I) J = maps[2]; j = 1; lJ = length(J) while i <= lI || j <= lJ @@ -42,11 +36,11 @@ function _operate_exponents_to!(output::Vector{Int}, op::F, z1::Vector{Int}, z2: output[J[j]] = op(0, z2[j]) j += 1 elseif j > lJ || (i <= lI && I[i] < J[j]) - output[I[i]] = op(z1[I[i]], 0) + output[I[i]] = op(z1[i], 0) i += 1 else @assert I[i] == J[j] - output[I[i]] = op(z1[I[i]], z2[j]) + output[I[i]] = op(z1[i], z2[j]) i += 1 j += 1 end @@ -57,19 +51,19 @@ end #function _operate_exponents!(op::F, Z::Vector{Vector{Int}}, z2::Vector{Int}, args::Vararg{Any,N}) where {F<:Function,N} # return Vector{Int}[_operate_exponents!(op, z, z2, args...) for z in Z] #end -function multdivmono!(output, output_variables::Vector{PolyVar{true}}, +function _multdivmono!(output, output_variables::Vector{PolyVar{true}}, v::Vector{PolyVar{true}}, x::Monomial{true}, op, z) if v == x.vars - if output_variables == v - _operate_exponents_to!(output, op, z, x.z) - else + if output_variables != v resize!(output_variables, length(v)) copyto!(output_variables, v) - _operate_exponents_to!(output, op, z, x.z, length(output_variables)) + resize!(output, length(output_variables)) end + _operate_exponents_to!(output, op, z, x.z) else maps = mergevars_to!(output_variables, [v, x.vars]) - _operate_exponents_to!(output, op, z, x.z, length(output_variables), maps) + resize!(output, length(output_variables)) + _operate_exponents_to!(output, op, z, x.z, maps) end return end @@ -110,11 +104,25 @@ function multdivmono(v::Vector{PolyVar{true}}, x::Monomial{true}, op, z) return w, z_new end function MP.mapexponents_to!(output::Monomial{true}, f::Function, x::Monomial{true}, y::Monomial{true}) - multdivmono!(output.z, output.vars, x.vars, y, f, x.z) + if x.vars == y.vars + if output.vars != x.vars + n = length(x.vars) + resize!(output.vars, n) + copyto!(output.vars, x.vars) + resize!(output.z, n) + end + _operate_exponents_to!(x.z, f, x.z, y.z) + else + _multdivmono!(output.z, output.vars, x.vars, y, f, x.z) + end return output end function MP.mapexponents!(f::Function, x::Monomial{true}, y::Monomial{true}) - multdivmono!(x.z, x.vars, x.vars, y, f, x.z) + if x.vars == y.vars + _operate_exponents_to!(x.z, f, x.z, y.z) + else + _multdivmono!(x.z, x.vars, copy(x.vars), y, f, copy(x.z)) + end return x end function MP.mapexponents(f::Function, x::Monomial{true}, y::Monomial{true}) diff --git a/test/mono.jl b/test/mono.jl index a7f7d01..89b60aa 100644 --- a/test/mono.jl +++ b/test/mono.jl @@ -133,4 +133,8 @@ @test (x^2)(3) == 9 @test (x)(3) == 3 end + @testset "TODO remove when added to MP" begin + @polyvar x y + @test x == DynamicPolynomials.MP.mapexponents!(div, x^1, x * y^2) + end end