Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Grammar update probabilities #97

Merged
merged 15 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/HerbSearch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,7 @@ export
misclassification,
validate_iterator,
sample,
rand
rand,
probe,
guided_search
nicolaefilat marked this conversation as resolved.
Show resolved Hide resolved
end # module HerbSearch
27 changes: 27 additions & 0 deletions src/getting_started.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@

using HerbGrammar, HerbSpecification, HerbSearch

my_replace(x,y,z) = replace(x,y => z, count = 1)

grammar = @pcsgrammar begin
0.188 : S = arg
0.188 : S = ""
0.188 : S = "<"
0.188 : S = ">"
0.188 : S = my_replace(S,S,S)
0.06 : S = S * S
end

examples = [
IOExample(Dict(:arg => "a < 4 and a > 0"), "a 4 and a 0") # <- e0 with correct space
# IOExample(Dict(:arg => "<open and <close>"), "open and close") # <- e1
IOExample(Dict(:arg => "<<<"), "")
IOExample(Dict(:arg => "<Change> <string> to <a> number"), "Change string to a number")
]

iter = HerbSearch.GuidedSearchIterator(grammar, :S, examples, SymbolTable(grammar))
@profview program = @time probe(examples, iter, 40, 10000)
# program = @time probe(examples, iter, 3600, 10000)

rulenode2expr(program, grammar)

106 changes: 90 additions & 16 deletions src/probe/probe_iterator.jl
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,27 @@
Stores the evaluation cost and the program in a structure.
This
"""
struct ProgramCache
mutable struct ProgramCache
program::RuleNode
correct_examples::Vector{Int}
cost::Int
end
function Base.:(==)(a::ProgramCache, b::ProgramCache)
return a.program == b.program
end
Base.hash(a::ProgramCache) = hash(a.program)

select(partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache}) = HerbSearch.selectpsol_largest_subset(partial_sols, all_selected_psols)
update!(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample}) = update_grammar(grammar,PSols_with_eval_cache, examples)

function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, select::Function, update!::Function, max_time::Int, iteration_size::Int)
function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, max_time::Int, iteration_size::Int)
start_time = time()
# store a set of all the results of evaluation programs
eval_cache = Set()
state = nothing
symboltable = SymbolTable(iterator.grammar)
# all partial solutions that were found so far
all_selected_psols = Set{RuleNode}()
all_selected_psols = Set{ProgramCache}()
# start next iteration while there is time left
while time() - start_time < max_time
i = 1
Expand Down Expand Up @@ -62,33 +69,100 @@ function probe(examples::Vector{<:IOExample}, iterator::ProgramIterator, select:
if next === nothing
return nothing
end
# select promising partial solutions that did not appear before
partial_sols = filter(x -> x.program ∉ all_selected_psols, select(psol_with_eval_cache))
# select promising partial solutions that did not appear before
# if (isempty(all_selected_psols))
# push!(all_selected_psols, psol_with_eval_cache...)
# end
partial_sols = filter(x -> x ∉ all_selected_psols, select(psol_with_eval_cache, all_selected_psols))
if !isempty(partial_sols)
push!(all_selected_psols, map(x -> x.program, partial_sols)...)
print(rulenode2expr(partial_sols[1].program, iterator.grammar))
push!(all_selected_psols, partial_sols...)
# update probabilites if any promising partial solutions
update!(iterator.grammar, partial_sols, examples) # update probabilites
# restart iterator
eval_cache = Set()
state = nothing

#for loop to update all_selected_psols
new_all_selected = Set{ProgramCache}()
eErr0Re marked this conversation as resolved.
Show resolved Hide resolved
for prog_with_cache ∈ all_selected_psols
program = prog_with_cache.program
new_cost = calculate_program_cost(program, iterator.grammar)
prog_with_cache.cost = new_cost
# program_cache = ProgramCache(program, prog_with_cache.correct_examples, cost)
# push!(new_all_selected, program_cache)
end
# all_selected_psols = new_all_selected
end
# # update probabilites if any promising partial solutions
# if !isempty(partial_sols)
# update!(iterator.grammar, partial_sols, eval_cache) # update probabilites
# # restart iterator
# eval_cache = Set()
# state = nothing
# end
end

return nothing
end

function update_grammar(grammar::ContextSensitiveGrammar, PSols_with_eval_cache::Vector{ProgramCache}, examples::Vector{<:IOExample})
sum = 0
for rule_index in eachindex(grammar.rules) # iterate for each rule_index
highest_correct_nr = 0
for psol in PSols_with_eval_cache
program = psol.program
len_correct_examples = length(psol.correct_examples)
# Asume this works
# check if the program tree has rule_index somewhere inside it using a recursive function
if contains_rule(program, rule_index) && len_correct_examples > highest_correct_nr
highest_correct_nr = len_correct_examples
end
end
fitnes = highest_correct_nr / length(examples)
# println("Highest correct examples: $(highest_correct_nr)")
# println("Fitness $(fitnes)")
p_uniform = 1 / length(grammar.rules)

# compute (log2(p_u) ^ (1 - fit)) = (1-fit) * log2(p_u)
sum+=p_uniform^(1-fitnes)
log_prob = ((1 - fitnes) * log(2, p_uniform)) #/Z figure out the Z
grammar.log_probabilities[rule_index] = log_prob
end
for rule_index in eachindex(grammar.rules)
grammar.log_probabilities[rule_index] = grammar.log_probabilities[rule_index] - log(2, sum)
end
println(map(x -> rulenode2expr(x.program, grammar), PSols_with_eval_cache))
for i in 1:6
print("$(grammar.log_probabilities[i]) ")
end
println()
for i in 1:6
print("$(2 ^ (grammar.log_probabilities[i])) ")
end
println()
end

# I will asume this works
function contains_rule(program::RuleNode, rule_index::Int)
if program.ind == rule_index # if the rule is good return true
return true
else
for child in program.children
if contains_rule(child, rule_index) # if a child has that rule then return true
return true
end
end
return false # if no child has that rule return false
end
end



"""
selectpsol_largest_subset(partial_sols::Vector{ProgramCache})

This scheme selects a single cheapest program (first enumerated) that
satisfies the largest subset of examples encountered so far across all partial_sols.
"""
function selectpsol_largest_subset(partial_sols::Vector{ProgramCache})
function selectpsol_largest_subset( partial_sols::Vector{ProgramCache}, all_selected_psols::Set{ProgramCache})
if isempty(partial_sols)
return Vector{ProgramCache}()
end
push!(partial_sols, all_selected_psols...)
largest_subset_length = 0
cost = typemax(Int)
best_sol = partial_sols[begin]
Expand Down Expand Up @@ -124,7 +198,7 @@ function selectpsol_first_cheapest(partial_sols::Vector{ProgramCache})
end
end
# get the cheapest programs that satisfy unique subsets of examples
return values(mapping)
return collect(values(mapping))
end

"""
Expand All @@ -151,7 +225,7 @@ function selectpsol_all_cheapest(partial_sols::Vector{ProgramCache})
end
end
# get all cheapest programs that satisfy unique subsets of examples
return Iterators.flatten(values(mapping))
return collect(Iterators.flatten(values(mapping)))
end

@programiterator GuidedSearchIterator(
Expand Down
Loading