From 69b694a4a9e7d43529ee2f8d9a12535057bb244a Mon Sep 17 00:00:00 2001 From: Atharva Sehgal Date: Sat, 21 Sep 2024 23:21:26 +0000 Subject: [PATCH] seperate out calling variable names to pass static analysis tests --- examples/example_w_llm.jl | 46 +++++++++++++++++++++++++++++++++++++++ src/LLMFunctions.jl | 25 +++++++++------------ 2 files changed, 56 insertions(+), 15 deletions(-) create mode 100644 examples/example_w_llm.jl diff --git a/examples/example_w_llm.jl b/examples/example_w_llm.jl new file mode 100644 index 000000000..f8fa8b5bf --- /dev/null +++ b/examples/example_w_llm.jl @@ -0,0 +1,46 @@ +# This isn't in the automated testing suite since it requires an LLM server running in the background. + +using LibraryAugmentedSymbolicRegression + +X = randn(Float32, 5, 100) +y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 .- 2 + +llm_options = LibraryAugmentedSymbolicRegression.LLMOptions(; + active=true, + weights=LibraryAugmentedSymbolicRegression.LLMWeights(; + llm_mutate=0.01, llm_crossover=0.01, llm_gen_random=0.01 + ), + num_pareto_context=5, + prompt_evol=true, + prompt_concepts=true, + api_key="token-abc123", + model="meta-llama/Meta-Llama-3-8B-Instruct", + api_kwargs=Dict("url" => "http://localhost:11440/v1"), +) + +options = LibraryAugmentedSymbolicRegression.Options(; + binary_operators=[+, *, /, -], + unary_operators=[cos, exp], + populations=20, + llm_options=llm_options, +) + +## The rest of the code is the same as the example.jl file. +hall_of_fame = equation_search( + X, y; niterations=40, options=options, parallelism=:multithreading +) + +dominating = calculate_pareto_frontier(hall_of_fame) + +trees = [member.tree for member in dominating] + +tree = trees[end] +output, did_succeed = eval_tree_array(tree, X, options) + +for member in dominating + complexity = compute_complexity(member, options) + loss = member.loss + string = string_tree(member.tree, options) + + println("$(complexity)\t$(loss)\t$(string)") +end diff --git a/src/LLMFunctions.jl b/src/LLMFunctions.jl index 2557af19e..5286e5c8c 100644 --- a/src/LLMFunctions.jl +++ b/src/LLMFunctions.jl @@ -58,14 +58,7 @@ function convertDict(d)::NamedTuple end function get_vars(options::Options)::String - if !isnothing(options.llm_options) && !isnothing(options.llm_options.var_order) - variable_names = [ - options.llm_options.var_order[key] for - key in sort(collect(keys(options.llm_options.var_order))) - ] - else - variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] - end + variable_names = get_variable_names(options.llm_options.var_order) return join(variable_names, ", ") end @@ -287,18 +280,20 @@ function tree_to_expr( end function tree_to_expr(tree::AbstractExpressionNode{T}, options)::String where {T<:DATA_TYPE} - variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] - if !isnothing(options.llm_options.var_order) - variable_names = [ - options.llm_options.var_order[key] for - key in sort(collect(keys(options.llm_options.var_order))) - ] - end + variable_names = get_variable_names(options.llm_options.var_order) return string_tree( tree, options.operators; f_constant=sketch_const, variable_names=variable_names ) end +function get_variable_names(var_order::Dict)::Vector{String} + return [var_order[key] for key in sort(collect(keys(var_order)))] +end + +function get_variable_names(var_order::Nothing)::Vector{String} + return ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] +end + function handle_not_expr(::Type{T}, x, var_names)::Node{T} where {T<:DATA_TYPE} if x isa Real Node{T}(; val=convert(T, x)) # old: Node(T, 0, true, convert(T,x))