Skip to content

Commit

Permalink
Merge pull request #609 from JuliaSymbolics/ale/terminterface1
Browse files Browse the repository at this point in the history
TermInterface Version 2
  • Loading branch information
ChrisRackauckas authored Jul 27, 2024
2 parents 7a057d6 + b11fbec commit da9267e
Show file tree
Hide file tree
Showing 14 changed files with 104 additions and 184 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicUtils"
uuid = "d1185830-fcd6-423d-90d6-eec64667417b"
authors = ["Shashi Gowda"]
version = "2.1.2"
version = "3.0.0"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -43,7 +43,7 @@ Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.10, 1.0, 2"
StaticArrays = "0.12, 1.0"
SymbolicIndexingInterface = "0.3"
TermInterface = "0.4"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
julia = "1.3"
Expand Down
8 changes: 5 additions & 3 deletions docs/src/manual/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ You can read the documentation of [TermInterface.jl](https://github.com/JuliaSym

## SymbolicUtils.jl only methods

`promote_symtype(f, arg_symtypes...)`

Returns the appropriate output type of applying `f` on arguments of type `arg_symtypes`.
```@docs
symtype
issym
promote_symtype
```
2 changes: 1 addition & 1 deletion docs/src/manual/representation.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Performance of symbolic simplification depends on the datastructures used to rep

The most basic term representation simply holds a function call and stores the function and the arguments it is called with. This is done by the `Term` type in SymbolicUtils. Functions that aren't commutative or associative, such as `sin` or `hypot` are stored as `Term`s. Commutative and associative operations like `+`, `*`, and their supporting operations like `-`, `/` and `^`, when used on terms of type `<:Number`, stand to gain from the use of more efficient datastrucutres.

All term representations must support `operation` and `arguments` functions. And they must define `istree` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.)
All term representations must support `operation` and `arguments` functions. And they must define `iscall` and `isexpr` to return `true` when called with an instance of the type. Generic term-manipulation programs such as the rule-based rewriter make use of this interface to inspect expressions. In this way, the interface wins back the generality lost by having a zoo of term representations instead of one. (see [interface](/interface/) section for more on this.)


### Preliminary representation of arithmetic
Expand Down
4 changes: 2 additions & 2 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ using SymbolicIndexingInterface
import Base: +, -, *, /, //, \, ^, ImmutableDict
using ConstructionBase
using TermInterface
import TermInterface: iscall, isexpr, issym, symtype, head, children,
operation, arguments, metadata, maketerm
import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm, sorted_arguments

Base.@deprecate istree iscall
export istree, operation, arguments, sorted_arguments, similarterm, iscall
Expand Down
8 changes: 3 additions & 5 deletions src/code.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr,
import ..SymbolicUtils
import ..SymbolicUtils.Rewriters
import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym,
symtype, similarterm, sorted_arguments, metadata, isterm, term, maketerm
symtype, sorted_arguments, metadata, isterm, term, maketerm

##== state management ==##

Expand Down Expand Up @@ -694,7 +694,7 @@ function _cse!(mem, expr)
iscall(expr) || return expr
op = _cse!(mem, operation(expr))
args = map(Base.Fix1(_cse!, mem), arguments(expr))
t = similarterm(expr, op, args)
t = maketerm(typeof(expr), op, args, nothing)

v, dict = mem
update! = let v=v, t=t
Expand Down Expand Up @@ -763,9 +763,7 @@ function cse_block!(assignments, counter, names, name, state, x)
if isterm(x)
return term(operation(x), args...)
else
return maketerm(typeof(x), operation(x),
args, symtype(x),
metadata(x))
return maketerm(typeof(x), operation(x), args, metadata(x))
end
else
return x
Expand Down
84 changes: 0 additions & 84 deletions src/interface.jl

This file was deleted.

23 changes: 10 additions & 13 deletions src/polyform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,6 @@ function polyize(x, pvar2sym, sym2term, vtype, pow, Fs, recurse)
maketerm(typeof(x),
op,
map(a->PolyForm(a, pvar2sym, sym2term, vtype; Fs, recurse), args),
symtype(x),
metadata(x))
else
x
Expand Down Expand Up @@ -176,18 +175,18 @@ isexpr(x::PolyForm) = true
iscall(x::Type{<:PolyForm}) = true
iscall(x::PolyForm) = true

function maketerm(::Type{<:PolyForm}, f, args, symtype, metadata)
basicsymbolic(t, f, args, symtype, metadata)
function maketerm(t::Type{<:PolyForm}, f, args, metadata)
# TODO: this looks uncovered.
basicsymbolic(f, args, nothing, metadata)
end
function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)},
args, symtype, metadata)
function maketerm(::Type{<:PolyForm}, f::Union{typeof(*), typeof(+), typeof(^)}, args, metadata)
f(args...)
end

head(::PolyForm) = PolyForm
operation(x::PolyForm) = MP.nterms(x.p) == 1 ? (*) : (+)

function arguments(x::PolyForm{T}) where {T}
function TermInterface.arguments(x::PolyForm{T}) where {T}

