Skip to content

Commit

Permalink
Merge pull request #53 from Herb-AI/contains-subtree
Browse files Browse the repository at this point in the history
Contains Subtree Constraint
  • Loading branch information
ReubenJ authored May 14, 2024
2 parents 10248f5 + d20e354 commit be8a0b6
Show file tree
Hide file tree
Showing 17 changed files with 800 additions and 35 deletions.
25 changes: 17 additions & 8 deletions src/HerbConstraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,24 +18,28 @@ abstract type AbstractGrammarConstraint <: AbstractConstraint end
abstract type AbstractLocalConstraint <: AbstractConstraint
Abstract type representing all local constraints.
Local constraints correspond to a specific (partial) [`AbstractRuleNode`](@ref) tree.
Each local constraint contains a `path` that points to a specific location in the tree.
The constraint is propagated on any tree manipulation at or below that `path`.
Each local constraint contains a `path` that points to a specific location in the tree at which the constraint applies.
Each local constraint should implement a [`propagate!`](@ref)-function.
Inside the [`propagate!`](@ref) function, the constraint can use the following solver functions:
- `remove!`: Elementary tree manipulation. Removes a value from a domain. (other tree manipulations are: `remove_above!`, `remove_below!`, `remove_all_but!`)
- `deactivate!`: Prevent repropagation. Call this as soon as the constraint is satisfied.
- `set_infeasible!`: Report a non-trivial inconsistency. Call this if the constraint can never be satisfied. An empty domain is considered a trivial inconsistency, such inconsistencies are already handled by tree manipulations.
- `isfeasible`: Check if the current tree is still feasible. Return from the propagate function, as soon as infeasibility is detected.
!!! warning
By default, [`AbstractLocalConstraint`](@ref)s are only propagated once.
Constraints that have to be propagated more frequently should subscribe to an event. This part of the solver is still WIP.
Currently, the solver supports only one type of subscription: `propagate_on_tree_manipulation!`.
"""
abstract type AbstractLocalConstraint <: AbstractConstraint end


"""
function get_priority(::AbstractLocalConstraint)
Used to determine which constraint to propagate first in [`fix_point!`](@ref).
Constraints with fast propagators and/or strong inference should be propagated first.
"""
function get_priority(::AbstractLocalConstraint)
return 0
end

include("csg_annotated/csg_annotated.jl")

include("varnode.jl")
Expand All @@ -57,16 +61,19 @@ include("solver/domainutils.jl")

include("patternmatch.jl")
include("lessthanorequal.jl")
include("makeequal.jl")

include("localconstraints/local_forbidden.jl")
include("localconstraints/local_ordered.jl")
include("localconstraints/local_contains.jl")
include("localconstraints/local_contains_subtree.jl")
include("localconstraints/local_forbidden_sequence.jl")
include("localconstraints/local_unique.jl")

include("grammarconstraints/forbidden.jl")
include("grammarconstraints/ordered.jl")
include("grammarconstraints/contains.jl")
include("grammarconstraints/contains_subtree.jl")
include("grammarconstraints/forbidden_sequence.jl")
include("grammarconstraints/unique.jl")

Expand All @@ -83,13 +90,15 @@ export
Forbidden,
Ordered,
Contains,
ContainsSubtree,
ForbiddenSequence,
Unique,

#local constraints
LocalForbidden,
LocalOrdered,
LocalContains,
LocalContainsSubtree,
LocalForbiddenSequence,
LocalUnique,

Expand Down
31 changes: 31 additions & 0 deletions src/grammarconstraints/contains_subtree.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
ContainsSubtree <: AbstractGrammarConstraint
This [`AbstractGrammarConstraint`] enforces that a given `subtree` appears in the program tree at least once.
!!! warning:
This constraint can only be propagated by the UniformSolver
"""
struct ContainsSubtree <: AbstractGrammarConstraint
tree::AbstractRuleNode
end

function on_new_node(solver::UniformSolver, c::ContainsSubtree, path::Vector{Int})
if length(path) == 0
post!(solver, LocalContainsSubtree(path, c.tree, nothing, nothing))
end
end

function on_new_node(::GenericSolver, ::ContainsSubtree, ::Vector{Int}) end

