Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement hash consing for Sym #658

Merged
merged 20 commits into from
Nov 7, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
d50294b
Add WeakValueDicts
bowenszhu Oct 14, 2024
7a3d8ba
Import `WeakValueDict`
bowenszhu Oct 16, 2024
e9e702b
Construct SymbolicUtils internal `WeakValueDict`
bowenszhu Oct 16, 2024
62403b9
Change `BasicSymbolic` types to `mutable` due to Julia finalizer
bowenszhu Oct 16, 2024
7e93ca4
Define hash extension function for incorporating symtype
bowenszhu Oct 16, 2024
9be3d03
Hash consing for `Sym`
bowenszhu Oct 16, 2024
93c2837
Merge remote-tracking branch 'origin/master' into hash-consing
bowenszhu Oct 17, 2024
f1a9a93
Test hash consing for `Sym` with different symtypes
bowenszhu Oct 18, 2024
70de918
Feat: Incorporate `metadata` into `BasicSymbolic` hash computation
bowenszhu Oct 25, 2024
2dac2a3
Apply hash consing also when there is metadata
bowenszhu Oct 25, 2024
c2d85c3
Create flyweight factory for `BasicSymbolic`
bowenszhu Oct 25, 2024
8957290
Add `isequal2` function for checking metadata comparison
bowenszhu Nov 5, 2024
2779856
Handle hash collision with customized `isequal2`
bowenszhu Nov 5, 2024
d36198f
Call outer constructor for `Sym` in `ConstructionBase.setproperties`
bowenszhu Nov 5, 2024
cf937b0
Test hash consing for `Sym` with metadata
bowenszhu Nov 5, 2024
765293a
Add docstring for the flyweight factory function
bowenszhu Nov 5, 2024
84a0596
Add docstring for `hash2`
bowenszhu Nov 5, 2024
187ce45
Add comment explaining why calling outer constructor in `setproperties`
bowenszhu Nov 5, 2024
55ca2ec
Refactor: Make `wvd` a constant global
bowenszhu Nov 6, 2024
13b642b
Rename the `isequal2` function to `isequal_with_metadata` for clarity.
bowenszhu Nov 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
TermInterface = "8ea1fca8-c5ef-4a55-8b96-4e9afe9c9a3c"
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
Unityper = "a7c27f48-0311-42f6-a7f8-2c11e75eb415"
WeakValueDicts = "897b6980-f191-5a31-bcb0-bf3c4585e0c1"

[weakdeps]
LabelledArrays = "2ee39098-c373-598a-b85f-a56591580800"
Expand Down Expand Up @@ -57,6 +58,7 @@ SymbolicIndexingInterface = "0.3"
TermInterface = "2.0"
TimerOutputs = "0.5"
Unityper = "0.1.2"
WeakValueDicts = "0.1.0"
julia = "1.3"

[extras]
Expand Down
1 change: 1 addition & 0 deletions src/SymbolicUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import TermInterface: iscall, isexpr, head, children,
operation, arguments, metadata, maketerm, sorted_arguments
# For ReverseDiffExt
import ArrayInterface
using WeakValueDicts: WeakValueDict

Base.@deprecate istree iscall
export istree, operation, arguments, sorted_arguments, iscall
Expand Down
95 changes: 85 additions & 10 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,38 +23,38 @@ const EMPTY_DICT = sdict()
const EMPTY_DICT_T = typeof(EMPTY_DICT)

