Skip to content

Commit

Permalink
seperate out calling variable names to pass static analysis tests
Browse files Browse the repository at this point in the history
  • Loading branch information
atharvas committed Sep 21, 2024
1 parent 94d4ba5 commit 69b694a
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 15 deletions.
46 changes: 46 additions & 0 deletions examples/example_w_llm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# This isn't in the automated testing suite since it requires an LLM server running in the background.

using LibraryAugmentedSymbolicRegression

X = randn(Float32, 5, 100)
y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 .- 2

llm_options = LibraryAugmentedSymbolicRegression.LLMOptions(;
active=true,
weights=LibraryAugmentedSymbolicRegression.LLMWeights(;
llm_mutate=0.01, llm_crossover=0.01, llm_gen_random=0.01
),
num_pareto_context=5,
prompt_evol=true,
prompt_concepts=true,
api_key="token-abc123",
model="meta-llama/Meta-Llama-3-8B-Instruct",
api_kwargs=Dict("url" => "http://localhost:11440/v1"),
)

options = LibraryAugmentedSymbolicRegression.Options(;
binary_operators=[+, *, /, -],
unary_operators=[cos, exp],
populations=20,
llm_options=llm_options,
)

## The rest of the code is the same as the example.jl file.
hall_of_fame = equation_search(
X, y; niterations=40, options=options, parallelism=:multithreading
)

dominating = calculate_pareto_frontier(hall_of_fame)

trees = [member.tree for member in dominating]

tree = trees[end]
output, did_succeed = eval_tree_array(tree, X, options)

for member in dominating
complexity = compute_complexity(member, options)
loss = member.loss
string = string_tree(member.tree, options)

println("$(complexity)\t$(loss)\t$(string)")
end
25 changes: 10 additions & 15 deletions src/LLMFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,14 +58,7 @@ function convertDict(d)::NamedTuple
end

function get_vars(options::Options)::String
if !isnothing(options.llm_options) && !isnothing(options.llm_options.var_order)
variable_names = [
options.llm_options.var_order[key] for
key in sort(collect(keys(options.llm_options.var_order)))
]
else
variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"]
end
variable_names = get_variable_names(options.llm_options.var_order)
return join(variable_names, ", ")
end

Expand Down Expand Up @@ -287,18 +280,20 @@ function tree_to_expr(
end

function tree_to_expr(tree::AbstractExpressionNode{T}, options)::String where {T<:DATA_TYPE}
variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"]
if !isnothing(options.llm_options.var_order)
variable_names = [
options.llm_options.var_order[key] for
key in sort(collect(keys(options.llm_options.var_order)))
]
end
variable_names = get_variable_names(options.llm_options.var_order)
return string_tree(
tree, options.operators; f_constant=sketch_const, variable_names=variable_names
)
end

function get_variable_names(var_order::Dict)::Vector{String}
return [var_order[key] for key in sort(collect(keys(var_order)))]
end

function get_variable_names(var_order::Nothing)::Vector{String}
return ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"]
end

function handle_not_expr(::Type{T}, x, var_names)::Node{T} where {T<:DATA_TYPE}
if x isa Real
Node{T}(; val=convert(T, x)) # old: Node(T, 0, true, convert(T,x))
Expand Down

0 comments on commit 69b694a

Please sign in to comment.