Check warning on line 19 in src/grammarconstraints/contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/grammarconstraints/contains_subtree.jl#L19

Added line #L19 was not covered by tests

"""
check_tree(c::ContainsSubtree, tree::AbstractRuleNode)::Bool
Checks if the given [`AbstractRuleNode`](@ref) tree abides the [`ContainsSubtree`](@ref) constraint.
"""
function check_tree(c::ContainsSubtree, tree::AbstractRuleNode)::Bool
if pattern_match(c.tree, tree) isa PatternMatchSuccess
return true
end
return any(check_tree(c, child) for child get_children(tree))
end
8 changes: 5 additions & 3 deletions src/lessthanorequal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,6 @@ function make_less_than_or_equal!(
guards::Vector{Tuple{AbstractHole, Int}}
)::LessThanOrEqualResult
@assert isfeasible(solver)
path1 = get_path(get_tree(solver), hole1)
path2 = get_path(get_tree(solver), hole2)
@match (isfilled(hole1), isfilled(hole2)) begin
(true, true) => begin
#(RuleNode, RuleNode)
Expand All @@ -119,6 +117,7 @@ function make_less_than_or_equal!(
return LessThanOrEqualSoftFail(hole2)
end
end
path2 = get_path(solver, hole2)
remove_below!(solver, path2, get_rule(hole1))
if !isfeasible(solver)
return LessThanOrEqualHardFail()
Expand Down Expand Up @@ -151,6 +150,7 @@ function make_less_than_or_equal!(
return LessThanOrEqualSoftFail(hole1)
end
end
path1 = get_path(solver, hole1)
remove_above!(solver, path1, get_rule(hole2))
if !isfeasible(solver)
return LessThanOrEqualHardFail()
Expand Down Expand Up @@ -183,6 +183,8 @@ function make_less_than_or_equal!(
return LessThanOrEqualSoftFail(hole1, hole2)
end
end
path1 = get_path(solver, hole1)
path2 = get_path(solver, hole2)
# Example:
# Before: {2, 3, 5} <= {1, 3, 4}
# After: {2, 3} <= {3, 4}
Expand Down Expand Up @@ -249,7 +251,7 @@ function make_less_than_or_equal!(
return result

Check warning on line 251 in src/lessthanorequal.jl

View check run for this annotation

Codecov / codecov/patch

src/lessthanorequal.jl#L251

Added line #L251 was not covered by tests
elseif length(guards) == 1
# a single guard is involved, preventing equality on the guard prevents the hardfail on the tiebreak
path = get_path(get_tree(solver), guards[1][1])
path = get_path(solver, guards[1][1])
remove!(solver, path, guards[1][2])
return LessThanOrEqualSuccessLessThan()
else
Expand Down
132 changes: 132 additions & 0 deletions src/localconstraints/local_contains_subtree.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@

"""
LocalContains
Enforces that a given `tree` appears at or below the given `path` at least once.
!!! warning:
This is a stateful constraint can only be propagated by the UniformSolver.
The `indices` and `candidates` fields should not be set by the user.
"""
mutable struct LocalContainsSubtree <: AbstractLocalConstraint
path::Vector{Int}
tree::AbstractRuleNode
candidates::Union{Vector{AbstractRuleNode}, Nothing}
indices::Union{StateSparseSet, Nothing}
end

"""
LocalContainsSubtree(path::Vector{Int}, tree::AbstractRuleNode)
Enforces that a given `tree` appears at or below the given `path` at least once.
"""
function LocalContainsSubtree(path::Vector{Int}, tree::AbstractRuleNode)
LocalContainsSubtree(path, tree, Vector{AbstractRuleNode}(), nothing)

Check warning on line 24 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L23-L24

Added lines #L23 - L24 were not covered by tests
end


"""
function propagate!(::GenericSolver, ::LocalContainsSubtree)
!!! warning:
LocalContainsSubtree uses stateful properties and can therefore not be propagated in the GenericSolver.
(The GenericSolver shares constraints among different states, so they cannot use stateful properties)
"""
function propagate!(::GenericSolver, ::LocalContainsSubtree)
throw(ArgumentError("LocalContainsSubtree cannot be propagated by the GenericSolver"))

Check warning on line 36 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L35-L36

Added lines #L35 - L36 were not covered by tests
end


"""
function propagate!(solver::UniformSolver, c::LocalContainsSubtree)
Enforce that the `tree` appears at or below the `path` at least once.
Nodes that can potentially become the target sub-tree are considered `candidates`.
In case of multiple candidates, a stateful set of `indices` is used to keep track of active candidates.
"""
function propagate!(solver::UniformSolver, c::LocalContainsSubtree)
track!(solver, "LocalContainsSubtree propagation")
if isnothing(c.candidates)
# Initial propagation: pattern match all nodes, only store the candidates for re-propagation
c.candidates = Vector{AbstractRuleNode}()
for node get_nodes(solver)
@match pattern_match(c.tree, node) begin
::PatternMatchHardFail => ()
::PatternMatchSuccess => begin
track!(solver, "LocalContainsSubtree satisfied (initial propagation)")
deactivate!(solver, c);

Check warning on line 57 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L55-L57

Added lines #L55 - L57 were not covered by tests
return
end
::PatternMatchSoftFail || ::PatternMatchSuccessWhenHoleAssignedTo => push!(c.candidates, node)
end
end
n = length(c.candidates)
if n == 0
track!(solver, "LocalContainsSubtree inconsistency (initial propagation)")
set_infeasible!(solver)
return
elseif n == 1
@match make_equal!(solver, c.candidates[1], c.tree) begin
::MakeEqualHardFail => begin
@assert false "pattern_match failed to detect a hardfail"

Check warning on line 71 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L70-L71

Added lines #L70 - L71 were not covered by tests
end
::MakeEqualSuccess => begin

Check warning on line 73 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L73

Added line #L73 was not covered by tests
track!(solver, "LocalContainsSubtree deduction (initial)")
deactivate!(solver, c);
return
end
::MakeEqualSoftFail => begin

Check warning on line 78 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L78

Added line #L78 was not covered by tests
track!(solver, "LocalContainsSubtree softfail (1 candidate) (initial)")
return
end
end
else
track!(solver, "LocalContainsSubtree softfail (>=2 candidates) (initial)")
c.indices = StateSparseSet(solver.sm, n)
return
end
else
# Re-propagation
if !isnothing(c.indices) && (size(c.indices) >= 2)
# Update the candidates by pattern matching them again
for i c.indices
match = pattern_match(c.candidates[i], c.tree)
@match match begin
::PatternMatchHardFail => remove!(c.indices, i)
::PatternMatchSuccess => begin
track!(solver, "LocalContainsSubtree satisfied")
deactivate!(solver, c);

Check warning on line 98 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L96-L98

Added lines #L96 - L98 were not covered by tests
return
end
::PatternMatchSoftFail || ::PatternMatchSuccessWhenHoleAssignedTo => ()
end
end
end
n = isnothing(c.indices) ? 1 : size(c.indices)
if n == 1
# If there is a single candidate remaining, set it equal to the target tree
index = isnothing(c.indices) ? 1 : findfirst(c.indices)
@match make_equal!(solver, c.candidates[index], c.tree) begin
::MakeEqualHardFail => begin
track!(solver, "LocalContainsSubtree inconsistency")
set_infeasible!(solver)

Check warning on line 112 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L110-L112

Added lines #L110 - L112 were not covered by tests
return
end
::MakeEqualSuccess => begin

Check warning on line 115 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L115

Added line #L115 was not covered by tests
track!(solver, "LocalContainsSubtree deduction")
deactivate!(solver, c);
return
end
::MakeEqualSoftFail => begin
track!(solver, "LocalContainsSubtree softfail (1 candidate)")

Check warning on line 121 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L120-L121

Added lines #L120 - L121 were not covered by tests
return
end
end
elseif n == 0
track!(solver, "LocalContainsSubtree inconsistency")
set_infeasible!(solver)
return

Check warning on line 128 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L125-L128

Added lines #L125 - L128 were not covered by tests
end
track!(solver, "LocalContainsSubtree softfail (>=2 candidates)")

Check warning on line 130 in src/localconstraints/local_contains_subtree.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_contains_subtree.jl#L130

Added line #L130 was not covered by tests
end
end
7 changes: 7 additions & 0 deletions src/localconstraints/local_forbidden_sequence.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,18 @@ function propagate!(solver::Solver, c::LocalForbiddenSequence)
if (node isa RuleNode) || (node isa StateHole && isfilled(node))
rule = get_rule(node)
if (rule c.ignore_if)
#softfail (ignore if)
return

Check warning on line 97 in src/localconstraints/local_forbidden_sequence.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_forbidden_sequence.jl#L97

Added line #L97 was not covered by tests
elseif (rule == forbidden_rule)
i -= 1
end
elseif isnothing(forbidden_assignment)
for r c.ignore_if
if node.domain[r]
#softfail (ignore if)
return
end
end

Check warning on line 107 in src/localconstraints/local_forbidden_sequence.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_forbidden_sequence.jl#L107

Added line #L107 was not covered by tests
forbidden_assignment = (path_idx, forbidden_rule)
i -= 1
end
Expand Down
47 changes: 39 additions & 8 deletions src/localconstraints/local_unique.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,16 @@
LocalUnique <: AbstractLocalConstraint
Enforces that a given `rule` appears at or below the given `path` at most once.
In case of the UniformSolver, cache the list of `holes`, since no new holes can appear.
"""
struct LocalUnique <: AbstractLocalConstraint
path::Vector{Int}
rule::Int
holes::Vector{AbstractHole}
end

LocalUnique(path::Vector{Int}, rule::Int) = LocalUnique(path, rule, Vector{AbstractHole}())

"""
function propagate!(solver::Solver, c::LocalUnique)
Expand All @@ -17,25 +21,30 @@ Uses a helper function to retrieve a list of holes that can potentially hold the
If there is only a single hole that can potentially hold the target rule, that hole will be filled with that rule.
"""
function propagate!(solver::Solver, c::LocalUnique)
node = get_node_at_location(solver, c.path)
holes = Vector{AbstractHole}()
count = _count_occurrences!(node, c.rule, holes)
track!(solver, "LocalUnique propagation")
if (solver isa GenericSolver) | isempty(c.holes)
empty!(c.holes)
node = get_node_at_location(solver, c.path)
count = _count_occurrences!(node, c.rule, c.holes)
else
#only search for the target rule in the cached list of holes
count = _count_occurrences(c.holes, c.rule)

Check warning on line 31 in src/localconstraints/local_unique.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_unique.jl#L31

Added line #L31 was not covered by tests
end
if count >= 2
set_infeasible!(solver)
track!(solver, "LocalUnique inconsistency")
elseif count == 1
if all(isuniform(hole) for hole holes)
if all(isuniform(hole) for hole c.holes)
track!(solver, "LocalUnique deactivate")
deactivate!(solver, c)

Check warning on line 39 in src/localconstraints/local_unique.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_unique.jl#L38-L39

Added lines #L38 - L39 were not covered by tests
end
for hole holes
for hole c.holes
deductions = 0
if hole.domain[c.rule] == true
path = get_path(get_tree(solver), hole)
if (hole.domain[c.rule] == true) && !isfilled(hole)
path = get_path(solver, hole)
remove!(solver, path, c.rule)
deductions += 1
track!(solver, "LocalUnique deduction ($(deductions))")
track!(solver, "LocalUnique deduction")
end
end
end
Expand Down Expand Up @@ -76,3 +85,25 @@ function _count_occurrences!(node::AbstractRuleNode, rule::Int, holes::Vector{Ab
end
return count
end

"""
function _count_occurrences(holes::Vector{AbstractHole}, rule::Int)
Counts the occurences of the `rule` in the cached list of `holes`.
!!! warning:
Stops counting if the rule occurs more than once.
Counting beyond 2 is not needed for LocalUnique.
"""
function _count_occurrences(holes::Vector{AbstractHole}, rule::Int)
count = 0
for hole holes
if isfilled(hole) && get_rule(hole) == rule
count += 1
if count >= 2
break

Check warning on line 104 in src/localconstraints/local_unique.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_unique.jl#L98-L104

Added lines #L98 - L104 were not covered by tests
end
end
end
count

Check warning on line 108 in src/localconstraints/local_unique.jl

View check run for this annotation

Codecov / codecov/patch

src/localconstraints/local_unique.jl#L107-L108

Added lines #L107 - L108 were not covered by tests
end
Loading

0 comments on commit be8a0b6

Please sign in to comment.