diff --git a/src/probe/guided_search_iterator.jl b/src/probe/guided_search_iterator.jl index 2d18262..03700cb 100644 --- a/src/probe/guided_search_iterator.jl +++ b/src/probe/guided_search_iterator.jl @@ -4,7 +4,7 @@ symboltable::SymbolTable ) -@kwdef mutable struct GuidedSearchState +Base.@kwdef mutable struct GuidedSearchState level::Int64 bank::Vector{Vector{RuleNode}} eval_cache::Set @@ -17,19 +17,20 @@ function Base.iterate(iter::GuidedSearchIterator) level=-1, bank=[], eval_cache=Set(), - iter=NewProgramsIterator(0, [], iter.grammar), + iter=NewProgramsIterator(0, [], get_grammar(iter.solver)), next_iter=nothing )) end function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Union{Tuple{RuleNode, GuidedSearchState}, Nothing} + grammar = get_grammar(iter.solver) # wrap in while true to optimize for tail call while true while state.next_iter === nothing state.level += 1 push!(state.bank, []) - state.iter = NewProgramsIterator(state.level, state.bank, iter.grammar) + state.iter = NewProgramsIterator(state.level, state.bank, grammar) state.next_iter = iterate(state.iter) if state.level > 0 @info ("Finished level $(state.level - 1) with $(length(state.bank[state.level])) programs") @@ -45,7 +46,7 @@ function Base.iterate(iter::GuidedSearchIterator, state::GuidedSearchState)::Uni # evaluate program eval_observation = [] - expr = rulenode2expr(prog, iter.grammar) + expr = rulenode2expr(prog, grammar) for example ∈ iter.spec output = execute_on_input(iter.symboltable, expr, example.in) push!(eval_observation, output) diff --git a/src/probe/new_program_iterator.jl b/src/probe/new_program_iterator.jl index e4a95df..2c11e71 100644 --- a/src/probe/new_program_iterator.jl +++ b/src/probe/new_program_iterator.jl @@ -67,3 +67,24 @@ function Base.iterate(iter::NewProgramsIterator, state::NewProgramsState) end return nothing end + +function calculate_rule_cost_prob(rule_index, grammar, log_base = 2) + log_prob = grammar.log_probabilities[rule_index] / log(log_base) + return convert(Int64, round(-log_prob)) +end + +function calculate_rule_cost_size(rule_index, grammar) + return 1 +end + +calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = calculate_rule_cost_size(rule_index, grammar) + +""" + calculate_program_cost(program::RuleNode, grammar::ContextSensitiveGrammar) +Calculates the cost of a program by summing up the cost of the children and the cost of the rule +""" +function calculate_program_cost(program::RuleNode, grammar::ContextSensitiveGrammar) + cost_children = sum([calculate_program_cost(child, grammar) for child ∈ program.children], init=0) + cost_rule = calculate_rule_cost(program.ind, grammar) + return cost_children + cost_rule +end \ No newline at end of file diff --git a/src/probe/probe_iterator.jl b/src/probe/probe_iterator.jl index d27674f..523e8ba 100644 --- a/src/probe/probe_iterator.jl +++ b/src/probe/probe_iterator.jl @@ -1,6 +1,3 @@ -include("sum_iterator.jl") -include("new_program_iterator.jl") -include("guided_search_iterator.jl") """ struct ProgramCache @@ -17,15 +14,28 @@ function Base.:(==)(a::ProgramCache, b::ProgramCache) end Base.hash(a::ProgramCache) = hash(a.program) -select(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) = HerbSearch.selectpsol_largest_subset(partial_sols, all_selected_psols) -update!(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample}) = update_grammar(grammar, PSols_with_eval_cache, examples) +include("sum_iterator.jl") +include("new_program_iterator.jl") +include("guided_search_iterator.jl") + +include("select_partial_sols.jl") +include("update_grammar.jl") + +select_partial_solution(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) = HerbSearch.selectpsol_largest_subset(partial_sols, all_selected_psols) +update_grammar!(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample}) = update_grammar(grammar, PSols_with_eval_cache, examples) +""" + probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_time::Int, iteration_size::Int) + +Probe for a solution using the given `iterator` and `examples` with a time limit of `max_time` and `iteration_size`. +""" function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_time::Int, iteration_size::Int) start_time = time() # store a set of all the results of evaluation programs eval_cache = Set() state = nothing - symboltable = SymbolTable(iterator.grammar) + grammar = get_grammar(iterator.solver) + symboltable = SymbolTable(grammar) # all partial solutions that were found so far all_selected_psols = Set{ProgramCache}() # start next iteration while there is time left @@ -40,7 +50,7 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_tim # evaluate program eval_observation = [] correct_examples = Vector{Int}() - expr = rulenode2expr(program, iterator.grammar) + expr = rulenode2expr(program, grammar) for (example_index, example) ∈ enumerate(examples) output = execute_on_input(symboltable, expr, example.in) push!(eval_observation, output) @@ -52,13 +62,13 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_tim nr_correct_examples = length(correct_examples) if nr_correct_examples == length(examples) # found solution - println("Last level: $(length(state.bank[state.level + 1])) programs") + @info "Last level: $(length(state.bank[state.level + 1])) programs" return program elseif eval_observation in eval_cache # result already in cache next = iterate(iterator, state) continue elseif nr_correct_examples >= 1 # partial solution - program_cost = calculate_program_cost(program, iterator.grammar) + program_cost = calculate_program_cost(program, grammar) push!(psol_with_eval_cache, ProgramCache(program, correct_examples, program_cost)) end @@ -72,12 +82,11 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_tim if next === nothing return nothing end - partial_sols = filter(x -> x ∉ all_selected_psols, select(psol_with_eval_cache, all_selected_psols)) + partial_sols = filter(x -> x ∉ all_selected_psols, select_partial_solution(psol_with_eval_cache, all_selected_psols)) if !isempty(partial_sols) - print(rulenode2expr(partial_sols[1].program, iterator.grammar)) push!(all_selected_psols, partial_sols...) # update probabilites if any promising partial solutions - update!(iterator.grammar, partial_sols, examples) # update probabilites + update_grammar!(grammar, partial_sols, examples) # update probabilites # restart iterator eval_cache = Set() state = nothing @@ -85,7 +94,7 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_tim #for loop to update all_selected_psols with new costs for prog_with_cache ∈ all_selected_psols program = prog_with_cache.program - new_cost = calculate_program_cost(program, iterator.grammar) + new_cost = calculate_program_cost(program, grammar) prog_with_cache.cost = new_cost end end @@ -93,153 +102,3 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_tim return nothing end - -function update_grammar(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample}) - sum = 0 - for rule_index in eachindex(grammar.rules) # iterate for each rule_index - highest_correct_nr = 0 - for psol in PSols_with_eval_cache - program = psol.program - len_correct_examples = length(psol.correct_examples) - # check if the program tree has rule_index somewhere inside it using a recursive function - if contains_rule(program, rule_index) && len_correct_examples > highest_correct_nr - highest_correct_nr = len_correct_examples - end - end - fitnes = highest_correct_nr / length(examples) - p_uniform = 1 / length(grammar.rules) - - # compute (log2(p_u) ^ (1 - fit)) = (1-fit) * log2(p_u) - sum += p_uniform^(1 - fitnes) - log_prob = ((1 - fitnes) * log(2, p_uniform)) - grammar.log_probabilities[rule_index] = log_prob - end - total_sum = 0 - for rule_index in eachindex(grammar.rules) - grammar.log_probabilities[rule_index] = grammar.log_probabilities[rule_index] - log(2, sum) - total_sum += 2^(grammar.log_probabilities[rule_index]) - end - @assert abs(total_sum - 1) <= 1e-4 "Total sum is $(total_sum) " -end - -""" - contains_rule(program::RuleNode, rule_index::Int) - -Check if a given `RuleNode` contains has used a derivation rule with the specified `rule_index` - -# Arguments -- `program::RuleNode`: The `RuleNode` to check. -- `rule_index::Int`: The index of the rule to check for. - -""" -function contains_rule(program::RuleNode, rule_index::Int) - if program.ind == rule_index # if the rule is good return true - return true - else - for child in program.children - if contains_rule(child, rule_index) # if a child has that rule then return true - return true - end - end - return false # if no child has that rule return false - end -end - - - -""" - selectpsol_largest_subset(partial_sols::Vector{ProgramCache}}, all_selected_psols::Set{ProgramCache})) - -This scheme selects a single cheapest program (first enumerated) that -satisfies the largest subset of examples encountered so far across all partial_sols. -""" -function selectpsol_largest_subset(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) - if isempty(partial_sols) - return Vector{ProgramCache}() - end - push!(partial_sols, all_selected_psols...) - largest_subset_length = 0 - cost = typemax(Int) - best_sol = partial_sols[begin] - for psol in partial_sols - len = length(psol.correct_examples) - if len > largest_subset_length || len == largest_subset_length && psol.cost < cost - largest_subset_length = len - best_sol = psol - cost = psol.cost - end - end - return [best_sol] -end - -""" - selectpsol_first_cheapest(partial_sols::Vector{ProgramCache}}, all_selected_psols::Set{ProgramCache})) - -This scheme selects a single cheapest program (first enumerated) that -satisfies a unique subset of examples. -""" -function selectpsol_first_cheapest(partial_sols::Vector{ProgramCache}, all_selected_psol::Set{ProgramCache}) - # maps subset of examples to the cheapest program - mapping = Dict{Vector{Int},ProgramCache}() - for sol ∈ partial_sols - examples = sol.correct_examples - if !haskey(mapping, examples) - mapping[examples] = sol - else - # if the cost of the new program is less than the cost of the previous program with the same subset of examples replace it - if sol.cost < mapping[examples].cost - mapping[examples] = sol - end - end - end - # get the cheapest programs that satisfy unique subsets of examples - return collect(values(mapping)) -end - -""" - selectpsol_all_cheapest(partial_sols::Vector{ProgramCache}, all_selected_psol::Set{ProgramCache}) - -This scheme selects all cheapest programs that satisfies a unique subset of examples. -""" -function selectpsol_all_cheapest(partial_sols::Vector{ProgramCache}, all_selected_psol::Set{ProgramCache}) - # maps subset of examples to the cheapest program - mapping = Dict{Vector{Int},Vector{ProgramCache}}() - for sol ∈ partial_sols - examples = sol.correct_examples - if !haskey(mapping, examples) - mapping[examples] = [sol] - else - # if the cost of the new program is less than the cost of the first program - progs = mapping[examples] - if sol.cost < progs[begin].cost - mapping[examples] = [sol] - elseif sol.cost == progs[begin].cost - # append to the list of cheapest programs - push!(progs, sol) - end - end - end - # get all cheapest programs that satisfy unique subsets of examples - return collect(Iterators.flatten(values(mapping))) -end - -function calculate_rule_cost_prob(rule_index, grammar) - log_prob = grammar.log_probabilities[rule_index] - return convert(Int64, round(-log_prob)) -end - -function calculate_rule_cost_size(rule_index, grammar) - return 1 -end - -calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = calculate_rule_cost_size(rule_index, grammar) - -""" - calculate_program_cost(program::RuleNode, grammar::ContextSensitiveGrammar) -Calculates the cost of a program by summing up the cost of the children and the cost of the rule -""" -function calculate_program_cost(program::RuleNode, grammar::ContextSensitiveGrammar) - cost_children = sum([calculate_program_cost(child, grammar) for child ∈ program.children], init=0) - cost_rule = calculate_rule_cost(program.ind, grammar) - return cost_children + cost_rule -end \ No newline at end of file diff --git a/src/probe/select_partial_sols.jl b/src/probe/select_partial_sols.jl new file mode 100644 index 0000000..a039f36 --- /dev/null +++ b/src/probe/select_partial_sols.jl @@ -0,0 +1,75 @@ +""" + selectpsol_largest_subset(partial_sols::Vector{ProgramCache}}, all_selected_psols::Set{ProgramCache})) + +This scheme selects a single cheapest program (first enumerated) that +satisfies the largest subset of examples encountered so far across all partial_sols. +""" +function selectpsol_largest_subset(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) + if isempty(partial_sols) + return Vector{ProgramCache}() + end + push!(partial_sols, all_selected_psols...) + largest_subset_length = 0 + cost = typemax(Int) + best_sol = partial_sols[begin] + for psol in partial_sols + len = length(psol.correct_examples) + if len > largest_subset_length || len == largest_subset_length && psol.cost < cost + largest_subset_length = len + best_sol = psol + cost = psol.cost + end + end + return [best_sol] +end + +""" + selectpsol_first_cheapest(partial_sols::Vector{ProgramCache}}, all_selected_psols::Set{ProgramCache})) + +This scheme selects a single cheapest program (first enumerated) that +satisfies a unique subset of examples. +""" +function selectpsol_first_cheapest(partial_sols::Vector{ProgramCache}, all_selected_psol::Set{ProgramCache}) + # maps subset of examples to the cheapest program + mapping = Dict{Vector{Int},ProgramCache}() + for sol ∈ partial_sols + examples = sol.correct_examples + if !haskey(mapping, examples) + mapping[examples] = sol + else + # if the cost of the new program is less than the cost of the previous program with the same subset of examples replace it + if sol.cost < mapping[examples].cost + mapping[examples] = sol + end + end + end + # get the cheapest programs that satisfy unique subsets of examples + return collect(values(mapping)) +end + +""" + selectpsol_all_cheapest(partial_sols::Vector{ProgramCache}, all_selected_psol::Set{ProgramCache}) + +This scheme selects all cheapest programs that satisfies a unique subset of examples. +""" +function selectpsol_all_cheapest(partial_sols::Vector{ProgramCache}, all_selected_psol::Set{ProgramCache}) + # maps subset of examples to the cheapest program + mapping = Dict{Vector{Int},Vector{ProgramCache}}() + for sol ∈ partial_sols + examples = sol.correct_examples + if !haskey(mapping, examples) + mapping[examples] = [sol] + else + # if the cost of the new program is less than the cost of the first program + progs = mapping[examples] + if sol.cost < progs[begin].cost + mapping[examples] = [sol] + elseif sol.cost == progs[begin].cost + # append to the list of cheapest programs + push!(progs, sol) + end + end + end + # get all cheapest programs that satisfy unique subsets of examples + return collect(Iterators.flatten(values(mapping))) +end \ No newline at end of file diff --git a/src/probe/sum_iterator.jl b/src/probe/sum_iterator.jl index f913922..d4eabab 100644 --- a/src/probe/sum_iterator.jl +++ b/src/probe/sum_iterator.jl @@ -17,7 +17,7 @@ for option ∈ sum_iter end ``` """ -@kwdef struct SumIterator +Base.@kwdef struct SumIterator number_of_elements::Int desired_sum::Int max_value::Int diff --git a/src/probe/update_grammar.jl b/src/probe/update_grammar.jl new file mode 100644 index 0000000..1943f55 --- /dev/null +++ b/src/probe/update_grammar.jl @@ -0,0 +1,62 @@ + +""" + update_grammar(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample}) + +Update the given `grammar` using the provided `PSols_with_eval_cache` and `examples`. + +# Arguments +- `grammar::ContextSensitiveGrammar`: The grammar to be updated. +- `PSols_with_eval_cache::Vector{ProgramCache}`: The program solutions with evaluation cache. +- `examples::Vector{<:IOExample}`: The input-output examples. + +""" +function update_grammar(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample}) + sum = 0 + for rule_index in eachindex(grammar.rules) # iterate for each rule_index + highest_correct_nr = 0 + for psol in PSols_with_eval_cache + program = psol.program + len_correct_examples = length(psol.correct_examples) + # check if the program tree has rule_index somewhere inside it using a recursive function + if contains_rule(program, rule_index) && len_correct_examples > highest_correct_nr + highest_correct_nr = len_correct_examples + end + end + fitnes = highest_correct_nr / length(examples) + p_uniform = 1 / length(grammar.rules) + + # compute (log2(p_u) ^ (1 - fit)) = (1-fit) * log2(p_u) + sum += p_uniform^(1 - fitnes) + log_prob = ((1 - fitnes) * log(2, p_uniform)) + grammar.log_probabilities[rule_index] = log_prob + end + total_sum = 0 + for rule_index in eachindex(grammar.rules) + grammar.log_probabilities[rule_index] = grammar.log_probabilities[rule_index] - log(2, sum) + total_sum += 2^(grammar.log_probabilities[rule_index]) + end + @assert abs(total_sum - 1) <= 1e-4 "Total sum is $(total_sum) " +end + +""" + contains_rule(program::RuleNode, rule_index::Int) + +Check if a given `RuleNode` contains has used a derivation rule with the specified `rule_index` + +# Arguments +- `program::RuleNode`: The `RuleNode` to check. +- `rule_index::Int`: The index of the rule to check for. + +""" +function contains_rule(program::RuleNode, rule_index::Int) + if program.ind == rule_index # if the rule is good return true + return true + else + for child in program.children + if contains_rule(child, rule_index) # if a child has that rule then return true + return true + end + end + return false # if no child has that rule return false + end +end \ No newline at end of file diff --git a/test/test_newprograms.jl b/test/test_newprograms.jl index 9edeb6e..e8f40e9 100644 --- a/test/test_newprograms.jl +++ b/test/test_newprograms.jl @@ -21,7 +21,7 @@ function fast_sol(nr_elements, desired_sum, maximum_value) end return array end -@testset "Test that the sum itearator works" begin +@testset "Test that the sum iterator works" begin @testset "Property based testing" begin max_value = 10 @@ -54,7 +54,6 @@ end # deep copy is needed because the iterator mutates the state in place push!(options, deepcopy(option)) end - println("Options", options) @test options == [ [1, 1, 1, 2], [1, 1, 2, 1], diff --git a/test/test_probe.jl b/test/test_probe.jl index 6420305..254ec5d 100644 --- a/test/test_probe.jl +++ b/test/test_probe.jl @@ -1,12 +1,12 @@ -my_replace(x,y,z) = replace(x,y => z, count = 1) +my_replace(x, y, z) = replace(x, y => z, count=1) grammar = @pcsgrammar begin - 0.188 : S = arg - 0.188 : S = "" - 0.188 : S = "<" - 0.188 : S = ">" - 0.188 : S = my_replace(S,S,S) - 0.059 : S = S * S + 0.188:S = arg + 0.188:S = "" + 0.188:S = "<" + 0.188:S = ">" + 0.188:S = my_replace(S, S, S) + 0.059:S = S * S end @testset "Simulate using the grammar from paper" begin @@ -17,6 +17,10 @@ end execute_on_input(grammar, program, Dict(:arg => "hello")) end end + @testset "Cost probabilities are computed correctly" begin + rule_costs = [HerbSearch.calculate_rule_cost_prob(rule, grammar) for rule ∈ eachindex(grammar.rules)] + @test rule_costs == [2, 2, 2, 2, 2, 4] + end @testset "Selection schemes for partial solutions" begin using HerbSearch: ProgramCache prog1 = RuleNode(1) @@ -27,84 +31,84 @@ end ( HerbSearch.selectpsol_largest_subset, [ - ProgramCache(prog1,[1,2,3,4],100), - ProgramCache(prog2,[1,2,3,4],2), # <- smallest cost solving most examples - ProgramCache(prog3,[1,2,3],1), + ProgramCache(prog1, [1, 2, 3, 4], 100), + ProgramCache(prog2, [1, 2, 3, 4], 2), # <- smallest cost solving most examples + ProgramCache(prog3, [1, 2, 3], 1), ], [prog2] ), ( HerbSearch.selectpsol_first_cheapest, [ - ProgramCache(prog1,[1,2,3,4],100), - ProgramCache(prog2,[1,2,3,4],2), # <- smallest cost solving 4 examples - ProgramCache(prog3,[1,2,3],1), # <- smallest cost solving 3 examples + ProgramCache(prog1, [1, 2, 3, 4], 100), + ProgramCache(prog2, [1, 2, 3, 4], 2), # <- smallest cost solving 4 examples + ProgramCache(prog3, [1, 2, 3], 1), # <- smallest cost solving 3 examples ], [prog2, prog3] ), ( HerbSearch.selectpsol_first_cheapest, [ - ProgramCache(prog1,[1,2,3,4],100), # <- smallest cost solving 4 examples - ProgramCache(prog2,[1,2],2), # <- smallest cost solving 2 examples - ProgramCache(prog3,[1,2,3],1), # <- smallest cost solving 3 examples + ProgramCache(prog1, [1, 2, 3, 4], 100), # <- smallest cost solving 4 examples + ProgramCache(prog2, [1, 2], 2), # <- smallest cost solving 2 examples + ProgramCache(prog3, [1, 2, 3], 1), # <- smallest cost solving 3 examples ], [prog1, prog2, prog3] ), ( HerbSearch.selectpsol_largest_subset, [ - ProgramCache(prog1,[1,2,3,4],100), # <- smallest cost solving 4 examples (but first) - ProgramCache(prog2,[1,2,3,4],100), # <- smallest cost solving 4 examples (but not considered) - ProgramCache(prog3,[1,2],2), + ProgramCache(prog1, [1, 2, 3, 4], 100), # <- smallest cost solving 4 examples (but first) + ProgramCache(prog2, [1, 2, 3, 4], 100), # <- smallest cost solving 4 examples (but not considered) + ProgramCache(prog3, [1, 2], 2), ], [prog1] ), ( HerbSearch.selectpsol_first_cheapest, [ - ProgramCache(prog1,[1,2,3,4],100), # <- smallest cost solving 4 examples (but first) - ProgramCache(prog2,[1,2,3,4],100), # <- smallest cost solving 4 examples (but not considered) - ProgramCache(prog3,[1,2],2), # <- smallest cost solving 2 examples + ProgramCache(prog1, [1, 2, 3, 4], 100), # <- smallest cost solving 4 examples (but first) + ProgramCache(prog2, [1, 2, 3, 4], 100), # <- smallest cost solving 4 examples (but not considered) + ProgramCache(prog3, [1, 2], 2), # <- smallest cost solving 2 examples ], [prog1, prog3] ), ( HerbSearch.selectpsol_all_cheapest, [ - ProgramCache(prog1,[1,2,3,4],100), # <- smallest cost solving 4 examples - ProgramCache(prog2,[1,2,3,4],100), # <- smallest cost solving 4 examples - ProgramCache(prog3,[1,2],2), # <- smallest cost solving 2 examples + ProgramCache(prog1, [1, 2, 3, 4], 100), # <- smallest cost solving 4 examples + ProgramCache(prog2, [1, 2, 3, 4], 100), # <- smallest cost solving 4 examples + ProgramCache(prog3, [1, 2], 2), # <- smallest cost solving 2 examples ], [prog1, prog2, prog3] ), ( HerbSearch.selectpsol_largest_subset, [ - ProgramCache(prog1,[1,2,3,4,5],100), # <- solves most programs - ProgramCache(prog2,[1,2,3,4],2), - ProgramCache(prog3,[1,2,3],1), + ProgramCache(prog1, [1, 2, 3, 4, 5], 100), # <- solves most programs + ProgramCache(prog2, [1, 2, 3, 4], 2), + ProgramCache(prog3, [1, 2, 3], 1), ], [prog1] ), ( HerbSearch.selectpsol_largest_subset, [ - ProgramCache(prog3,[1],1), # only one program + ProgramCache(prog3, [1], 1), # only one program ], [prog3] ), ( HerbSearch.selectpsol_first_cheapest, [ - ProgramCache(prog3,[1],1), # only one program + ProgramCache(prog3, [1], 1), # only one program ], [prog3] ), ( HerbSearch.selectpsol_largest_subset, [ - ProgramCache(prog1,[],1), # no solved examples + ProgramCache(prog1, [], 1), # no solved examples ], [prog1] ), @@ -127,8 +131,8 @@ end [] ) ], - function test_select_function(func_to_call,partial_sols, expected) - partial_sols_filtered = func_to_call(partial_sols) + function test_select_function(func_to_call, partial_sols, expected) + partial_sols_filtered = func_to_call(partial_sols, Set{ProgramCache}()) mapped_to_programs = map(cache -> cache.program, partial_sols_filtered) @test sort(mapped_to_programs) == sort(expected) end @@ -195,44 +199,39 @@ end output = [example.out for example in examples] symboltable = SymbolTable(grammar) - @testset "Running using sized based enumeration" begin - HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_size(rule_index, grammar) - iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, symboltable) - runtime = @timed program = probe(examples, iter, identity, identity, 1, 10000) - expression = rulenode2expr(program, grammar) - @test runtime.time <= 1 - - received = execute_on_input(symboltable, expression, input) - @test output == received + cost_functions = [HerbSearch.calculate_rule_cost_size, HerbSearch.calculate_rule_cost_prob] + select_functions = [HerbSearch.selectpsol_all_cheapest, HerbSearch.selectpsol_first_cheapest, HerbSearch.selectpsol_largest_subset] + uniform_grammar = @pcsgrammar begin + 1:S = arg + 1:S = "" + 1:S = "<" + 1:S = ">" + 1:S = my_replace(S, S, S) + 1:S = S * S end - - @testset "Running using probability based enumeration" begin - # test currently fails.. - examples = [ - IOExample(Dict(:arg => "a < 4 and a > 0"), "a 4 and a 0") - IOExample(Dict(:arg => ""), "open and close") - IOExample(Dict(:arg => " to number"), "Change string to a number") - ] - input = [example.in for example in examples] - output = [example.out for example in examples] - - HerbSearch.calculate_rule_cost(rule_index::Int, grammar::ContextSensitiveGrammar) = HerbSearch.calculate_rule_cost_prob(rule_index, grammar) - @testset "Check rule cost computation" begin - for i in 1:5 - @test HerbSearch.calculate_rule_cost(i, grammar) == 2 + for cost_func ∈ cost_functions + for select_func ∈ select_functions + for grammar_to_use ∈ [uniform_grammar, grammar] + @testset "Uniform grammar is uniform" begin + sum(exp.(grammar.log_probabilities)) ≈ 1 + end + # overwrite calculate cost + HerbSearch.calculate_rule_cost(rule_index::Int, g::ContextSensitiveGrammar) = cost_func(rule_index, g) + # overwrite select function + HerbSearch.select_partial_solution(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) = select_func(partial_sols, all_selected_psols) + + deep_copy_grammar = deepcopy(grammar_to_use) + iter = HerbSearch.GuidedSearchIterator(deep_copy_grammar, :S, examples, symboltable) + max_time = 5 + runtime = @timed program = probe(examples, iter, max_time, 100) + expression = rulenode2expr(program, grammar_to_use) + @test runtime.time <= max_time + + received = execute_on_input(symboltable, expression, input) + @test output == received end - # the rule with S * S should have cost 4 - @test HerbSearch.calculate_rule_cost(6, grammar) == 4 end - iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, symboltable) - runtime = @timed program = probe(examples, iter, identity, identity, 5, 10000) - - expression = rulenode2expr(program, grammar) - @test runtime.time <= 5 - - received = execute_on_input(symboltable, expression, input) - @test output == received end end end \ No newline at end of file