@compactify show_methods=false begin
@abstract struct BasicSymbolic{T} <: Symbolic{T}
@abstract mutable struct BasicSymbolic{T} <: Symbolic{T}
metadata::Metadata = NO_METADATA
end
struct Sym{T} <: BasicSymbolic{T}
mutable struct Sym{T} <: BasicSymbolic{T}
name::Symbol = :OOF
end
struct Term{T} <: BasicSymbolic{T}
mutable struct Term{T} <: BasicSymbolic{T}
f::Any = identity # base/num if Pow; issorted if Add/Dict
arguments::Vector{Any} = EMPTY_ARGS
hash::RefValue{UInt} = EMPTY_HASH
end
struct Mul{T} <: BasicSymbolic{T}
mutable struct Mul{T} <: BasicSymbolic{T}
coeff::Any = 0 # exp/den if Pow
dict::EMPTY_DICT_T = EMPTY_DICT
hash::RefValue{UInt} = EMPTY_HASH
arguments::Vector{Any} = EMPTY_ARGS
issorted::RefValue{Bool} = NOT_SORTED
end
struct Add{T} <: BasicSymbolic{T}
mutable struct Add{T} <: BasicSymbolic{T}
coeff::Any = 0 # exp/den if Pow
dict::EMPTY_DICT_T = EMPTY_DICT
hash::RefValue{UInt} = EMPTY_HASH
arguments::Vector{Any} = EMPTY_ARGS
issorted::RefValue{Bool} = NOT_SORTED
end
struct Div{T} <: BasicSymbolic{T}
mutable struct Div{T} <: BasicSymbolic{T}
num::Any = 1
den::Any = 1
simplified::Bool = false
arguments::Vector{Any} = EMPTY_ARGS
end
struct Pow{T} <: BasicSymbolic{T}
mutable struct Pow{T} <: BasicSymbolic{T}
base::Any = 1
exp::Any = 1
arguments::Vector{Any} = EMPTY_ARGS
Expand All @@ -77,6 +77,8 @@ function exprtype(x::BasicSymbolic)
end
end

const wvd = WeakValueDict{UInt, BasicSymbolic}()

# Same but different error messages
@noinline error_on_type() = error("Internal error: unreachable reached!")
@noinline error_sym() = error("Sym doesn't have a operation or arguments!")
Expand All @@ -92,7 +94,11 @@ const SIMPLIFIED = 0x01 << 0
function ConstructionBase.setproperties(obj::BasicSymbolic{T}, patch::NamedTuple)::BasicSymbolic{T} where T
nt = getproperties(obj)
nt_new = merge(nt, patch)
Unityper.rt_constructor(obj){T}(;nt_new...)
# Call outer constructor because hash consing cannot be applied in inner constructor
@compactified obj::BasicSymbolic begin
Sym => Sym{T}(nt_new.name; nt_new...)
_ => Unityper.rt_constructor(obj){T}(;nt_new...)
end
end

###
Expand Down Expand Up @@ -265,6 +271,26 @@ function _isequal(a, b, E)
end
end

"""
$(TYPEDSIGNATURES)

Checks for equality between two `BasicSymbolic` objects, considering both their
values and metadata.

The default `Base.isequal` function for `BasicSymbolic` only compares their expressions
and ignores metadata. This does not help deal with hash collisions when metadata is
relevant for distinguishing expressions, particularly in hashing contexts. This function
provides a stricter equality check that includes metadata comparison, preventing
such collisions.

Modifying `Base.isequal` directly breaks numerous tests in `SymbolicUtils.jl` and
downstream packages like `ModelingToolkit.jl`, hence the need for this separate
function.
"""
function isequal2(a::BasicSymbolic, b::BasicSymbolic)::Bool
bowenszhu marked this conversation as resolved.
Show resolved Hide resolved
isequal(a, b) && isequal(metadata(a), metadata(b))
end

Base.one( s::Symbolic) = one( symtype(s))
Base.zero(s::Symbolic) = zero(symtype(s))

Expand Down Expand Up @@ -307,12 +333,61 @@ function Base.hash(s::BasicSymbolic, salt::UInt)::UInt
end
end

"""
$(TYPEDSIGNATURES)

Calculates a hash value for a `BasicSymbolic` object, incorporating both its metadata and
symtype.

This function provides an alternative hashing strategy to `Base.hash` for `BasicSymbolic`
objects. Unlike `Base.hash`, which only considers the expression structure, `hash2` also
includes the metadata and symtype in the hash calculation. This can be beneficial for hash
consing, allowing for more effective deduplication of symbolically equivalent expressions
with different metadata or symtypes.
"""
hash2(s::BasicSymbolic) = hash2(s, zero(UInt))
function hash2(s::BasicSymbolic{T}, salt::UInt)::UInt where {T}
hash(metadata(s), hash(T, hash(s, salt)))
end

