From 4b3ae15c6a8c1c8a007f4842ea90b5cf845bbe6f Mon Sep 17 00:00:00 2001 From: Shashi Gowda Date: Fri, 17 May 2024 09:05:11 -0400 Subject: [PATCH] Revert "Revert "Merge pull request #584 from JuliaSymbolics/ale/terminterface-new"" This reverts commit 0b75a322dd686c18398b42904f6571fb73183868. --- Project.toml | 4 +- docs/src/api.md | 9 ---- docs/src/index.md | 11 +---- src/SymbolicUtils.jl | 15 ++++--- src/code.jl | 16 ++++---- src/inspect.jl | 6 +-- src/interface.jl | 10 ++--- src/matchers.jl | 4 +- src/ordering.jl | 4 +- src/polyform.jl | 24 ++++++----- src/rewriters.jl | 7 ++-- src/rule.jl | 4 +- src/simplify.jl | 2 +- src/simplify_rules.jl | 10 ++--- src/substitute.jl | 4 +- src/types.jl | 96 ++++++++++++++++++++++++++----------------- src/utils.jl | 24 ++++++----- test/interface.jl | 2 +- test/runtests.jl | 3 +- 19 files changed, 138 insertions(+), 117 deletions(-) diff --git a/Project.toml b/Project.toml index f51e10333..4bc2bb891 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicUtils" uuid = "d1185830-fcd6-423d-90d6-eec64667417b" authors = ["Shashi Gowda"] -version = "1.7.0" +version = "2.0.0" [deps] AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c" @@ -22,6 +22,7 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" +TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" @@ -42,6 +43,7 @@ Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.10, 1.0, 2" StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" +TermInterface = "0.4" TimerOutputs = "0.5" Unityper = "0.1.2" julia = "1.3" diff --git a/docs/src/api.md b/docs/src/api.md index eb93f3ee0..45266df82 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -12,15 +12,6 @@ SymbolicUtils.Pow SymbolicUtils.promote_symtype ``` -## Interfacing - -```@docs -SymbolicUtils.istree -SymbolicUtils.operation -SymbolicUtils.arguments -SymbolicUtils.similarterm -``` - ## Rewriters ```@docs diff --git a/docs/src/index.md b/docs/src/index.md index 4d48f72fc..85e2dac40 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -108,15 +108,8 @@ g(2//5, g(1, β)) Symbolic expressions are of type `Term{T}`, `Add{T}`, `Mul{T}`, `Pow{T}` or `Div{T}` and denote some function call where one or more arguments are themselves such expressions or `Sym`s. See more about the representation [here](/representation/). -All the expression types support the following: - -- `istree(x)` -- always returns `true` denoting, `x` is not a leaf node like Sym or a literal. -- `operation(x)` -- the function being called -- `arguments(x)` -- a vector of arguments -- `symtype(x)` -- the "inferred" type (`T`) - -See more on the interface [here](/interface) - +All the expression types support the [TermInterface.jl](https://github.com/0x0f0f0f/TermInterface.jl) interface. +Please refer to the package for the complete reference of the interface. ## Term rewriting diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index d7e6a6a80..286ba3bc1 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -4,21 +4,26 @@ $(DocStringExtensions.README) module SymbolicUtils using DocStringExtensions + export @syms, term, showraw, hasmetadata, getmetadata, setmetadata using Unityper - -# Sym, Term, -# Add, Mul and Pow +using TermInterface using DataStructures using Setfield import Setfield: PropertyLens using SymbolicIndexingInterface import Base: +, -, *, /, //, \, ^, ImmutableDict using ConstructionBase -include("interface.jl") +using TermInterface +import TermInterface: iscall, isexpr, issym, symtype, head, children, + operation, arguments, metadata, maketerm + +Base.@deprecate_binding istree iscall +export istree, operation, arguments, unsorted_arguments, similarterm +# Sym, Term, +# Add, Mul and Pow include("types.jl") -export istree, operation, arguments, similarterm # Methods on symbolic objects using SpecialFunctions, NaNMath diff --git a/src/code.jl b/src/code.jl index 9c4e0e24a..208fc91ea 100644 --- a/src/code.jl +++ b/src/code.jl @@ -8,7 +8,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters -import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, istree, operation, arguments, issym, +import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, symtype, similarterm, unsorted_arguments, metadata, isterm, term ##== state management ==## @@ -162,7 +162,7 @@ end toexpr(O::Expr, st) = O function substitute_name(O, st) - if (issym(O) || istree(O)) && haskey(st.rewrites, O) + if (issym(O) || iscall(O)) && haskey(st.rewrites, O) st.rewrites[O] else O @@ -176,13 +176,13 @@ function toexpr(O, st) end O = substitute_name(O, st) - !istree(O) && return O + !iscall(O) && return O op = operation(O) expr′ = function_to_expr(op, O, st) if expr′ !== nothing return expr′ else - !istree(O) && return O + !iscall(O) && return O args = arguments(O) return Expr(:call, toexpr(op, st), map(x->toexpr(x, st), args)...) end @@ -221,7 +221,7 @@ get_rewrites(args::DestructuredArgs) = () function get_rewrites(args::Union{AbstractArray, Tuple}) cflatten(map(get_rewrites, args)) end -get_rewrites(x) = istree(x) ? (x,) : () +get_rewrites(x) = iscall(x) ? (x,) : () cflatten(x) = Iterators.flatten(x) |> collect # Used in Symbolics @@ -691,7 +691,7 @@ end @inline newsym(::Type{T}) where T = Sym{T}(gensym("cse")) function _cse!(mem, expr) - istree(expr) || return expr + iscall(expr) || return expr op = _cse!(mem, operation(expr)) args = map(Base.Fix1(_cse!, mem), arguments(expr)) t = similarterm(expr, op, args) @@ -742,7 +742,7 @@ end function cse_state!(state, t) - !istree(t) && return t + !iscall(t) && return t state[t] = Base.get(state, t, 0) + 1 foreach(x->cse_state!(state, x), unsorted_arguments(t)) end @@ -758,7 +758,7 @@ function cse_block!(assignments, counter, names, name, state, x) counter[] += 1 return sym end - elseif istree(x) + elseif iscall(x) args = map(a->cse_block!(assignments, counter, names, name, state,a), unsorted_arguments(x)) if isterm(x) return term(operation(x), args...) diff --git a/src/inspect.jl b/src/inspect.jl index f62551893..42b0b1be5 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -2,11 +2,11 @@ import AbstractTrees const inspect_metadata = Ref{Bool}(false) function AbstractTrees.nodevalue(x::Symbolic) - istree(x) ? operation(x) : x + iscall(x) ? operation(x) : isexpr(x) ? head(x) : x end function AbstractTrees.nodevalue(x::BasicSymbolic) - str = if !istree(x) + str = if !iscall(x) string(exprtype(x), "(", x, ")") elseif isadd(x) string(exprtype(x), @@ -27,7 +27,7 @@ function AbstractTrees.nodevalue(x::BasicSymbolic) end function AbstractTrees.children(x::Symbolic) - istree(x) ? arguments(x) : () + iscall(x) ? arguments(x) : isexpr(x) ? children(x) : () end """ diff --git a/src/interface.jl b/src/interface.jl index 255ef584f..687a802a8 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,10 +1,10 @@ """ - istree(x) + iscall(x) Returns `true` if `x` is a term. If true, `operation`, `arguments` must also be defined for `x` appropriately. """ -istree(x) = false +iscall(x) = false """ symtype(x) @@ -29,7 +29,7 @@ issym(x) = false """ operation(x) -If `x` is a term as defined by `istree(x)`, `operation(x)` returns the +If `x` is a term as defined by `iscall(x)`, `operation(x)` returns the head of the term if `x` represents a function call, for example, the head is the function being called. """ @@ -38,14 +38,14 @@ function operation end """ arguments(x) -Get the arguments of `x`, must be defined if `istree(x)` is `true`. +Get the arguments of `x`, must be defined if `iscall(x)` is `true`. """ function arguments end """ unsorted_arguments(x::T) -If x is a term satisfying `istree(x)` and your term type `T` provides +If x is a term satisfying `iscall(x)` and your term type `T` provides an optimized implementation for storing the arguments, this function can be used to retrieve the arguments when the order of arguments does not matter but the speed of the operation does. diff --git a/src/matchers.jl b/src/matchers.jl index 531bc1535..7f4dea537 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,7 +6,7 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # function matcher(val::Any) - istree(val) && return term_matcher(val) + iscall(val) && return term_matcher(val) function literal_matcher(next, data, bindings) islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing end @@ -89,7 +89,7 @@ function term_matcher(term) function term_matcher(success, data, bindings) !islist(data) && return nothing - !istree(car(data)) && return nothing + !iscall(car(data)) && return nothing function loop(term, bindings′, matchers′) # Get it to compile faster if !islist(matchers′) diff --git a/src/ordering.jl b/src/ordering.jl index 81d64a9b4..3417f3f85 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -20,7 +20,7 @@ function get_degrees(expr) if issym(expr) ((Symbol(expr),) => 1,) - elseif istree(expr) + elseif iscall(expr) op = operation(expr) args = arguments(expr) if operation(expr) == (^) && args[2] isa Number @@ -62,7 +62,7 @@ function lexlt(degs1, degs2) return false # they are equal end -_arglen(a) = istree(a) ? length(unsorted_arguments(a)) : 0 +_arglen(a) = iscall(a) ? length(unsorted_arguments(a)) : 0 function <ₑ(a::Tuple, b::Tuple) for (x, y) in zip(a, b) diff --git a/src/polyform.jl b/src/polyform.jl index 873e332a8..21e04ac9b 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -6,7 +6,7 @@ using Bijections Abstracts a [MultivariatePolynomials.jl](https://juliaalgebra.github.io/MultivariatePolynomials.jl/stable/) as a SymbolicUtils expression and vice-versa. -The SymbolicUtils term interface (`istree`, `operation, and `arguments`) works on PolyForm lazily: +The SymbolicUtils term interface (`isexpr`/`iscall`, `operation, and `arguments`) works on PolyForm lazily: the `operation` and `arguments` are created by converting one level of arguments into SymbolicUtils expressions. They may further contain PolyForm within them. We use this to hold polynomials in memory while doing `simplify_fractions`. @@ -97,7 +97,7 @@ _isone(p::PolyForm) = isone(p.p) function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse) if x isa Number return x - elseif istree(x) + elseif iscall(x) if !(symtype(x) <: Number) error("Cannot convert $x of symtype $(symtype(x)) into a PolyForm") end @@ -170,8 +170,10 @@ function PolyForm(x, PolyForm{symtype(x)}(p, pvar2sym, sym2term, metadata) end -istree(x::Type{<:PolyForm}) = true -istree(x::PolyForm) = true +isexpr(x::Type{<:PolyForm}) = true +isexpr(x::PolyForm) = true +iscall(x::Type{<:PolyForm}) = true +iscall(x::PolyForm) = true function similarterm(t::PolyForm, f, args, symtype; metadata=nothing) basic_similarterm(t, f, args, symtype; metadata=metadata) @@ -181,6 +183,7 @@ function similarterm(::PolyForm, f::Union{typeof(*), typeof(+), typeof(^)}, f(args...) end +head(::PolyForm) = PolyForm operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+) function arguments(x::PolyForm{T}) where {T} @@ -227,6 +230,7 @@ function arguments(x::PolyForm{T}) where {T} PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts] end end +children(x::PolyForm) = [operation(x); arguments(x)] Base.show(io::IO, x::PolyForm) = show_term(io, x) @@ -336,7 +340,7 @@ function simplify_fractions(x; polyform=false) end function add_with_div(x, flatten=true) - (!istree(x) || operation(x) != (+)) && return x + (!iscall(x) || operation(x) != (+)) && return x aa = unsorted_arguments(x) !any(a->isdiv(a), aa) && return x # no rewrite necessary @@ -361,7 +365,7 @@ function flatten_fractions(x) end function fraction_iszero(x) - !istree(x) && return _iszero(x) + !iscall(x) && return _iszero(x) ff = flatten_fractions(x) # fast path and then slow path any(_iszero, numerators(ff)) || @@ -369,18 +373,18 @@ function fraction_iszero(x) end function fraction_isone(x) - !istree(x) && return _isone(x) + !iscall(x) && return _isone(x) _isone(simplify_fractions(flatten_fractions(x))) end function needs_div_rules(x) (isdiv(x) && !(x.num isa Number) && !(x.den isa Number)) || - (istree(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) || - (istree(x) && any(needs_div_rules, unsorted_arguments(x))) + (iscall(x) && operation(x) === (+) && count(has_div, unsorted_arguments(x)) > 1) || + (iscall(x) && any(needs_div_rules, unsorted_arguments(x))) end function has_div(x) - return isdiv(x) || (istree(x) && any(has_div, unsorted_arguments(x))) + return isdiv(x) || (iscall(x) && any(has_div, unsorted_arguments(x))) end flatten_pows(xs) = map(xs) do x diff --git a/src/rewriters.jl b/src/rewriters.jl index fb6ab3a08..81ae2dbe0 100644 --- a/src/rewriters.jl +++ b/src/rewriters.jl @@ -31,6 +31,7 @@ rewriters. """ module Rewriters using SymbolicUtils: @timer +using TermInterface import SymbolicUtils: similarterm, istree, operation, arguments, unsorted_arguments, metadata, node_count export Empty, IfElse, If, Chain, RestartedChain, Fixpoint, Postwalk, Prewalk, PassThrough @@ -196,7 +197,7 @@ instrument(x::PassThrough, f) = PassThrough(instrument(x.rw, f)) passthrough(x, default) = x === nothing ? default : x function (p::Walk{ord, C, F, false})(x) where {ord, C, F} @assert ord === :pre || ord === :post - if istree(x) + if iscall(x) if ord === :pre x = p.rw(x) end @@ -214,11 +215,11 @@ end function (p::Walk{ord, C, F, true})(x) where {ord, C, F} @assert ord === :pre || ord === :post - if istree(x) + if iscall(x) if ord === :pre x = p.rw(x) end - if istree(x) + if iscall(x) _args = map(arguments(x)) do arg if node_count(arg) > p.thread_cutoff Threads.@spawn p(arg) diff --git a/src/rule.jl b/src/rule.jl index 704db596c..05941b764 100644 --- a/src/rule.jl +++ b/src/rule.jl @@ -120,7 +120,7 @@ end getdepth(r::Rule) = r.depth function rule_depth(rule, d=0, maxdepth=0) - if istree(rule) + if iscall(rule) maxdepth = reduce(max, (rule_depth(r, d+1, maxdepth) for r in arguments(rule)), init=1) elseif rule isa Slot || rule isa Segment maxdepth = max(d, maxdepth) @@ -389,7 +389,7 @@ Base.show(io::IO, acr::ACRule) = print(io, "ACRule(", acr.rule, ")") function (acr::ACRule)(term) r = Rule(acr) - if !istree(term) + if !iscall(term) r(term) else f = operation(term) diff --git a/src/simplify.jl b/src/simplify.jl index 0c0bfd44a..87bc95954 100644 --- a/src/simplify.jl +++ b/src/simplify.jl @@ -43,7 +43,7 @@ function simplify(x; SymbolicUtils.simplify_fractions(x) : x end -has_operation(x, op) = (istree(x) && (operation(x) == op || +has_operation(x, op) = (iscall(x) && (operation(x) == op || any(a->has_operation(a, op), unsorted_arguments(x)))) diff --git a/src/simplify_rules.jl b/src/simplify_rules.jl index b12a9f272..a612036cb 100644 --- a/src/simplify_rules.jl +++ b/src/simplify_rules.jl @@ -2,9 +2,9 @@ using .Rewriters """ is_operation(f) Returns a single argument anonymous function predicate, that returns `true` if and only if -the argument to the predicate satisfies `istree` and `operation(x) == f` +the argument to the predicate satisfies `iscall` and `operation(x) == f` """ -is_operation(f) = @nospecialize(x) -> istree(x) && (operation(x) == f) +is_operation(f) = @nospecialize(x) -> iscall(x) && (operation(x) == f) let CANONICALIZE_PLUS = [ @@ -132,7 +132,7 @@ let ] function number_simplifier() - rule_tree = [If(istree, Chain(ASSORTED_RULES)), + rule_tree = [If(iscall, Chain(ASSORTED_RULES)), If(x -> !isadd(x) && is_operation(+)(x), Chain(CANONICALIZE_PLUS)), If(is_operation(+), Chain(PLUS_DISTRIBUTE)), # This would be useful even if isadd @@ -173,12 +173,12 @@ let end # reduce overhead of simplify by defining these as constant - serial_simplifier = If(istree, Fixpoint(default_simplifier())) + serial_simplifier = If(iscall, Fixpoint(default_simplifier())) threaded_simplifier(cutoff) = Fixpoint(default_simplifier(threaded=true, thread_cutoff=cutoff)) - serial_expand_simplifier = If(istree, + serial_expand_simplifier = If(iscall, Fixpoint(Chain((expand, Fixpoint(default_simplifier()))))) diff --git a/src/substitute.jl b/src/substitute.jl index 9a5213f0b..73ea7659d 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -16,7 +16,7 @@ julia> substitute(1+sqrt(y), Dict(y => 2), fold=false) function substitute(expr, dict; fold=true) haskey(dict, expr) && return dict[expr] - if istree(expr) + if iscall(expr) op = substitute(operation(expr), dict; fold=fold) if fold canfold = !(op isa Symbolic) @@ -53,7 +53,7 @@ Base.occursin(needle::Symbolic, haystack) = _occursin(needle, haystack) function _occursin(needle, haystack) isequal(needle, haystack) && return true - if istree(haystack) + if iscall(haystack) args = arguments(haystack) for arg in args occursin(needle, arg) && return true diff --git a/src/types.jl b/src/types.jl index 2202a8df1..f9ea33121 100644 --- a/src/types.jl +++ b/src/types.jl @@ -114,6 +114,8 @@ symtype(x::Number) = typeof(x) end end +@inline head(x::BasicSymbolic) = BasicSymbolic + function arguments(x::BasicSymbolic) args = unsorted_arguments(x) @compactified x::BasicSymbolic begin @@ -135,6 +137,9 @@ function arguments(x::BasicSymbolic) end return args end + +unsorted_arguments(x) = arguments(x) +children(x::BasicSymbolic) = [operation(x); arguments(x)] function unsorted_arguments(x::BasicSymbolic) @compactified x::BasicSymbolic begin Term => return x.arguments @@ -157,7 +162,7 @@ function unsorted_arguments(x::BasicSymbolic) if isadd(x) for (k, v) in x.dict push!(args, applicable(*,k,v) ? k*v : - similarterm(k, *, [k, v])) + maketerm(k, *, [k, v])) end else # MUL for (k, v) in x.dict @@ -183,7 +188,9 @@ function unsorted_arguments(x::BasicSymbolic) return args end -istree(s::BasicSymbolic) = !issym(s) +isexpr(s::BasicSymbolic) = !issym(s) +iscall(s::BasicSymbolic) = isexpr(s) + @inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false issym(x::BasicSymbolic) = isa_SymType(Val(:Sym), x) isterm(x) = isa_SymType(Val(:Term), x) @@ -400,7 +407,7 @@ end @inline function numerators(x) isdiv(x) && return numerators(x.num) - istree(x) && operation(x) === (*) ? arguments(x) : Any[x] + iscall(x) && operation(x) === (*) ? arguments(x) : Any[x] end @inline denominators(x) = isdiv(x) ? numerators(x.den) : Any[1] @@ -516,7 +523,7 @@ end Binarizes `Term`s with n-ary operations """ function unflatten(t::Symbolic{T}) where{T} - if istree(t) + if iscall(t) f = operation(t) if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops a = arguments(t) @@ -528,28 +535,12 @@ end unflatten(t) = t -""" - similarterm(t, f, args, symtype; metadata=nothing) - -Create a term that is similar in type to `t`. Extending this function allows packages -using their own expression types with SymbolicUtils to define how new terms should -be created. Note that `similarterm` may return an object that has a -different type than `t`, because `f` also influences the result. - -## Arguments +function TermInterface.maketerm(::Type{<:BasicSymbolic}, head, args, type, metadata) + basicsymbolic(first(args), args[2:end], type, metadata) +end -- `t` the reference term to use to create similar terms -- `f` is the operation of the term -- `args` is the arguments -- The `symtype` of the resulting term. Best effort will be made to set the symtype of the - resulting similar term to this type. -""" -similarterm(t::Symbolic, f, args; metadata=nothing) = - similarterm(t, f, args, _promote_symtype(f, args); metadata=metadata) -similarterm(t::BasicSymbolic, f, args, - symtype; metadata=nothing) = basic_similarterm(t, f, args, symtype; metadata=metadata) -function basic_similarterm(t, f, args, stype; metadata=nothing) +function basicsymbolic(f, args, stype, metadata) if f isa Symbol error("$f must not be a Symbol") end @@ -559,7 +550,7 @@ function basic_similarterm(t, f, args, stype; metadata=nothing) end if T <: LiteralReal Term{T}(f, args, metadata=metadata) - elseif stype <: Number && (f in (+, *) || (f in (/, ^) && length(args) == 2)) && all(x->symtype(x) <: Number, args) + elseif T <: Number && (f in (+, *) || (f in (/, ^) && length(args) == 2)) && all(x->symtype(x) <: Number, args) res = f(args...) if res isa Symbolic @set! res.metadata = metadata @@ -580,16 +571,17 @@ function hasmetadata(s::Symbolic, ctx) metadata(s) isa AbstractDict && haskey(metadata(s), ctx) end -function issafecanon(f, s) +issafecanon(f, s) = true +function issafecanon(f, s::Symbolic) if isnothing(metadata(s)) || issym(s) return true else _issafecanon(f, s) end end -_issafecanon(::typeof(*), s) = !istree(s) || !(operation(s) in (+,*,^)) -_issafecanon(::typeof(+), s) = !istree(s) || !(operation(s) in (+,*)) -_issafecanon(::typeof(^), s) = !istree(s) || !(operation(s) in (*, ^)) +_issafecanon(::typeof(*), s) = !iscall(s) || !(operation(s) in (+,*,^)) +_issafecanon(::typeof(+), s) = !iscall(s) || !(operation(s) in (+,*)) +_issafecanon(::typeof(^), s) = !iscall(s) || !(operation(s) in (*, ^)) issafecanon(f, ss...) = all(x->issafecanon(f, x), ss) @@ -641,12 +633,42 @@ end function to_symbolic(x) Base.depwarn("`to_symbolic(x)` is deprecated, define the interface for your " * - "symbolic structure using `istree(x)`, `operation(x)`, `arguments(x)` " * + "symbolic structure using `iscall(x)`, `operation(x)`, `arguments(x)` " * "and `similarterm(::YourType, f, args, symtype)`", :to_symbolic, force=true) x end +""" + similarterm(x, op, args, symtype=nothing; metadata=nothing) + +""" +function similarterm(x, op, args, symtype=nothing; metadata=nothing) + Base.depwarn("""`similarterm` is deprecated, use `maketerm` instead. + See https://github.com/JuliaSymbolics/TermInterface.jl for details. + The present call can be replaced by + `maketerm(typeof(x), $(head(x)), [op, args...], symtype, metadata)`""", :similarterm) + + TermInterface.maketerm(typeof(x), callhead(x), [op, args...], symtype, metadata) +end + +# Old fallback +function similarterm(T::Type, op, args, symtype=nothing; metadata=nothing) + Base.depwarn("`similarterm` is deprecated, use `maketerm` instead." * + "See https://github.com/JuliaSymbolics/TermInterface.jl for details.", :similarterm) + op(args...) +end + +export similarterm + + +""" + callhead(x) +Used in this deprecation cycle of `similarterm` to find the `head` argument to +`maketerm`. Do not implement this, or use `similarterm` if you're using this package. +""" +callhead(x) = typeof(x) + ### ### Pretty printing ### @@ -654,7 +676,7 @@ const show_simplified = Ref(false) isnegative(t::Real) = t < 0 function isnegative(t) - if istree(t) && operation(t) === (*) + if iscall(t) && operation(t) === (*) coeff = first(arguments(t)) return isnegative(coeff) end @@ -666,7 +688,7 @@ setargs(t, args) = Term{symtype(t)}(operation(t), args) cdrargs(args) = setargs(t, cdr(args)) print_arg(io, x::Union{Complex, Rational}; paren=true) = print(io, "(", x, ")") -isbinop(f) = istree(f) && !istree(operation(f)) && Base.isbinaryoperator(nameof(operation(f))) +isbinop(f) = iscall(f) && !iscall(operation(f)) && Base.isbinaryoperator(nameof(operation(f))) function print_arg(io, x; paren=false) if paren && isbinop(x) print(io, "(", x, ")") @@ -685,7 +707,7 @@ function print_arg(io, f, x) end function remove_minus(t) - !istree(t) && return -t + !iscall(t) && return -t @assert operation(t) == (*) args = arguments(t) @assert args[1] < 0 @@ -756,9 +778,9 @@ function show_ref(io, f, args) x = args[1] idx = args[2:end] - istree(x) && print(io, "(") + iscall(x) && print(io, "(") print(io, x) - istree(x) && print(io, ")") + iscall(x) && print(io, ")") print(io, "[") for i=1:length(idx) print_arg(io, idx[i]) @@ -768,7 +790,7 @@ function show_ref(io, f, args) end function show_call(io, f, args) - fname = istree(f) ? Symbol(repr(f)) : nameof(f) + fname = iscall(f) ? Symbol(repr(f)) : nameof(f) len_args = length(args) if Base.isunaryoperator(fname) && len_args == 1 print(io, "$fname") @@ -810,7 +832,7 @@ function show_term(io::IO, t) show_pow(io, args) elseif f === (getindex) show_ref(io, f, args) - elseif f === (identity) && !issym(args[1]) && !istree(args[1]) + elseif f === (identity) && !issym(args[1]) && !iscall(args[1]) show(io, args[1]) else show_call(io, f, args) diff --git a/src/utils.jl b/src/utils.jl index acf9e92d6..90d7c407c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,7 +35,7 @@ pow(x::Symbolic,y::Symbolic) = Base.:^(x,y) # Simplification utilities function has_trig_exp(term) - !istree(term) && return false + !iscall(term) && return false fns = (sin, cos, tan, cot, sec, csc, exp, cosh, sinh) op = operation(term) @@ -47,7 +47,7 @@ function has_trig_exp(term) end function fold(t) - if istree(t) + if iscall(t) tt = map(fold, arguments(t)) if !any(x->x isa Symbolic, tt) # evaluate it @@ -81,7 +81,7 @@ function isnotflat(⋆) function (x) args = arguments(x) for t in args - if istree(t) && operation(t) === (⋆) + if iscall(t) && operation(t) === (⋆) return true end end @@ -141,7 +141,7 @@ function flatten_term(⋆, x) # flatten nested ⋆ flattened_args = [] for t in args - if istree(t) && operation(t) === (⋆) + if iscall(t) && operation(t) === (⋆) append!(flattened_args, arguments(t)) else push!(flattened_args, t) @@ -170,7 +170,7 @@ struct LL{V} i::Int end -islist(x) = istree(x) || !isempty(x) +islist(x) = iscall(x) || !isempty(x) Base.empty(l::LL) = empty(l.v) Base.isempty(l::LL) = l.i > length(l.v) @@ -184,9 +184,9 @@ Base.isempty(t::Term) = false @inline car(t::Term) = operation(t) @inline cdr(t::Term) = arguments(t) -@inline car(v) = istree(v) ? operation(v) : first(v) +@inline car(v) = iscall(v) ? operation(v) : first(v) @inline function cdr(v) - if istree(v) + if iscall(v) arguments(v) else islist(v) ? LL(v, 2) : error("asked cdr of empty") @@ -200,7 +200,7 @@ end if n === 0 return ll else - istree(ll) ? drop_n(arguments(ll), n-1) : drop_n(cdr(ll), n-1) + iscall(ll) ? drop_n(arguments(ll), n-1) : drop_n(cdr(ll), n-1) end end @inline drop_n(ll::Union{Tuple, AbstractArray}, n) = drop_n(LL(ll, 1), n) @@ -218,10 +218,12 @@ macro matchable(expr) get_name(e::Expr) = (@assert(e.head == :(::)); e.args[1]) fields = map(get_name, fields) quote + # TODO: fix this to be not a call. Make pattern matcher work for these $expr - SymbolicUtils.istree(::$name) = true + SymbolicUtils.head(::$name) = $name SymbolicUtils.operation(::$name) = $name SymbolicUtils.arguments(x::$name) = getfield.((x,), ($(QuoteNode.(fields)...),)) + SymbolicUtils.children(x::$name) = [SymbolicUtils.operation(x); SymbolicUtils.children(x)] Base.length(x::$name) = $(length(fields) + 1) SymbolicUtils.similarterm(x::$name, f, args, type; kw...) = f(args...) end |> esc @@ -229,7 +231,7 @@ end """ node_count(t) -Count the nodes in a symbolic expression tree satisfying `istree` and `arguments`. +Count the nodes in a symbolic expression tree satisfying `iscall` and `arguments`. """ -node_count(t) = istree(t) ? reduce(+, node_count(x) for x in arguments(t), init = 0) + 1 : 1 +node_count(t) = iscall(t) ? reduce(+, node_count(x) for x in arguments(t), init = 0) + 1 : 1 diff --git a/test/interface.jl b/test/interface.jl index d98d97328..af83fc89f 100644 --- a/test/interface.jl +++ b/test/interface.jl @@ -1,5 +1,5 @@ using SymbolicUtils, Test -import SymbolicUtils: istree, issym, operation, arguments, symtype +import SymbolicUtils: iscall, issym, operation, arguments, symtype issym(s::Symbol) = true Base.nameof(s::Symbol) = s diff --git a/test/runtests.jl b/test/runtests.jl index 004b26b7d..3098331ab 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -33,6 +33,7 @@ else include("code.jl") include("cse.jl") include("interface.jl") - include("fuzz.jl") + # Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed + # include("fuzz.jl") include("adjoints.jl") end