From 3ea899f798af9aa8ad49ea6548cc736d1c032b55 Mon Sep 17 00:00:00 2001 From: Akira Kyle Date: Mon, 18 Nov 2024 13:15:07 -0700 Subject: [PATCH] WIP moshi --- Project.toml | 4 +- src/SymbolicUtils.jl | 3 +- src/code.jl | 6 +- src/inspect.jl | 2 +- src/matchers.jl | 18 +- src/methods.jl | 4 +- src/ordering.jl | 2 +- src/polyform.jl | 4 +- src/substitute.jl | 2 + src/types.jl | 416 ++++++++++++++++++++++++++----------------- src/utils.jl | 21 ++- test/basics.jl | 16 +- test/runtests.jl | 5 +- test/types.jl | 113 ++++++++++++ 14 files changed, 423 insertions(+), 193 deletions(-) create mode 100644 test/types.jl diff --git a/Project.toml b/Project.toml index ec3162512..850126888 100644 --- a/Project.toml +++ b/Project.toml @@ -16,6 +16,7 @@ DynamicPolynomials = "7c1d4256-1411-5781-91ec-d7bc3513ac07" IfElse = "615f187c-cbe4-4ef1-ba3b-2fcf58d6d173" LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d" MultivariatePolynomials = "102ac46a-7ee4-5c85-9060-abc95bfdeaa3" NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -25,7 +26,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c" TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f" -Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415" WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1" [weakdeps] @@ -48,6 +48,7 @@ DocStringExtensions = "0.8, 0.9" DynamicPolynomials = "0.5, 0.6" IfElse = "0.1" LabelledArrays = "1.5" +Moshi = "0.3.5" MultivariatePolynomials = "0.5" NaNMath = "0.3, 1" ReverseDiff = "1" @@ -57,7 +58,6 @@ StaticArrays = "0.12, 1.0" SymbolicIndexingInterface = "0.3" TermInterface = "2.0" TimerOutputs = "0.5" -Unityper = "0.1.2" WeakValueDicts = "0.1.0" julia = "1.3" diff --git a/src/SymbolicUtils.jl b/src/SymbolicUtils.jl index fb13f50b4..20d54df7b 100644 --- a/src/SymbolicUtils.jl +++ b/src/SymbolicUtils.jl @@ -7,7 +7,8 @@ using DocStringExtensions export @syms, term, showraw, hasmetadata, getmetadata, setmetadata -using Unityper +using Moshi.Data: @data, data_type_name, variant_name +using Moshi.Match: @match using TermInterface using DataStructures using Setfield diff --git a/src/code.jl b/src/code.jl index 4128a39fd..9be9cc4d6 100644 --- a/src/code.jl +++ b/src/code.jl @@ -8,7 +8,7 @@ export toexpr, Assignment, (←), Let, Func, DestructuredArgs, LiteralExpr, import ..SymbolicUtils import ..SymbolicUtils.Rewriters -import SymbolicUtils: @matchable, BasicSymbolic, Sym, Term, iscall, operation, arguments, issym, +import SymbolicUtils: @matchable, BasicSymbolicType, Sym, Term, iscall, operation, arguments, issym, symtype, sorted_arguments, metadata, isterm, term, maketerm import SymbolicIndexingInterface: symbolic_type, NotSymbolic @@ -156,7 +156,7 @@ function function_to_expr(::typeof(SymbolicUtils.ifelse), O, st) :($(toexpr(args[1], st)) ? $(toexpr(args[2], st)) : $(toexpr(args[3], st))) end -function function_to_expr(x::BasicSymbolic, O, st) +function function_to_expr(x::BasicSymbolicType, O, st) issym(x) ? get(st.rewrites, O, nothing) : nothing end @@ -766,7 +766,7 @@ end function cse_block(state, t, name=Symbol("var-", hash(t))) assignments = Assignment[] counter = Ref{Int}(1) - names = Dict{Any, BasicSymbolic}() + names = Dict{Any, BasicSymbolicType}() Let(assignments, cse_block!(assignments, counter, names, name, state, t)) end diff --git a/src/inspect.jl b/src/inspect.jl index ab3951725..3e06dfdc2 100644 --- a/src/inspect.jl +++ b/src/inspect.jl @@ -5,7 +5,7 @@ function AbstractTrees.nodevalue(x::Symbolic) iscall(x) ? operation(x) : isexpr(x) ? head(x) : x end -function AbstractTrees.nodevalue(x::BasicSymbolic) +function AbstractTrees.nodevalue(x::BasicSymbolicType) str = if !iscall(x) string(exprtype(x), "(", x, ")") elseif isadd(x) diff --git a/src/matchers.jl b/src/matchers.jl index 7f4dea537..edf7a0484 100644 --- a/src/matchers.jl +++ b/src/matchers.jl @@ -6,9 +6,23 @@ # 3. Callback: takes arguments Dictionary × Number of elements matched # function matcher(val::Any) - iscall(val) && return term_matcher(val) + if isconst(val) + slot = val.val + return matcher(slot) + elseif iscall(val) + return term_matcher(val) + end function literal_matcher(next, data, bindings) - islist(data) && isequal(car(data), val) ? next(bindings, 1) : nothing + if islist(data) + cd = car(data) + if isconst(cd) + cd = cd.val + end + if isequal(cd, val) + return next(bindings, 1) + end + end + nothing end end diff --git a/src/methods.jl b/src/methods.jl index 2baef6424..28036b05e 100644 --- a/src/methods.jl +++ b/src/methods.jl @@ -95,8 +95,8 @@ macro number_methods(T, rhs1, rhs2, options=nothing) number_methods(T, rhs1, rhs2, options) |> esc end -@number_methods(BasicSymbolic{<:Number}, term(f, a), term(f, a, b), skipbasics) -@number_methods(BasicSymbolic{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics) +@number_methods(BasicSymbolicType{<:Number}, term(f, a), term(f, a, b), skipbasics) +@number_methods(BasicSymbolicType{<:LiteralReal}, term(f, a), term(f, a, b), onlybasics) for f in vcat(diadic, [+, -, *, \, /, ^]) @eval promote_symtype(::$(typeof(f)), diff --git a/src/ordering.jl b/src/ordering.jl index 332f11cf8..02f0b0de3 100644 --- a/src/ordering.jl +++ b/src/ordering.jl @@ -78,7 +78,7 @@ function <ₑ(a::Tuple, b::Tuple) return length(a) < length(b) end -function <ₑ(a::BasicSymbolic, b::BasicSymbolic) +function <ₑ(a::BasicSymbolicType, b::BasicSymbolicType) da, db = get_degrees(a), get_degrees(b) fw = monomial_lt(da, db) bw = monomial_lt(db, da) diff --git a/src/polyform.jl b/src/polyform.jl index 7d6bc906e..24054e592 100644 --- a/src/polyform.jl +++ b/src/polyform.jl @@ -29,7 +29,7 @@ PolyForm(sin((x+y)^2), recurse=true) #=> sin((x^2 + (2x)y + y^2)) struct PolyForm{T} <: Symbolic{T} p::MP.AbstractPolynomialLike pvar2sym::Bijection{Any,Any} # @polyvar x --> @sym x etc. - sym2term::Dict{BasicSymbolic,Any} # Symbol("sin-$hash(sin(x+y))") --> sin(x+y) => sin(PolyForm(...)) + sym2term::Dict{BasicSymbolicType,Any} # Symbol("sin-$hash(sin(x+y))") --> sin(x+y) => sin(PolyForm(...)) metadata function (::Type{PolyForm{T}})(p, d1, d2, m=nothing) where {T} p isa Number && return p @@ -63,7 +63,7 @@ end function get_sym2term() v = SYM2TERM[].value if v === nothing - d = Dict{BasicSymbolic,Any}() + d = Dict{BasicSymbolicType,Any}() SYM2TERM[] = WeakRef(d) return d else diff --git a/src/substitute.jl b/src/substitute.jl index 828f88b14..4548f7c29 100644 --- a/src/substitute.jl +++ b/src/substitute.jl @@ -22,6 +22,7 @@ function substitute(expr, dict; fold=true) canfold = !(op isa Symbolic) args = map(arguments(expr)) do x x′ = substitute(x, dict; fold=fold) + x′ = isconst(x) ? x′.val : x′ canfold = canfold && !(x′ isa Symbolic) x′ end @@ -54,6 +55,7 @@ function _occursin(needle, haystack) if iscall(haystack) args = arguments(haystack) for arg in args + arg = isconst(arg) ? arg.val : arg if needle isa Integer || needle isa AbstractFloat isequal(needle, arg) && return true else diff --git a/src/types.jl b/src/types.jl index 898259f44..bd2fc71db 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,103 +1,113 @@ -#------------------- -#-------------------- -#### Symbolic -#-------------------- abstract type Symbolic{T} end -### -### Uni-type design -### - -@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV +@enum ExprType::UInt8 SYM TERM ADD MUL POW DIV CONST const Metadata = Union{Nothing,Base.ImmutableDict{DataType,Any}} const NO_METADATA = nothing -sdict(kv...) = Dict{Any, Any}(kv...) - using Base: RefValue -const EMPTY_ARGS = [] const EMPTY_HASH = RefValue(UInt(0)) -const NOT_SORTED = RefValue(false) -const EMPTY_DICT = sdict() -const EMPTY_DICT_T = typeof(EMPTY_DICT) - -@compactify show_methods=false begin - @abstract mutable struct BasicSymbolic{T} <: Symbolic{T} - metadata::Metadata = NO_METADATA - end - mutable struct Sym{T} <: BasicSymbolic{T} - name::Symbol = :OOF - end - mutable struct Term{T} <: BasicSymbolic{T} - f::Any = identity # base/num if Pow; issorted if Add/Dict - arguments::Vector{Any} = EMPTY_ARGS - hash::RefValue{UInt} = EMPTY_HASH - end - mutable struct Mul{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - arguments::Vector{Any} = EMPTY_ARGS - issorted::RefValue{Bool} = NOT_SORTED - end - mutable struct Add{T} <: BasicSymbolic{T} - coeff::Any = 0 # exp/den if Pow - dict::EMPTY_DICT_T = EMPTY_DICT - hash::RefValue{UInt} = EMPTY_HASH - arguments::Vector{Any} = EMPTY_ARGS - issorted::RefValue{Bool} = NOT_SORTED - end - mutable struct Div{T} <: BasicSymbolic{T} - num::Any = 1 - den::Any = 1 - simplified::Bool = false - arguments::Vector{Any} = EMPTY_ARGS - end - mutable struct Pow{T} <: BasicSymbolic{T} - base::Any = 1 - exp::Any = 1 - arguments::Vector{Any} = EMPTY_ARGS - end -end - -function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolic}) + +# TODO: Actually close the type system by making everything hold only BasicSymbolicType except Const +@data BasicSymbolic{T} <: Symbolic{T} begin + struct Sym + metadata::Metadata = NO_METADATA + name::Symbol = :OOF + end + struct Term + metadata::Metadata = NO_METADATA + f::Any = identity + arguments::Vector{Symbolic} = Symbolic[] + hash::RefValue{UInt} = EMPTY_HASH + end + struct Mul + metadata::Metadata = NO_METADATA + coeff::Any = 0 + dict::Dict{BasicSymbolic.Type, Any} = Dict{BasicSymbolic.Type, Any}() + hash::RefValue{UInt} = EMPTY_HASH + arguments::Vector{Any} = [] + issorted::RefValue{Bool} = RefValue(false) + end + struct Add + metadata::Metadata = NO_METADATA + coeff::Any = 0 + dict::Dict{BasicSymbolic.Type, Any} = Dict{BasicSymbolic.Type, Any}() + hash::RefValue{UInt} = EMPTY_HASH + arguments::Vector{Any} = [] + issorted::RefValue{Bool} = RefValue(false) + end + struct Div + metadata::Metadata = NO_METADATA + num::Any = 1 + den::Any = 1 + simplified::Bool = false + arguments::Vector{Any} = [] + end + struct Pow + metadata::Metadata = NO_METADATA + base::Any = 1 + exp::Any = 1 + arguments::Vector{Any} = [] + end + struct Const + metadata::Metadata = NO_METADATA + val::Any = 1 + end +end + +const BasicSymbolicType = BasicSymbolic.Type +const Term = BasicSymbolic.Term +const Sym = BasicSymbolic.Sym +const Add = BasicSymbolic.Add +const Mul = BasicSymbolic.Mul +const Div = BasicSymbolic.Div +const Pow = BasicSymbolic.Pow +const Const = BasicSymbolic.Const + +function SymbolicIndexingInterface.symbolic_type(::Type{<:BasicSymbolicType}) ScalarSymbolic() end -function exprtype(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => TERM - Add => ADD - Mul => MUL - Div => DIV - Pow => POW - Sym => SYM - _ => error_on_type() +function exprtype(x::BasicSymbolicType) + @match x begin + Term(_) => TERM + Add(_) => ADD + Mul(_) => MUL + Div(_) => DIV + Pow(_) => POW + Sym(_) => SYM + Const(_) => CONST end end -const wvd = WeakValueDict{UInt, BasicSymbolic}() +const wvd = WeakValueDict{UInt, BasicSymbolicType}() # Same but different error messages @noinline error_on_type() = error("Internal error: unreachable reached!") @noinline error_sym() = error("Sym doesn't have a operation or arguments!") +@noinline error_const() = error("Const doesn't have a operation or arguments!") @noinline error_property(E, s) = error("$E doesn't have field $s") # We can think about bits later # flags const SIMPLIFIED = 0x01 << 0 -#@inline is_of_type(x::BasicSymbolic, type::UInt8) = (x.bitflags & type) != 0x00 -#@inline issimplified(x::BasicSymbolic) = is_of_type(x, SIMPLIFIED) +#@inline is_of_type(x::BasicSymbolicType, type::UInt8) = (x.bitflags & type) != 0x00 +#@inline issimplified(x::BasicSymbolicType) = is_of_type(x, SIMPLIFIED) -function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T +function ConstructionBase.setproperties(obj::BasicSymbolicType{T}, patch::NamedTuple)::BasicSymbolicType{T} where T nt = getproperties(obj) nt_new = merge(nt, patch) - # Call outer constructor because hash consing cannot be applied in inner constructor - @compactified obj::BasicSymbolic begin - Sym => Sym{T}(nt_new.name; nt_new...) - _ => Unityper.rt_constructor(obj){T}(;nt_new...) + #data_type_name(obj){T}(;nt_new...) + # TODO which to use? + @match obj begin + Sym(_) => Sym{T}(;nt_new...) + Term(_) => Term{T}(;nt_new...) + Add(_) => Add{T}(;nt_new...) + Mul(_) => Mul{T}(;nt_new...) + Div(_) => Div{T}(;nt_new...) + Pow(_) => Pow{T}(;nt_new...) + Const(_) => Const{T}(;nt_new...) end end @@ -119,25 +129,25 @@ symtype(x) = typeof(x) @inline symtype(::Type{<:Symbolic{T}}) where T = T # We're returning a function pointer -@inline function operation(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => x.f - Add => (+) - Mul => (*) - Div => (/) - Pow => (^) - Sym => error_sym() - _ => error_on_type() +@inline function operation(x::BasicSymbolicType) + @match x begin + Term(_) => x.f + Add(_) => (+) + Mul(_) => (*) + Div(_) => (/) + Pow(_) => (^) + Sym(_) => error_sym() + Const(_) => error_const() end end -@inline head(x::BasicSymbolic) = operation(x) +@inline head(x::BasicSymbolicType) = operation(x) -function TermInterface.sorted_arguments(x::BasicSymbolic) +function TermInterface.sorted_arguments(x::BasicSymbolicType) args = arguments(x) - @compactified x::BasicSymbolic begin - Add => @goto ADD - Mul => @goto MUL + @match x begin + Add(_) => @goto ADD + Mul(_) => @goto MUL _ => return args end @label MUL @@ -157,17 +167,17 @@ end @deprecate unsorted_arguments(x) arguments(x) -TermInterface.children(x::BasicSymbolic) = arguments(x) -TermInterface.sorted_children(x::BasicSymbolic) = sorted_arguments(x) -function TermInterface.arguments(x::BasicSymbolic) - @compactified x::BasicSymbolic begin - Term => return x.arguments - Add => @goto ADDMUL - Mul => @goto ADDMUL - Div => @goto DIV - Pow => @goto POW - Sym => error_sym() - _ => error_on_type() +TermInterface.children(x::BasicSymbolicType) = arguments(x) +TermInterface.sorted_children(x::BasicSymbolicType) = sorted_arguments(x) +function TermInterface.arguments(x::BasicSymbolicType) + @match x begin + Term(_) => return x.arguments + Add(_) => @goto ADDMUL + Mul(_) => @goto ADDMUL + Div(_) => @goto DIV + Pow(_) => @goto POW + Sym(_) => error_sym() + Const(_) => error_const() end @label ADDMUL @@ -175,7 +185,7 @@ function TermInterface.arguments(x::BasicSymbolic) args = x.arguments isempty(args) || return args siz = length(x.dict) - idcoeff = E === ADD ? iszero(x.coeff) : isone(x.coeff) + idcoeff = E === ADD ? _iszero(x.coeff) : _isone(x.coeff) sizehint!(args, idcoeff ? siz : siz + 1) idcoeff || push!(args, x.coeff) if isadd(x) @@ -207,10 +217,17 @@ function TermInterface.arguments(x::BasicSymbolic) return args end -isexpr(s::BasicSymbolic) = !issym(s) -iscall(s::BasicSymbolic) = isexpr(s) +function isexpr(x::BasicSymbolicType) + @match x begin + BasicSymbolic.Sym(_) => false + BasicSymbolic.Const(_) => false + _ => true + end +end -@inline isa_SymType(T::Val{S}, x) where {S} = x isa BasicSymbolic ? Unityper.isa_type_fun(Val(SymbolicUtils.BasicSymbolic), T, x) : false +iscall(s::BasicSymbolicType) = isexpr(s) + +@inline isa_SymType(S, x) = x isa BasicSymbolicType ? variant_name(x) == S : false """ issym(x) @@ -218,12 +235,13 @@ iscall(s::BasicSymbolic) = isexpr(s) Returns `true` if `x` is a `Sym`. If true, `nameof` must be defined on `x` and must return a `Symbol`. """ -issym(x) = isa_SymType(Val(:Sym), x) -isterm(x) = isa_SymType(Val(:Term), x) -ismul(x) = isa_SymType(Val(:Mul), x) -isadd(x) = isa_SymType(Val(:Add), x) -ispow(x) = isa_SymType(Val(:Pow), x) -isdiv(x) = isa_SymType(Val(:Div), x) +issym(x) = isa_SymType(:Sym, x) +isterm(x) = isa_SymType(:Term, x) +ismul(x) = isa_SymType(:Mul, x) +isadd(x) = isa_SymType(:Add, x) +ispow(x) = isa_SymType(:Pow, x) +isdiv(x) = isa_SymType(:Div, x) +isconst(x) = isa_SymType(:Const, x) ### ### Base interface @@ -244,7 +262,7 @@ function _allarequal(xs, ys)::Bool return true end -function Base.isequal(a::BasicSymbolic{T}, b::BasicSymbolic{S}) where {T,S} +function Base.isequal(a::BasicSymbolicType{T}, b::BasicSymbolicType{S}) where {T,S} a === b && return true E = exprtype(a) @@ -266,6 +284,8 @@ function _isequal(a, b, E) a1 = arguments(a) a2 = arguments(b) isequal(operation(a), operation(b)) && _allarequal(a1, a2) + elseif E === CONST + isequal(a.val, b.val) else error_on_type() end @@ -274,10 +294,10 @@ end """ $(TYPEDSIGNATURES) -Checks for equality between two `BasicSymbolic` objects, considering both their +Checks for equality between two `BasicSymbolicType` objects, considering both their values and metadata. -The default `Base.isequal` function for `BasicSymbolic` only compares their expressions +The default `Base.isequal` function for `BasicSymbolicType` only compares their expressions and ignores metadata. This does not help deal with hash collisions when metadata is relevant for distinguishing expressions, particularly in hashing contexts. This function provides a stricter equality check that includes metadata comparison, preventing @@ -287,14 +307,14 @@ Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` an downstream packages like `ModelingToolkit.jl`, hence the need for this separate function. """ -function isequal_with_metadata(a::BasicSymbolic, b::BasicSymbolic)::Bool +function isequal_with_metadata(a::BasicSymbolicType, b::BasicSymbolicType)::Bool isequal(a, b) && isequal(metadata(a), metadata(b)) end Base.one( s::Symbolic) = one( symtype(s)) Base.zero(s::Symbolic) = zero(symtype(s)) -Base.nameof(s::BasicSymbolic) = issym(s) ? s.name : error("None Sym BasicSymbolic doesn't have a name") +Base.nameof(s::BasicSymbolicType) = issym(s) ? s.name : error("None Sym BasicSymbolicType doesn't have a name") ## This is much faster than hash of an array of Any hashvec(xs, z) = foldr(hash, xs, init=z) @@ -303,7 +323,8 @@ const ADD_SALT = 0xaddaddaddaddadda % UInt const SUB_SALT = 0xaaaaaaaaaaaaaaaa % UInt const DIV_SALT = 0x334b218e73bbba53 % UInt const POW_SALT = 0x2b55b97a6efb080c % UInt -function Base.hash(s::BasicSymbolic, salt::UInt)::UInt +const COS_SALT = 0xdc3d6b8f18b75e3c % UInt +function Base.hash(s::BasicSymbolicType, salt::UInt)::UInt E = exprtype(s) if E === SYM hash(nameof(s), salt ⊻ SYM_SALT) @@ -328,6 +349,8 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt h′ = hashvec(arguments(s), hash(oph, salt)) s.hash[] = h′ return h′ + elseif E === CONST + return hash(s.val, salt ⊻ COS_SALT) else error_on_type() end @@ -336,17 +359,17 @@ end """ $(TYPEDSIGNATURES) -Calculates a hash value for a `BasicSymbolic` object, incorporating both its metadata and +Calculates a hash value for a `BasicSymbolicType` object, incorporating both its metadata and symtype. -This function provides an alternative hashing strategy to `Base.hash` for `BasicSymbolic` +This function provides an alternative hashing strategy to `Base.hash` for `BasicSymbolicType` objects. Unlike `Base.hash`, which only considers the expression structure, `hash2` also includes the metadata and symtype in the hash calculation. This can be beneficial for hash consing, allowing for more effective deduplication of symbolically equivalent expressions with different metadata or symtypes. """ -hash2(s::BasicSymbolic) = hash2(s, zero(UInt)) -function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T} +hash2(s::BasicSymbolicType) = hash2(s, zero(UInt)) +function hash2(s::BasicSymbolicType{T}, salt::UInt)::UInt where {T} hash(metadata(s), hash(T, hash(s, salt))) end @@ -357,9 +380,9 @@ end """ $(TYPEDSIGNATURES) -Implements hash consing (flyweight design pattern) for `BasicSymbolic` objects. +Implements hash consing (flyweight design pattern) for `BasicSymbolicType` objects. -This function checks if an equivalent `BasicSymbolic` object already exists. It uses a +This function checks if an equivalent `BasicSymbolicType` object already exists. It uses a custom hash function (`hash2`) incorporating metadata and symtypes to search for existing objects in a `WeakValueDict` (`wvd`). Due to the possibility of hash collisions (where different objects produce the same hash), a custom equality check (`isequal_with_metadata`) @@ -369,13 +392,13 @@ otherwise, the input `s` is returned. This reduces memory usage, improves compil for runtime code generation, and supports built-in common subexpression elimination, particularly when working with symbolic objects with metadata. -Using a `WeakValueDict` ensures that only weak references to `BasicSymbolic` objects are +Using a `WeakValueDict` ensures that only weak references to `BasicSymbolicType` objects are stored, allowing objects that are no longer strongly referenced to be garbage collected. Custom functions `hash2` and `isequal_with_metadata` are used instead of `Base.hash` and `Base.isequal` to accommodate metadata without disrupting existing tests reliant on the original behavior of those functions. """ -function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic +function BasicSymbolicEquivalent(s::BasicSymbolicType)::BasicSymbolicType h = hash2(s) t = get!(wvd, h, s) if t === s || isequal_with_metadata(t, s) @@ -385,24 +408,42 @@ function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic end end +# TODO: figure out how to implement BasicSymbolicEquivalent function Sym{T}(name::Symbol; kw...) where {T} - s = Sym{T}(; name, kw...) - BasicSymbolic(s) + #s = Sym{T}(; name=name, kw...) + #BasicSymbolicEquivalent(s) + Sym{T}(; name=name, kw...) end function Term{T}(f, args; kw...) where T - if eltype(args) !== Any - args = convert(Vector{Any}, args) - end + #if eltype(args) !== Symbolic + # args = convert(Vector{Symbolic}, args) + #end - Term{T}(;f=f, arguments=args, hash=Ref(UInt(0)), kw...) + # TODO: revisit convert after https://github.com/Roger-luo/Moshi.jl/issues/32 is resolved + Term{T}(;f=f, arguments=convert(Vector{Any}, args), hash=Ref(UInt(0)), kw...) end function Term(f, args; metadata=NO_METADATA) Term{_promote_symtype(f, args)}(f, args, metadata=metadata) end -function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T +function Const(val::T; kwargs...) where {T} + BasicSymbolic.Const{T}(; val=val, kwargs...) +end + +function Base.convert(::Type{Symbolic}, x) + Const(x) +end + +function Base.convert(::Type{BasicSymbolicType}, x) + Const(x) +end +function Base.convert(::Type{BasicSymbolicType}, x::BasicSymbolicType) + x +end + +function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where {T} if isempty(dict) return coeff elseif _iszero(coeff) && length(dict) == 1 @@ -415,10 +456,12 @@ function Add(::Type{T}, coeff, dict; metadata=NO_METADATA, kw...) where T end end - Add{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + # TODO: revisit convert after https://github.com/Roger-luo/Moshi.jl/issues/32 is resolved + Add{T}(; coeff=coeff, dict=convert(Dict{BasicSymbolicType, Any}, dict), + hash=Ref(UInt(0)), metadata=metadata, arguments=[], issorted=RefValue(false), kw...) end -function Mul(T, a, b; metadata=NO_METADATA, kw...) +function Mul(::Type{T}, a, b; metadata=NO_METADATA, kw...) where {T} isempty(b) && return a if _isone(a) && length(b) == 1 pair = first(b) @@ -430,7 +473,23 @@ function Mul(T, a, b; metadata=NO_METADATA, kw...) else coeff = a dict = b - Mul{T}(; coeff, dict, hash=Ref(UInt(0)), metadata, arguments=[], issorted=RefValue(false), kw...) + # TODO: revisit convert after https://github.com/Roger-luo/Moshi.jl/issues/32 is resolved + Mul{T}(; coeff=coeff, dict=convert(Dict{BasicSymbolicType, Any}, dict), + hash=Ref(UInt(0)), metadata=metadata, arguments=[], issorted=RefValue(false), kw...) + end +end + +function _iszero(x::BasicSymbolicType) + @match x begin + Const(_) => iszero(x.val) + _ => false + end +end + +function _isone(x::BasicSymbolicType) + @match x begin + Const(_) => isone(x.val) + _ => false end end @@ -495,7 +554,7 @@ function Div{T}(n, d, simplified=false; metadata=nothing) where {T} end end - Div{T}(; num=n, den=d, simplified, arguments=[], metadata) + Div{T}(; num=n, den=d, simplified=simplified, arguments=[], metadata=metadata) end function Div(n,d, simplified=false; kw...) @@ -512,26 +571,26 @@ end function Pow{T}(a, b; metadata=NO_METADATA) where {T} _iszero(b) && return 1 _isone(b) && return a - Pow{T}(; base=a, exp=b, arguments=[], metadata) + Pow{T}(; base=a, exp=b, arguments=[], metadata=metadata) end function Pow(a, b; metadata=NO_METADATA) Pow{promote_symtype(^, symtype(a), symtype(b))}(makepow(a, b)..., metadata=metadata) end -function toterm(t::BasicSymbolic{T}) where T +function toterm(t::BasicSymbolicType{T}) where T E = exprtype(t) if E === SYM || E === TERM return t elseif E === ADD || E === MUL - args = Any[] + args = BasicSymbolicType[] push!(args, t.coeff) for (k, coeff) in t.dict - push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), Any[coeff, k])) + push!(args, coeff == 1 ? k : Term{T}(E === MUL ? (^) : (*), [Const(coeff), k])) end Term{T}(operation(t), args) elseif E === DIV - Term{T}(/, Any[t.num, t.den]) + Term{T}(/, [t.num, t.den]) elseif E === POW Term{T}(^, [t.base, t.exp]) else @@ -546,7 +605,7 @@ Any Muls inside an Add should always have a coeff of 1 and the key (in Add) should instead be used to store the actual coefficient """ function makeadd(sign, coeff, xs...) - d = sdict() + d = Dict{BasicSymbolicType, Any}() for x in xs if isadd(x) coeff += x.coeff @@ -573,7 +632,7 @@ function makeadd(sign, coeff, xs...) coeff, d end -function makemul(coeff, xs...; d=sdict()) +function makemul(coeff, xs...; d=Dict{BasicSymbolicType, Any}()) for x in xs if ispow(x) && x.exp isa Number d[x.base] = x.exp + get(d, x.base, 0) @@ -612,7 +671,7 @@ function term(f, args...; type = nothing) else T = type end - Term{T}(f, Any[args...]) + Term{T}(f, [args...]) end """ @@ -624,7 +683,7 @@ function unflatten(t::Symbolic{T}) where{T} f = operation(t) if f == (+) || f == (*) # TODO check out for other n-ary --> binary ops a = arguments(t) - return foldl((x,y) -> Term{T}(f, Any[x, y]), a) + return foldl((x,y) -> Term{T}(f, [x, y]), a) end end return t @@ -632,11 +691,11 @@ end unflatten(t) = t -function TermInterface.maketerm(T::Type{<:BasicSymbolic}, head, args, metadata) +function TermInterface.maketerm(T::Type{<:BasicSymbolicType}, head, args, metadata) st = symtype(T) pst = _promote_symtype(head, args) # Use promoted symtype only if not a subtype of the existing symtype of T. - # This is useful when calling `maketerm(BasicSymbolic{Number}, (==), [true, false])` + # This is useful when calling `maketerm(BasicSymbolicType{Number}, (==), [true, false])` # Where the result would have a symtype of Bool. # Please see discussion in https://github.com/JuliaSymbolics/SymbolicUtils.jl/pull/609 # TODO this should be optimized. @@ -709,10 +768,10 @@ end issafecanon(f, s) = true function issafecanon(f, s::Symbolic) - if isnothing(metadata(s)) || issym(s) - return true - else - _issafecanon(f, s) + isnothing(metadata(s)) || @match s begin + Sym(_) => true + Const(_) => true + _ => _issafecanon(f, s) end end _issafecanon(::typeof(*), s) = !iscall(s) || !(operation(s) in (+,*,^)) @@ -778,6 +837,10 @@ const show_simplified = Ref(false) isnegative(t::Real) = t < 0 function isnegative(t) + if isconst(t) + val = t.val + return isnegative(val) + end if iscall(t) && operation(t) === (*) coeff = first(arguments(t)) return isnegative(coeff) @@ -812,8 +875,12 @@ function remove_minus(t) !iscall(t) && return -t @assert operation(t) == (*) args = arguments(t) - @assert args[1] < 0 - Any[-args[1], args[2:end]...] + arg1 = args[1] + if isconst(arg1) + arg1 = arg1.val + end + @assert arg1 < 0 + Any[-arg1, args[2:end]...] end @@ -848,17 +915,27 @@ function show_pow(io, args) end function show_mul(io, args) + if isconst(args) + print(io, args.val) + return + end length(args) == 1 && return print_arg(io, *, args[1]) - minus = args[1] isa Number && args[1] == -1 - unit = args[1] isa Number && args[1] == 1 + arg1 = args[1] + if isconst(arg1) + arg1 = arg1.val + end + + minus = arg1 isa Number && arg1 == -1 + unit = arg1 isa Number && arg1 == 1 - paren_scalar = (args[1] isa Complex && !_iszero(imag(args[1]))) || - args[1] isa Rational || - (args[1] isa Number && !isfinite(args[1])) + paren_scalar = (arg1 isa Complex && !_iszero(imag(arg1))) || + arg1 isa Rational || + (arg1 isa Number && !isfinite(arg1)) nostar = minus || unit || - (!paren_scalar && args[1] isa Number && !(args[2] isa Number)) + (!paren_scalar && arg1 isa Number && + !(isconst(args[2]) && args[2].val isa Number)) for (i, t) in enumerate(args) if i != 1 @@ -946,11 +1023,11 @@ end showraw(io, t) = Base.show(IOContext(io, :simplify=>false), t) showraw(t) = showraw(stdout, t) -function Base.show(io::IO, v::BasicSymbolic) - if issym(v) - Base.show_unquoted(io, v.name) - else - show_term(io, v) +function Base.show(io::IO, v::BasicSymbolicType) + @match v begin + Sym(_) => Base.show_unquoted(io, v.name) + Const(_) => print(io, v.val) + _ => show_term(io, v) end end @@ -1002,7 +1079,7 @@ end The output symtype of applying variable `f` to arguments of symtype `arg_symtypes...`. if the arguments are of the wrong type then this function will error. """ -function promote_symtype(f::BasicSymbolic{<:FnType{X,Y}}, args...) where {X, Y} +function promote_symtype(f::BasicSymbolicType{<:FnType{X,Y}}, args...) where {X, Y} if X === Tuple return Y end @@ -1164,6 +1241,12 @@ sub_t(a) = promote_symtype(-, symtype(a)) import Base: (+), (-), (*), (//), (/), (\), (^) function +(a::SN, b::SN) + if isconst(a) + return a.val + b + end + if isconst(b) + return b.val + a + end !issafecanon(+, a,b) && return term(+, a, b) # Don't flatten if args have metadata if isadd(a) && isadd(b) return Add(add_t(a,b), @@ -1180,6 +1263,9 @@ function +(a::SN, b::SN) end function +(a::Number, b::SN) + if isconst(b) + return a + b.val + end !issafecanon(+, b) && return term(+, a, b) # Don't flatten if args have metadata iszero(a) && return b if isadd(b) @@ -1194,6 +1280,7 @@ end +(a::SN) = a function -(a::SN) + isconst(a) && return Const(-a.val) !issafecanon(*, a) && return term(-, a) isadd(a) ? Add(sub_t(a), -a.coeff, mapvalues((_,v) -> -v, a.dict)) : Add(sub_t(a), makeadd(-1, 0, a)...) @@ -1218,6 +1305,8 @@ mul_t(a) = promote_symtype(*, symtype(a)) *(a::SN) = a function *(a::SN, b::SN) + isconst(a) && return a.val * b + isconst(b) && return b.val * a # Always make sure Div wraps Mul !issafecanon(*, a, b) && return term(*, a, b) if isdiv(a) && isdiv(b) @@ -1246,6 +1335,7 @@ function *(a::SN, b::SN) end function *(a::Number, b::SN) + isconst(b) && return a * b.val !issafecanon(*, b) && return term(*, a, b) if iszero(a) a @@ -1256,7 +1346,7 @@ function *(a::Number, b::SN) elseif isone(-a) && isadd(b) # -1(a+b) -> -a - b T = promote_symtype(+, typeof(a), symtype(b)) - Add(T, b.coeff * a, Dict{Any,Any}(k=>v*a for (k, v) in b.dict)) + Add(T, b.coeff * a, Dict{BasicSymbolicType,Any}(k=>v*a for (k, v) in b.dict)) else Mul(mul_t(a, b), makemul(a, b)...) end diff --git a/src/utils.jl b/src/utils.jl index 812e229fb..df09b2482 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -64,8 +64,12 @@ end sym_isa(::Type{T}) where {T} = @nospecialize(x) -> x isa T || symtype(x) <: T -isliteral(::Type{T}) where {T} = x -> x isa T -is_literal_number(x) = isliteral(Number)(x) +function is_literal_number(x) + if isconst(x) + x = get_val(x) + end + x isa Number +end # checking the type directly is faster than dynamic dispatch in type unstable code _iszero(x) = x isa Number && iszero(x) @@ -179,10 +183,15 @@ Base.length(l::LL) = length(l.v)-l.i+1 @inline car(l::LL) = l.v[l.i] @inline cdr(l::LL) = isempty(l) ? empty(l) : LL(l.v, l.i+1) -Base.length(t::Term) = length(arguments(t)) + 1 # PIRACY -Base.isempty(t::Term) = false -@inline car(t::Term) = operation(t) -@inline cdr(t::Term) = arguments(t) +function Base.length(t::BasicSymbolicType) + @match t begin + Term(_) => length(arguments(t)) + 1 + _ => 1 + end +end +Base.isempty(t::BasicSymbolicType) = false +@inline car(t::BasicSymbolicType) = operation(t) +@inline cdr(t::BasicSymbolicType) = arguments(t) @inline car(v) = iscall(v) ? operation(v) : first(v) @inline function cdr(v) diff --git a/test/basics.jl b/test/basics.jl index a5f0b5149..589909a9a 100644 --- a/test/basics.jl +++ b/test/basics.jl @@ -1,4 +1,4 @@ -using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, isequal_with_metadata +using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolicType, term, isequal_with_metadata using SymbolicUtils using IfElse: ifelse using Setfield @@ -82,10 +82,10 @@ struct Ctx1 end struct Ctx2 end @testset "metadata" begin - @syms a b c - for a = [a, sin(a), a+b, a*b, a^3] + @syms a b + for x = [a, sin(a), a+b, a*b, a^3] - a′ = setmetadata(a, Ctx1, "meta_1") + a′ = setmetadata(x, Ctx1, "meta_1") @test hasmetadata(a′, Ctx1) @test !hasmetadata(a′, Ctx2) @@ -240,8 +240,8 @@ end @test isequal(SymbolicUtils.maketerm(typeof(b + c), +, [a, (b+c)], nothing).dict, Dict(a=>1,b=>1,c=>1)) @test isequal(SymbolicUtils.maketerm(typeof(b^2), ^, [b^2, 1//2], nothing), b) - # test that maketerm doesn't hard-code BasicSymbolic subtype - # and is consistent with BasicSymbolic arithmetic operations + # test that maketerm doesn't hard-code BasicSymbolicType subtype + # and is consistent with BasicSymbolicType arithmetic operations @test isequal(SymbolicUtils.maketerm(typeof(a / b), *, [a / b, c], nothing), (a / b) * c) @test isequal(SymbolicUtils.maketerm(typeof(a * b), *, [0, c], nothing), 0) @test isequal(SymbolicUtils.maketerm(typeof(a^b), ^, [a * b, 3], nothing), (a * b)^3) @@ -288,7 +288,7 @@ end # Check that the Array type does not get changed to AbstractArray new_expr = SymbolicUtils.maketerm( - SymbolicUtils.BasicSymbolic{Vector{Float64}}, sin, [1.0, 2.0], nothing) + SymbolicUtils.BasicSymbolicType{Vector{Float64}}, sin, [1.0, 2.0], nothing) @test symtype(new_expr) == Vector{Float64} end @@ -381,7 +381,7 @@ end @test repr(x*x) == "x * x" @test repr(x*x + x*x) == "(x * x) + (x * x)" for ex in [sin(x), x+x, x*x, x\x, x/x] - @test typeof(sin(x)) <: BasicSymbolic{LiteralReal} + @test typeof(sin(x)) <: BasicSymbolicType{LiteralReal} end @test repr(sin(x) + sin(x)) == "sin(x) + sin(x)" end diff --git a/test/runtests.jl b/test/runtests.jl index 9ea8354a8..4621ee35e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -5,6 +5,7 @@ using Pkg, Test, SafeTestsets @safetestset "Benchmark" begin include("benchmark.jl") end else @safetestset "Doc" begin include("doctest.jl") end + @safetestset "Types" begin include("types.jl") end @safetestset "Basics" begin include("basics.jl") end @safetestset "Order" begin include("order.jl") end @safetestset "PolyForm" begin include("polyform.jl") end @@ -13,8 +14,8 @@ using Pkg, Test, SafeTestsets @safetestset "Code" begin include("code.jl") end @safetestset "CSE" begin include("cse.jl") end @safetestset "Interface" begin include("interface.jl") end - # Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed - @safetestset "Fuzz" begin include("fuzz.jl") end + ## Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed + #@safetestset "Fuzz" begin include("fuzz.jl") end @safetestset "Adjoints" begin include("adjoints.jl") end @safetestset "Hash Consing" begin include("hash_consing.jl") end end diff --git a/test/types.jl b/test/types.jl new file mode 100644 index 000000000..b39b9386c --- /dev/null +++ b/test/types.jl @@ -0,0 +1,113 @@ +using SymbolicUtils: Symbolic, BasicSymbolic, BasicSymbolicType, Sym, Term, Add, Mul, Div, Pow, Const +using SymbolicUtils + +s1 = Sym{Float64}(:abc) +s2 = Sym{Int64}(; name = :def) +@testset "Sym" begin + @test typeof(s1) <: BasicSymbolicType + @test typeof(s1) == BasicSymbolicType{Float64} + @test s1 isa BasicSymbolicType + @test s1 isa SymbolicUtils.Symbolic + @test s1.metadata isa SymbolicUtils.Metadata + @test s1.metadata == SymbolicUtils.NO_METADATA + @test s1.name == :abc + @test typeof(s2) <: BasicSymbolicType + @test typeof(s2) == BasicSymbolicType{Int64} + @test typeof(s2.name) == Symbol + @test s2.name == :def +end + +@testset "Term" begin + t1 = Term(sin, [s1]) + @test typeof(t1) <: BasicSymbolicType + @test typeof(t1) == BasicSymbolicType{Real} + @test t1.f == sin + @test isequal(t1.arguments, [s1]) + @test typeof(t1.arguments) == Vector{Symbolic} +end + +c1 = Const(1) +c2 = Const(3.14) +@testset "Const" begin + @test typeof(c1) <: BasicSymbolicType + @test typeof(c1.val) == Int + @test c1.val == 1 + @test typeof(c2.val) == Float64 + @test c2.val == 3.14 + c3 = Const(big"123456789012345678901234567890") + @test typeof(c3.val) == BigInt + @test c3.val == big"123456789012345678901234567890" + c4 = Const(big"1.23456789012345678901") + @test typeof(c4.val) == BigFloat + @test c4.val == big"1.23456789012345678901" +end + +coeff = c1 +dict = Dict{BasicSymbolicType, Any}(s1 => 3, s2 => 5) +@testset "Add" begin + a1 = Add{Real}(; coeff=coeff, dict=dict) + @test typeof(a1) <: BasicSymbolicType + @test a1.coeff isa BasicSymbolicType + @test isequal(a1.coeff, c1) + @test typeof(a1.dict) == Dict{BasicSymbolicType, Any} + @test a1.dict == dict + @test typeof(a1.arguments) == Vector{BasicSymbolicType} + @test isempty(a1.arguments) + @test typeof(a1.issorted) == Base.RefValue{Bool} + @test !a1.issorted[] +end + +@testset "Mul" begin + m1 = Mul{Real}(; coeff=coeff, dict=dict) + @test typeof(m1) <: BasicSymbolicType + @test m1.coeff isa BasicSymbolicType + @test isequal(m1.coeff, c1) + @test typeof(m1.dict) == Dict{BasicSymbolicType, Any} + @test m1.dict == dict + @test typeof(m1.arguments) == Vector{BasicSymbolicType} + @test isempty(m1.arguments) + @test typeof(m1.issorted) == Base.RefValue{Bool} + @test !m1.issorted[] +end + +@testset "Div" begin + d1 = Div(s1, s2) + @test typeof(d1) <: BasicSymbolicType + @test typeof(d1) == BasicSymbolicType{Float64} + @test isequal(d1.num, s1) + @test isequal(d1.den, s2) + @test typeof(d1.simplified) == Bool + @test !d1.simplified + @test isequal(arguments(d1), [s1, s2]) + d2 = Div{Real}(; num=s1, den=s2) + @test isequal(d2.num, s1) + @test isequal(d2.den, s2) +end + +@testset "Pow" begin + p1 = Pow(s1, s2) + @test typeof(p1) <: BasicSymbolicType + @test isequal(p1.base, s1) + @test isequal(p1.exp, s2) + @test isequal(arguments(p1), [s1, s2]) + p2 = Pow{Real}(; base=s1, exp=s2) + @test isequal(p2.base, s1) + @test isequal(p2.exp, s2) +end + +@testset "BasicSymbolic iszero" begin + c1 = Const(0) + @test SymbolicUtils._iszero(c1) + c2 = Const(1) + @test !SymbolicUtils._iszero(c2) + c3 = Const(0.0) + @test SymbolicUtils._iszero(c3) + c4 = Const(0.00000000000000000000000001) + @test !SymbolicUtils._iszero(c4) + c5 = Const(big"326264532521352634435352152") + @test !SymbolicUtils._iszero(c5) + c6 = Const(big"0.314654523452") + @test !SymbolicUtils._iszero(c6) + s = Sym{Real}(:y) + @test !SymbolicUtils._iszero(s) +end