###
### Constructors
###

function Sym{T}(name::Symbol; kw...) where T
Sym{T}(; name=name, kw...)
"""
$(TYPEDSIGNATURES)

Implements hash consing (flyweight design pattern) for `BasicSymbolic` objects.

This function checks if an equivalent `BasicSymbolic` object already exists. It uses a
custom hash function (`hash2`) incorporating metadata and symtypes to search for existing
objects in a `WeakValueDict` (`wvd`). Due to the possibility of hash collisions (where
different objects produce the same hash), a custom equality check (`isequal2`) which
includes metadata comparison, is used to confirm the equivalence of objects with matching
hashes. If an equivalent object is found, the existing object is returned; otherwise, the
input `s` is returned. This reduces memory usage, improves compilation time for runtime
code generation, and supports built-in common subexpression elimination, particularly when
working with symbolic objects with metadata.

Using a `WeakValueDict` ensures that only weak references to `BasicSymbolic` objects are
stored, allowing objects that are no longer strongly referenced to be garbage collected.
Custom functions `hash2` and `isequal2` are used instead of `Base.hash` and `Base.isequal`
to accommodate metadata without disrupting existing tests reliant on the original behavior
of those functions.
"""
function BasicSymbolic(s::BasicSymbolic)::BasicSymbolic
h = hash2(s)
t = get!(wvd, h, s)
if t === s || isequal2(t, s)
return t
else
return s
end
end

function Sym{T}(name::Symbol; kw...) where {T}
s = Sym{T}(; name, kw...)
BasicSymbolic(s)
end

function Term{T}(f, args; kw...) where T
Expand Down
9 changes: 8 additions & 1 deletion test/basics.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term
using SymbolicUtils: Symbolic, Sym, FnType, Term, Add, Mul, Pow, symtype, operation, arguments, issym, isterm, BasicSymbolic, term, isequal2
using SymbolicUtils
using IfElse: ifelse
using Setfield
Expand Down Expand Up @@ -336,6 +336,13 @@ end

@test !isequal(a, missing)
@test !isequal(missing, b)

a1 = setmetadata(a, Ctx1, "meta_1")
a2 = setmetadata(a, Ctx1, "meta_1")
a3 = setmetadata(a, Ctx2, "meta_2")
@test !isequal2(a, a1)
@test isequal2(a1, a2)
@test !isequal2(a1, a3)
end

@testset "subtyping" begin
Expand Down
26 changes: 26 additions & 0 deletions test/hash_consing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using SymbolicUtils, Test

struct Ctx1 end
struct Ctx2 end

@testset "Sym" begin
x1 = only(@syms x)
x2 = only(@syms x)
@test x1 === x2
x3 = only(@syms x::Float64)
@test x1 !== x3
x4 = only(@syms x::Float64)
@test x1 !== x4
@test x3 === x4
x5 = only(@syms x::Int)
x6 = only(@syms x::Int)
@test x1 !== x5
@test x3 !== x5
@test x5 === x6

xm1 = setmetadata(x1, Ctx1, "meta_1")
xm2 = setmetadata(x1, Ctx1, "meta_1")
@test xm1 === xm2
xm3 = setmetadata(x1, Ctx2, "meta_2")
@test xm1 !== xm3
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@ using Pkg, Test, SafeTestsets
# Disabled until https://github.com/JuliaMath/SpecialFunctions.jl/issues/446 is fixed
@safetestset "Fuzz" begin include("fuzz.jl") end
@safetestset "Adjoints" begin include("adjoints.jl") end
@safetestset "Hash Consing" begin include("hash_consing.jl") end
end
end
Loading