From 867ecb358ffe0fe6533008906ecce4287a7f4bbb Mon Sep 17 00:00:00 2001 From: Reuben Gardos Reid <5456207+ReubenJ@users.noreply.github.com> Date: Tue, 26 Nov 2024 17:53:40 +0300 Subject: [PATCH] Re-add empty-child RuleNode and UniformHole constructors Use functions instead of constructors because otherwise this constitutes type-piracy. --- src/HerbGrammar.jl | 8 +++- src/rulenode_operators.jl | 91 ++++++++++++++++++++++++++++++++++----- 2 files changed, 86 insertions(+), 13 deletions(-) diff --git a/src/HerbGrammar.jl b/src/HerbGrammar.jl index ed42aec..a85d16a 100644 --- a/src/HerbGrammar.jl +++ b/src/HerbGrammar.jl @@ -7,13 +7,14 @@ using Serialization # grammar_io using HerbCore include("grammar_base.jl") -include("rulenode_operators.jl") include("utils.jl") include("nodelocation.jl") include("csg/csg.jl") include("csg/probabilistic_csg.jl") +include("rulenode_operators.jl") + include("grammar_io.jl") export @@ -67,6 +68,9 @@ export read_pcsg, add_rule!, remove_rule!, - cleanup_removed_rules! + cleanup_removed_rules!, + holes_from_child_types, + rulenode_with_empty_children, + uniform_hole_with_empty_children end # module HerbGrammar diff --git a/src/rulenode_operators.jl b/src/rulenode_operators.jl index 4581e30..34f58de 100644 --- a/src/rulenode_operators.jl +++ b/src/rulenode_operators.jl @@ -1,3 +1,72 @@ +""" + holes_from_child_types(index::Integer, grammar::ContextSensitiveGrammar) + +Given the `index` of a rule in a `grammar`, create a vector of [`Hole`](@ref)s +corresponding to the children of the rule at `index`. +""" +function holes_from_child_types(index::Integer, grammar::ContextSensitiveGrammar) + return [Hole(get_domain(grammar, type)) for type in grammar.childtypes[index]] +end + +""" + RuleNode(ind::Int, grammar::ContextSensitiveGrammar) + +Create a [`RuleNode`](@ref) with [`Hole`](@ref)s as children. The holes +are initialized with the types of the children of the rule at `ind`. + +# Examples +```jldoctest +julia> g = @csgrammar begin + A = 1 | 2 | 3 + B = A + A + end +1: A = 1 +2: A = 2 +3: A = 3 +4: B = A + A + + +julia> rulenode_with_empty_children(4, g) +4{hole[Bool[1, 1, 1, 0]],hole[Bool[1, 1, 1, 0]]} +``` +""" +function rulenode_with_empty_children(ind::Int, _val::Union{Any,Nothing}, grammar::ContextSensitiveGrammar) + child_holes = holes_from_child_types(ind, grammar) + return RuleNode(ind, _val, child_holes) +end + +function rulenode_with_empty_children(ind::Int, grammar::ContextSensitiveGrammar) + return rulenode_with_empty_children(ind, nothing, grammar) +end + +""" + uniform_hole_with_empty_children(domain::BitVector, grammar::AbstractGrammar) + +Create a [`UniformHole`](@ref) with [`Hole`](@ref)s as children. The holes +are initialized with the types of the children of the rule at `ind`. + +# Examples +```jldoctest +julia> g = @csgrammar begin + A = 1 | 2 | 3 + B = (A + A) | (A - A) + end +1: A = 1 +2: A = 2 +3: A = 3 +4: B = A + A +5: B = A - A + + +julia> uniform_hole_with_empty_children(BitVector([0, 0, 0, 1, 1]), g) +fshole[Bool[0, 0, 0, 1, 1]]{hole[Bool[1, 1, 1, 0, 0]],hole[Bool[1, 1, 1, 0, 0]]} +``` +""" +function uniform_hole_with_empty_children(domain::BitVector, grammar::AbstractGrammar) + child_holes = holes_from_child_types(findfirst(domain), grammar) + return UniformHole(domain, child_holes) +end + rulesoftype(::Hole, ::Set{Int}) = Set{Int}() """ @@ -74,7 +143,7 @@ end Replace child `i` of a node, a part of larger `expr`, with `new_expr`. """ function swap_node(expr::RuleNode, node::RuleNode, child_index::Int, new_expr::RuleNode) - if expr == node + if expr == node node.children[child_index] = new_expr else for child ∈ expr.children @@ -125,14 +194,14 @@ function rulesonleft(node::RuleNode, path::Vector{Int})::Set{Int} for ch in node.children union!(ruleset, rulesonleft(ch, Vector{Int}())) end - return ruleset + return ruleset elseif length(path) == 1 # if there is only one element left in the path, collect all children except the one indicated in the path ruleset = Set{Int}(get_rule(node)) for i in 1:path[begin]-1 union!(ruleset, rulesonleft(node.children[i], Vector{Int}())) end - return ruleset + return ruleset else # collect all subtrees up to the child indexed in the path ruleset = Set{Int}(get_rule(node)) @@ -140,7 +209,7 @@ function rulesonleft(node::RuleNode, path::Vector{Int})::Set{Int} union!(ruleset, rulesonleft(node.children[i], Vector{Int}())) end union!(ruleset, rulesonleft(node.children[path[begin]], path[2:end])) - return ruleset + return ruleset end end @@ -159,7 +228,7 @@ function rulenode2expr(rulenode::AbstractRuleNode, grammar::AbstractGrammar) end root = deepcopy(grammar.rules[get_rule(rulenode)]) if !grammar.isterminal[get_rule(rulenode)] # not terminal - root,_ = _rulenode2expr(root, rulenode, grammar) + root, _ = _rulenode2expr(root, rulenode, grammar) end return root end @@ -172,15 +241,15 @@ end function _rulenode2expr(expr::Expr, rulenode::AbstractRuleNode, grammar::AbstractGrammar, j=0) if isfilled(rulenode) - for (k,arg) in enumerate(expr.args) + for (k, arg) in enumerate(expr.args) if isa(arg, Expr) - expr.args[k],j = _rulenode2expr(arg, rulenode, grammar, j) + expr.args[k], j = _rulenode2expr(arg, rulenode, grammar, j) elseif haskey(grammar.bytype, arg) child = rulenode.children[j+=1] if isfilled(child) expr.args[k] = deepcopy(grammar.rules[get_rule(child)]) if !isterminal(grammar, child) - expr.args[k],_ = _rulenode2expr(expr.args[k], child, grammar, 0) + expr.args[k], _ = _rulenode2expr(expr.args[k], child, grammar, 0) end else expr.args[k] = _get_hole_type(child, grammar) @@ -199,7 +268,7 @@ function _rulenode2expr(typ::Symbol, rulenode::AbstractRuleNode, grammar::Abstra child = rulenode.children[1] retval = deepcopy(grammar.rules[get_rule(child)]) if !grammar.isterminal[get_rule(child)] - retval,_ = _rulenode2expr(retval, child, grammar, 0) + retval, _ = _rulenode2expr(retval, child, grammar, 0) end end retval, j @@ -222,7 +291,7 @@ rulenode_log_probability(::Hole, ::AbstractGrammar) = 1 Returns true if the expression represented by the [`RuleNode`](@ref) is a complete expression, meaning that it is fully defined and doesn't have any [`Hole`](@ref)s. """ -function iscomplete(grammar::AbstractGrammar, node::RuleNode) +function iscomplete(grammar::AbstractGrammar, node::RuleNode) if isterminal(grammar, node) return true elseif isempty(node.children) @@ -333,7 +402,7 @@ function contains_returntype(node::RuleNode, grammar::AbstractGrammar, sym::Symb return true end for c in node.children - if contains_returntype(c, grammar, sym, maxdepth-1) + if contains_returntype(c, grammar, sym, maxdepth - 1) return true end end