function is_var(v)
MP.nterms(v) == 1 &&
Expand Down Expand Up @@ -231,10 +230,7 @@ function arguments(x::PolyForm{T}) where {T}
PolyForm{T}(t, x.pvar2sym, x.sym2term, nothing)) for t in ts]
end
end

sorted_arguments(x::PolyForm) = arguments(x)

children(x::PolyForm) = [operation(x); arguments(x)]
children(x::PolyForm) = arguments(x)

Base.show(io::IO, x::PolyForm) = show_term(io, x)

Expand All @@ -255,7 +251,7 @@ function unpolyize(x)
# we need a special maketerm here because the default one used in Postwalk will call
# promote_symtype to get the new type, but we just want to forward that in case
# promote_symtype is not defined for some of the expressions here.
Postwalk(identity, maketerm=(T,f,args,sT,m) -> maketerm(T, f, args, symtype(x), m))(x)
Postwalk(identity, maketerm=(T,f,args,m) -> maketerm(T, f, args, m))(x)
end

function toterm(x::PolyForm)
Expand Down Expand Up @@ -307,7 +303,8 @@ function add_divs(x, y)
end
end

function frac_maketerm(T, f, args, stype, metadata)
function frac_maketerm(T, f, args, metadata)
# TODO add stype to T?
if f in (*, /, \, +, -)
f(args...)
elseif f == (^)
Expand All @@ -317,7 +314,7 @@ function frac_maketerm(T, f, args, stype, metadata)
args[1]^args[2]
end
else
maketerm(T, f, args, stype, metadata)
maketerm(T, f, args, metadata)
end
end

Expand Down
34 changes: 9 additions & 25 deletions src/rewriters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -167,11 +167,7 @@ end
struct Walk{ord, C, F, threaded}
rw::C
thread_cutoff::Int
maketerm::F # XXX: for the 2.0 deprecation cycle, we actually store a function
# that behaves like `similarterm` here, we use `compatmaker` to wrap
# maketerm-like input to do this, with a warning if similarterm provided
# we need this workaround to deprecate because similarterm takes value
# but maketerm only knows the type.
maketerm::F
end

function instrument(x::Walk{ord, C,F,threaded}, f) where {ord,C,F,threaded}
Expand All @@ -183,25 +179,13 @@ end

using .Threads

function compatmaker(similarterm, maketerm)
# XXX: delete this and only use maketerm in a future release.
if similarterm isa Nothing
function (x, f, args, type=_promote_symtype(f, args); metadata)
maketerm(typeof(x), f, args, type, metadata)
end
else
Base.depwarn("Prewalk and Postwalk now take maketerm instead of similarterm keyword argument. similarterm(x, f, args, type; metadata) is now maketerm(typeof(x), f, args, type, metadata)", :similarterm)
similarterm
end
end
function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
maker = compatmaker(similarterm, maketerm)
Walk{:post, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)

function Postwalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm)
Walk{:post, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm)
end

function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm, similarterm=nothing)
maker = compatmaker(similarterm, maketerm)
Walk{:pre, typeof(rw), typeof(maker), threaded}(rw, thread_cutoff, maker)
function Prewalk(rw; threaded::Bool=false, thread_cutoff=100, maketerm=maketerm)
Walk{:pre, typeof(rw), typeof(maketerm), threaded}(rw, thread_cutoff, maketerm)
end

struct PassThrough{C}
Expand All @@ -220,8 +204,8 @@ function (p::Walk{ord, C, F, false})(x) where {ord, C, F}
end

if iscall(x)
x = p.maketerm(x, operation(x), map(PassThrough(p),
arguments(x)), metadata=metadata(x))
x = p.maketerm(typeof(x), operation(x), map(PassThrough(p),
arguments(x)), metadata(x))
end

return ord === :post ? p.rw(x) : x
Expand All @@ -245,7 +229,7 @@ function (p::Walk{ord, C, F, true})(x) where {ord, C, F}
end
end
args = map((t,a) -> passthrough(t isa Task ? fetch(t) : t, a), _args, arguments(x))
t = p.maketerm(x, operation(x), args, metadata=metadata(x))
t = p.maketerm(typeof(x), operation(x), args, metadata(x))
end
return ord === :post ? p.rw(t) : t
else
Expand Down
2 changes: 1 addition & 1 deletion src/rule.jl
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ function (acr::ACRule)(term)
if result !== nothing
# Assumption: inds are unique
length(args) == length(inds) && return result
return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i inds)...], symtype(term), metadata(term))
return maketerm(typeof(term), f, [result, (args[i] for i in eachindex(args) if i inds)...], metadata(term))
end
end
end
Expand Down
2 changes: 1 addition & 1 deletion src/simplify.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ end

has_operation(x, op) = (iscall(x) && (operation(x) == op ||
any(a->has_operation(a, op),
arguments(x))))
arguments(x))))

Base.@deprecate simplify(x, ctx; kwargs...) simplify(x; rewriter=ctx, kwargs...)
1 change: 0 additions & 1 deletion src/substitute.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ function substitute(expr, dict; fold=true)
maketerm(typeof(expr),
op,
args,
symtype(expr),
metadata(expr))
else
expr
Expand Down
Loading

0 comments on commit da9267e

Please sign in to comment.