diff --git a/Project.toml b/Project.toml index d18b95487..cd64702bc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicUtils" uuid = "d1185830-fcd6-423d-90d6-eec64667417b" authors = ["Shashi Gowda"] -version = "2.1.2" +version = "3.0.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -43,7 +43,7 @@ Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.10, 1.0, 2" StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" -TermInterface = "0.4" +TermInterface = "2.0" TimerOutputs = "0.5" Unityper = "0.1.2" julia = "1.3" diff --git a/docs/src/manual/interface.md b/docs/src/manual/interface.md index 899aff658..601bdaaa4 100644 --- a/docs/src/manual/interface.md +++ b/docs/src/manual/interface.md @@ -13,6 +13,8 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym ## SymbolicUtils.jl only methods -`promote_symtype(f, arg_symtypes...)` - -Returns the appropriate output type of applying `f` on arguments of type `arg_symtypes`. +```@docs +symtype +issym +promote_symtype +``` diff --git a/docs/src/manual/representation.md b/docs/src/manual/representation.md index 997d33f3a..fea21bf1b 100644 --- a/docs/src/manual/representation.md +++ b/docs/src/manual/representation.md @@ -4,7 +4,7 @@ Performance of symbolic simplification depends on the datastructures used to rep The most basic term representation simply holds a function call and stores the function and the arguments it is called with. This is done by the `Term` type in SymbolicUtils. Functions that aren't commutative or associative, such as `sin` or `hypot` are stored as `Term`s. Commutative and associative operations like `+`, `*`, and their supporting operations like `-`, `/` and `^`, when used on terms of type `<:Number`, stand to gain from the use of more efficient datastrucutres. -All term representations must support `operation` and `arguments` functions. And they must define `istree` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.) +All term representations must support `operation` and `arguments` functions. And they must define `iscall` and `isexpr` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.) ### Preliminary representation of arithmetic diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index c7c660a8f..32e94ac18 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -16,8 +16,8 @@ using SymbolicIndexingInterface import Base: +, -, *, /, //, \, ^, ImmutableDict using ConstructionBase using TermInterface -import TermInterface: iscall, isexpr, issym, symtype, head, children, - operation, arguments, metadata, maketerm +import TermInterface: iscall, isexpr, head, children, + operation, arguments, metadata, maketerm, sorted_arguments Base.@deprecate istree iscall export istree, operation, arguments, sorted_arguments, similarterm, iscall diff --git a/src/code.jl b/src/code.jl index 6432bd1f5..4e1589ed9 100644 --- a/src/code.jl +++ b/src/code.jl @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, - symtype, similarterm, sorted_arguments, metadata, isterm, term, maketerm + symtype, sorted_arguments, metadata, isterm, term, maketerm ##== state management ==## @@ -694,7 +694,7 @@ function _cse!(mem, expr) iscall(expr) || return expr op = _cse!(mem, operation(expr)) args = map(Base.Fix1(_cse!, mem), arguments(expr)) - t = similarterm(expr, op, args) + t = maketerm(typeof(expr), op, args, nothing) v, dict = mem update! = let v=v, t=t @@ -763,9 +763,7 @@ function cse_block!(assignments, counter, names, name, state, x) if isterm(x) return term(operation(x), args...) else - return maketerm(typeof(x), operation(x), - args, symtype(x), - metadata(x)) + return maketerm(typeof(x), operation(x), args, metadata(x)) end else return x diff --git a/src/interface.jl b/src/interface.jl deleted file mode 100644 index bea1d47ae..000000000 --- a/src/interface.jl +++ /dev/null @@ -1,84 +0,0 @@ -""" - iscall(x) - -Returns `true` if `x` is a term. If true, `operation`, `arguments` -must also be defined for `x` appropriately. -""" -iscall(x) = false - -""" - symtype(x) - -Returns the symbolic type of `x`. By default this is just `typeof(x)`. -Define this for your symbolic types if you want `SymbolicUtils.simplify` to apply rules -specific to numbers (such as commutativity of multiplication). Or such -rules that may be implemented in the future. -""" -function symtype(x) - typeof(x) -end - -""" - issym(x) - -Returns `true` if `x` is a symbol. If true, `nameof` must be defined -on `x` and must return a Symbol. -""" -issym(x) = false - -""" - operation(x) - -If `x` is a term as defined by `iscall(x)`, `operation(x)` returns the -head of the term if `x` represents a function call, for example, the head -is the function being called. -""" -function operation end - -""" - sorted_arguments(x) - -Get the arguments of `x`, must be defined if `iscall(x)` is `true`. -""" -function sorted_arguments end - -""" - sorted_arguments(x::T) - -If x is a term satisfying `iscall(x)` and your term type `T` provides -an optimized implementation for storing the arguments, this function can -be used to retrieve the arguments when the order of arguments does not matter -but the speed of the operation does. -""" -function arguments end -arity(x) = length(arguments(x)) - -""" - metadata(x) - -Return the metadata attached to `x`. -""" -metadata(x) = nothing - -""" - metadata(x, md) - -Returns a new term which has the structure of `x` but also has -the metadata `md` attached to it. -""" -function metadata(x, data) - error("Setting metadata on $x is not possible") -end - -""" - similarterm(x, head, args, symtype=nothing; metadata=nothing, exprhead=:call) - -Returns a term that is in the same closure of types as `typeof(x)`, -with `head` as the head and `args` as the arguments, `type` as the symtype -and `metadata` as the metadata. By default this will execute `head(args...)`. -`x` parameter can also be a `Type`. The `exprhead` keyword argument is useful -when manipulating `Expr`s. - -`similarterm` is deprecated see help for `maketerm` instead. -""" -function similarterm end diff --git a/src/polyform.jl b/src/polyform.jl index 227755c5f..7d6bc906e 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -121,7 +121,6 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) maketerm(typeof(x), op, map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args), - symtype(x), metadata(x)) else x @@ -176,18 +175,18 @@ isexpr(x::PolyForm) = true iscall(x::Type{<:PolyForm}) = true iscall(x::PolyForm) = true -function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata) - basicsymbolic(t, f, args, symtype, metadata) +function maketerm(t::Type{<:PolyForm}, f, args, metadata) + # TODO: this looks uncovered. + basicsymbolic(f, args, nothing, metadata) end -function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, - args, symtype, metadata) +function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, args, metadata) f(args...) end head(::PolyForm) = PolyForm operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+) -function arguments(x::PolyForm{T}) where {T} +function TermInterface.arguments(x::PolyForm{T}) where {T} function is_var(v) MP.nterms(v) == 1 && @@ -231,10 +230,7 @@ function arguments(x::PolyForm{T}) where {T} PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts] end end - -sorted_arguments(x::PolyForm) = arguments(x) - -children(x::PolyForm) = [operation(x); arguments(x)] +children(x::PolyForm) = arguments(x) Base.show(io::IO, x::PolyForm) = show_term(io, x) @@ -255,7 +251,7 @@ function unpolyize(x) # we need a special maketerm here because the default one used in Postwalk will call # promote_symtype to get the new type, but we just want to forward that in case # promote_symtype is not defined for some of the expressions here. - Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x) + Postwalk(identity, maketerm=(T,f,args,m) -> maketerm(T, f, args, m))(x) end function toterm(x::PolyForm) @@ -307,7 +303,8 @@ function add_divs(x, y) end end -function frac_maketerm(T, f, args, stype, metadata) +function frac_maketerm(T, f, args, metadata) + # TODO add stype to T? if f in (*, /, \, +, -) f(args...) elseif f == (^) @@ -317,7 +314,7 @@ function frac_maketerm(T, f, args, stype, metadata) args[1]^args[2] end else - maketerm(T, f, args, stype, metadata) + maketerm(T, f, args, metadata) end end diff --git a/src/rewriters.jl b/src/rewriters.jl index fe5d2bb04..78efc7ee8 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -167,11 +167,7 @@ end struct Walk{ord, C, F, threaded} rw::C thread_cutoff::Int - maketerm::F # XXX: for the 2.0 deprecation cycle, we actually store a function - # that behaves like `similarterm` here, we use `compatmaker` to wrap - # maketerm-like input to do this, with a warning if similarterm provided - # we need this workaround to deprecate because similarterm takes value - # but maketerm only knows the type. + maketerm::F end function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded} @@ -183,25 +179,13 @@ end using .Threads -function compatmaker(similarterm, maketerm) - # XXX: delete this and only use maketerm in a future release. - if similarterm isa Nothing - function (x, f, args, type=_promote_symtype(f, args); metadata) - maketerm(typeof(x), f, args, type, metadata) - end - else - Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm) - similarterm - end -end -function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) - maker = compatmaker(similarterm, maketerm) - Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) + +function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm) + Walk{:post, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm) end -function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing) - maker = compatmaker(similarterm, maketerm) - Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker) +function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm) + Walk{:pre, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm) end struct PassThrough{C} @@ -220,8 +204,8 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F} end if iscall(x) - x = p.maketerm(x, operation(x), map(PassThrough(p), - arguments(x)), metadata=metadata(x)) + x = p.maketerm(typeof(x), operation(x), map(PassThrough(p), + arguments(x)), metadata(x)) end return ord === :post ? p.rw(x) : x @@ -245,7 +229,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F} end end args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x)) - t = p.maketerm(x, operation(x), args, metadata=metadata(x)) + t = p.maketerm(typeof(x), operation(x), args, metadata(x)) end return ord === :post ? p.rw(t) : t else diff --git a/src/rule.jl b/src/rule.jl index 13fe86c79..e1531bfe2 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -408,7 +408,7 @@ function (acr::ACRule)(term) if result !== nothing # Assumption: inds are unique length(args) == length(inds) && return result - return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], symtype(term), metadata(term)) + return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i ∉ inds)...], metadata(term)) end end end diff --git a/src/simplify.jl b/src/simplify.jl index 695e57c5a..68fe78f83 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -45,6 +45,6 @@ end has_operation(x, op) = (iscall(x) && (operation(x) == op || any(a->has_operation(a, op), - arguments(x)))) + arguments(x)))) Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...) diff --git a/src/substitute.jl b/src/substitute.jl index 51c75e3c4..8fc980c69 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -34,7 +34,6 @@ function substitute(expr, dict; fold=true) maketerm(typeof(expr), op, args, - symtype(expr), metadata(expr)) else expr diff --git a/src/types.jl b/src/types.jl index a46b1bfa0..4b085ef61 100644 --- a/src/types.jl +++ b/src/types.jl @@ -98,8 +98,19 @@ end ### ### Term interface ### -symtype(x::Number) = typeof(x) + +""" +$(SIGNATURES) + +Returns the [numeric type](https://docs.julialang.org/en/v1/base/numbers/#Standard-Numeric-Types) +of `x`. By default this is just `typeof(x)`. +Define this for your symbolic types if you want [`SymbolicUtils.simplify`](@ref) to apply rules +specific to numbers (such as commutativity of multiplication). Or such +rules that may be implemented in the future. +""" +symtype(x) = typeof(x) @inline symtype(::Symbolic{T}) where T = T +@inline symtype(::Type{<:Symbolic{T}}) where T = T # We're returning a function pointer @inline function operation(x::BasicSymbolic) @@ -116,7 +127,7 @@ end @inline head(x::BasicSymbolic) = operation(x) -function sorted_arguments(x::BasicSymbolic) +function TermInterface.sorted_arguments(x::BasicSymbolic) args = arguments(x) @compactified x::BasicSymbolic begin Add => @goto ADD @@ -138,13 +149,11 @@ function sorted_arguments(x::BasicSymbolic) return args end -children(x::BasicSymbolic) = arguments(x) - -sorted_children(x::BasicSymbolic) = sorted_arguments(x) - @deprecate unsorted_arguments(x) arguments(x) -function arguments(x::BasicSymbolic) +TermInterface.children(x::BasicSymbolic) = arguments(x) +TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) +function TermInterface.arguments(x::BasicSymbolic) @compactified x::BasicSymbolic begin Term => return x.arguments Add => @goto ADDMUL @@ -166,7 +175,7 @@ function arguments(x::BasicSymbolic) if isadd(x) for (k, v) in x.dict push!(args, applicable(*,k,v) ? k*v : - maketerm(k, *, [k, v])) + maketerm(k, *, [k, v], nothing)) end else # MUL for (k, v) in x.dict @@ -196,7 +205,14 @@ isexpr(s::BasicSymbolic) = !issym(s) iscall(s::BasicSymbolic) = isexpr(s) @inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false -issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x) + +""" + issym(x) + +Returns `true` if `x` is a `Sym`. If true, `nameof` must be defined +on `x` and must return a `Symbol`. +""" +issym(x) = isa_SymType(Val(:Sym), x) isterm(x) = isa_SymType(Val(:Term), x) ismul(x) = isa_SymType(Val(:Mul), x) isadd(x) = isa_SymType(Val(:Add), x) @@ -539,8 +555,22 @@ end unflatten(t) = t -function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata) - basicsymbolic(head, args, type, metadata) +function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) + st = symtype(T) + pst = _promote_symtype(head, args) + # Use promoted symtype only if not a subtype of the existing symtype of T. + # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])` + # Where the result would have a symtype of Bool. + # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 + # TODO this should be optimized. + new_st = if pst === Bool + pst + elseif pst === Any || (st === Number && pst <: st) + st + else + pst + end + basicsymbolic(head, args, new_st, metadata) end @@ -663,28 +693,6 @@ function to_symbolic(x) x end -""" - similarterm(x, op, args, symtype=nothing; metadata=nothing) - -""" -function similarterm(x, op, args, symtype=nothing; metadata=nothing) - Base.depwarn("""`similarterm` is deprecated, use `maketerm` instead. - `similarterm(x, op, args, symtype; metadata)` is now - `maketerm(typeof(x), op, args, symtype, metadata)`""", :similarterm) - TermInterface.maketerm(typeof(x), op, args, symtype, metadata) -end - -# Old fallback -function similarterm(T::Type, op, args, symtype=nothing; metadata=nothing) - - Base.depwarn("`similarterm` is deprecated, use `maketerm` instead." * - "See https://github.com/JuliaSymbolics/TermInterface.jl for details.", :similarterm) - op(args...) -end - -export similarterm - - ### ### Pretty printing ### diff --git a/src/utils.jl b/src/utils.jl index 69b6e8e2d..812e229fb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -53,7 +53,7 @@ function fold(t) # evaluate it return operation(t)(tt...) else - return maketerm(typeof(t), operation(t), tt, symtype(t), metadata(t)) + return maketerm(typeof(t), operation(t), tt, metadata(t)) end else return t @@ -147,19 +147,19 @@ function flatten_term(⋆, x) push!(flattened_args, t) end end - maketerm(typeof(x), ⋆, flattened_args, symtype(x), metadata(x)) + maketerm(typeof(x), ⋆, flattened_args, metadata(x)) end function sort_args(f, t) args = arguments(t) if length(args) < 2 - return maketerm(typeof(t), f, args, symtype(t), metadata(t)) + return maketerm(typeof(t), f, args, metadata(t)) elseif length(args) == 2 x, y = args - return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], symtype(t), metadata(t)) + return maketerm(typeof(t), f, x <ₑ y ? [x,y] : [y,x], metadata(t)) end args = args isa Tuple ? [args...] : args - maketerm(typeof(t), f, sort(args, lt=<ₑ), symtype(t), metadata(t)) + maketerm(typeof(t), f, sort(args, lt=<ₑ), metadata(t)) end # Linked List interface @@ -225,7 +225,7 @@ macro matchable(expr) SymbolicUtils.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) SymbolicUtils.children(x::$name) = [SymbolicUtils.operation(x); SymbolicUtils.children(x)] Base.length(x::$name) = $(length(fields) + 1) - SymbolicUtils.maketerm(x::$name, f, args, type, metadata) = f(args...) + SymbolicUtils.maketerm(x::$name, f, args, metadata) = f(args...) end |> esc end diff --git a/test/basics.jl b/test/basics.jl index bb2519d56..e59b008e3 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -216,23 +216,39 @@ end @testset "maketerm" begin @syms a b c - @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], Number, nothing).dict, Dict(a=>1,b=>1,c=>1)) - @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], Number, nothing), b) + @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1)) + @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) # test that maketerm doesn't hard-code BasicSymbolic subtype # and is consistent with BasicSymbolic arithmetic operations - @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], Number, nothing), (a / b) * c) - @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], Number, nothing), 0) - @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, nothing), (a * b)^3) + @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], nothing), (a / b) * c) + @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], nothing), 0) + @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], nothing), (a * b)^3) # test that maketerm sets metadata correctly metadata = Base.ImmutableDict{DataType, Any}(Ctx1, "meta_1") - s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], Number, metadata) + s = SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], metadata) @test !hasmetadata(s, Ctx1) - s = SymbolicUtils.maketerm(typeof(a^b), *, [a * b, 3], Number, metadata) + s = SymbolicUtils.maketerm(typeof(a^b), *, [a * b, 3], metadata) @test hasmetadata(s, Ctx1) @test getmetadata(s, Ctx1) == "meta_1" + + # Correct symtype propagation + ref_expr = a * b + @test symtype(ref_expr) == Number + new_expr = SymbolicUtils.maketerm(typeof(ref_expr), (==), [a, b], nothing) + @test symtype(new_expr) == Bool + + # Doesn't know return type, promoted symtype is Any + foo(x,y) = x^2 + x + new_expr = SymbolicUtils.maketerm(typeof(ref_expr), foo, [a, b], nothing) + @test symtype(new_expr) == Number + + # Promoted symtype is a subtype of referred + @syms x::Int y::Int + new_expr = SymbolicUtils.maketerm(typeof(ref_expr), (+), [x, y], nothing) + @test symtype(new_expr) == Int64 end toterm(t) = Term{symtype(t)}(operation(t), arguments(t))