Skip to content

Commit

Permalink
Add support for DomainRuleNode in make_equal!
Browse files Browse the repository at this point in the history
  • Loading branch information
Whebon committed Jun 4, 2024
1 parent ae30b63 commit 8228f13
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 5 deletions.
42 changes: 37 additions & 5 deletions src/makeequal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ Tree manipulation that enforces `node1` == `node2` if unambiguous.
function make_equal!(
solver::Solver,
hole1::Union{RuleNode, AbstractHole},
hole2::Union{RuleNode, AbstractHole}
hole2::Union{RuleNode, AbstractHole, DomainRuleNode}
)::MakeEqualResult
make_equal!(solver, hole1, hole2, Dict{Symbol, AbstractRuleNode}())
end
Expand Down Expand Up @@ -128,9 +128,41 @@ function make_equal!(
end

function make_equal!(
::Solver,
::Union{RuleNode, AbstractHole},
::DomainRuleNode
solver::Solver,
node::Union{RuleNode, AbstractHole},
domainrulenode::DomainRuleNode,
vars::Dict{Symbol, AbstractRuleNode}
)::MakeEqualResult
throw("NotImplementedException: DomainRuleNodes are not yet support in make_equal!")
softfailed = false
if isfilled(node)
#(RuleNode, DomainRuleNode)
if !domainrulenode.domain[get_rule(node)]
set_infeasible!(solver)
return MakeEqualHardFail()

Check warning on line 141 in src/makeequal.jl

View check run for this annotation

Codecov / codecov/patch

src/makeequal.jl#L139-L141

Added lines #L139 - L141 were not covered by tests
end
else
#(AbstractHole, DomainRuleNode)
rules = get_intersection(node.domain, domainrulenode.domain)
if length(rules) == 0
return MakeEqualHardFail()

Check warning on line 147 in src/makeequal.jl

View check run for this annotation

Codecov / codecov/patch

src/makeequal.jl#L147

Added line #L147 was not covered by tests
elseif length(rules) == 1
path = get_path(solver, node)
remove_all_but!(solver, path, rules[1])
node = get_node_at_location(solver, path)
else
softfailed = true
end
end

for (child1, child2) zip(get_children(node), get_children(domainrulenode))
result = make_equal!(solver, child1, child2, vars)
@match result begin
::MakeEqualSuccess => ();
::MakeEqualHardFail => return result;
::MakeEqualSoftFail => begin

Check warning on line 162 in src/makeequal.jl

View check run for this annotation

Codecov / codecov/patch

src/makeequal.jl#L162

Added line #L162 was not covered by tests
softfailed = true
end
end
end
return softfailed ? MakeEqualSoftFail() : MakeEqualSuccess()
end
73 changes: 73 additions & 0 deletions test/test_contains_subtree.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,4 +208,77 @@
end
end
end

@testset "DomainRuleNode" begin
tests = [
(
"SoftFail large domain",
BitVector((0, 0, 0, 1, 1, 1)), # domain_root
BitVector((0, 0, 0, 1, 1, 1)), # domain_root_target

BitVector((1, 1, 1, 0, 0, 0)), # domain_leaf
BitVector((1, 1, 1, 0, 0, 0)), # domain_leaf_target
),
(
"SoftFail small domain",
BitVector((0, 0, 0, 1, 1, 0)), # domain_root
BitVector((0, 0, 0, 1, 1, 0)), # domain_root_target

BitVector((1, 1, 0, 0, 0, 0)), # domain_leaf
BitVector((1, 1, 0, 0, 0, 0)), # domain_leaf_target
),
(
"Deduction in Root",
BitVector((0, 0, 0, 1, 0, 1)), # domain_root
BitVector((0, 0, 0, 1, 0, 0)), # domain_root_target

BitVector((1, 1, 0, 0, 0, 0)), # domain_leaf
BitVector((1, 1, 0, 0, 0, 0)), # domain_leaf_target
),
(
"Deduction in Leaf",
BitVector((0, 0, 0, 1, 1, 0)), # domain_root
BitVector((0, 0, 0, 1, 1, 0)), # domain_root_target

BitVector((0, 1, 1, 0, 0, 0)), # domain_leaf
BitVector((0, 1, 0, 0, 0, 0)), # domain_leaf_target
),
(
"Deduction in Root and Leaf",
BitVector((0, 0, 0, 1, 0, 1)), # domain_root
BitVector((0, 0, 0, 1, 0, 0)), # domain_root_target

BitVector((0, 1, 1, 0, 0, 0)), # domain_leaf
BitVector((0, 1, 0, 0, 0, 0)), # domain_leaf_target
)
]

@testset "$name" for (name, domain_root, domain_root_target, domain_leaf, domain_leaf_target) tests
grammar = @csgrammar begin
S = 1
S = 2
S = 3
S = 4, S
S = 5, S
S = 6, S
end

# must contain at least rule 4 or 5 in the root.
# must contain at least rule 1 or 2 in the leaf.
addconstraint!(grammar, ContainsSubtree(DomainRuleNode(grammar, [4, 5], [
DomainRuleNode(grammar, [1, 2])
])))

tree = UniformHole(domain_root, [
UniformHole(domain_leaf, [])
])
solver = UniformSolver(grammar, tree)
tree = get_tree(solver)

for rule 1:6
@test domain_root_target[rule] == tree.domain[rule]
@test domain_leaf_target[rule] == tree.children[1].domain[rule]
end
end
end
end

0 comments on commit 8228f13

Please sign in to comment.