Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rewrite w/ multiple dispatch #36

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 4 additions & 25 deletions .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,7 @@ jobs:
fail-fast: false
matrix:
test:
- "part1"
- "part2"
- "part3"
- "online"
julia-version:
- "1.10"
- "1"
Expand All @@ -37,32 +35,13 @@ jobs:
include:
- os: windows-latest
julia-version: "1"
test: "part1"
- os: windows-latest
julia-version: "1"
test: "part2"
- os: windows-latest
julia-version: "1"
test: "part3"
test: "online"
- os: macOS-latest
julia-version: "1"
test: "part1"
- os: macOS-latest
julia-version: "1"
test: "part2"
- os: macOS-latest
julia-version: "1"
test: "part3"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "part1"
test: "online"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "part2"
- os: ubuntu-latest
julia-version: "~1.11.0-0"
test: "part3"

test: "online"
steps:
- uses: actions/checkout@v4
- name: "Set up Julia"
Expand Down
52 changes: 14 additions & 38 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,79 +1,55 @@
name = "LibraryAugmentedSymbolicRegression"
uuid = "158930c3-947c-4174-974b-74b39e64a28f"
authors = ["AryaGrayeli <[email protected]>", "AtharvaSehgal <[email protected]>", "MilesCranmer <[email protected]>"]
version = "0.1.0"
version = "0.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
Logging = "56ddb016-857b-54e1-b83d-db4d58db5568"
LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7"
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568"
PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb"
TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76"

[weakdeps]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[extensions]
LaSREnzymeExt = "Enzyme"
LaSRJSON3Ext = "JSON3"
LaSRSymbolicUtilsExt = "SymbolicUtils"

[compat]
ADTypes = "^1.4.0"
Compat = "^4.2"
ConstructionBase = "1.5.7"
Dates = "1"
DifferentiationInterface = "0.5, 0.6"
DispatchDoctor = "0.4"
Compat = "^4.16"
ConstructionBase = "1.0.0 - 1.5.6, 1.5.8 - 1"
DispatchDoctor = "^0.4.17"
Distributed = "<0.0.1, 1"
DynamicExpressions = "0.19.3, 1"
DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14, 1"
Enzyme = "0.12, 0.13"
DynamicExpressions = "~1.8"
DynamicQuantities = "1"
JSON = "0.21"
JSON3 = "1"
LineSearches = "7"
Logging = "1"
LossFunctions = "0.10, 0.11, 1"
MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10, ~1.11"
MacroTools = "0.4, 0.5"
Optim = "~1.8, ~1.9, 1"
PackageExtensionCompat = "1"
Optim = "~1.8, ~1.9, ~1.10"
PackageExtensionCompat = "1.0.2"
Pkg = "<0.0.1, 1"
PrecompileTools = "1"
Printf = "<0.0.1, 1"
ProgressBars = "~1.4, ~1.5"
PromptingTools = "0.53, 0.54, 0.56"
PromptingTools = "0.65 - 0.70"
Random = "<0.0.1, 1"
Reexport = "1"
SpecialFunctions = "0.10.1, 1, 2"
StatsBase = "0.33, 0.34"
SymbolicUtils = "0.19, ^1.0.5, 2, 3"
SymbolicRegression = "1"
TOML = "<0.0.1, 1"
julia = "1.10"

[extras]
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
julia = "1.10"
34 changes: 17 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,16 @@ model = LaSRRegressor(
binary_operators=[+, -, *],
unary_operators=[cos],
llm_options=LLMOptions(
active=true,
weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1),
prompt_evol=true,
prompt_concepts=true,
use_llm=true,
lasr_weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1),
use_concept_evolution=true,
use_concepts=true,
api_key="token-abc123",
prompts_dir="prompts/",
llm_recorder_dir="lasr_runs/debug_0/",
model="meta-llama/Meta-Llama-3-8B-Instruct",
api_kwargs=Dict("url" => "http://localhost:11440/v1"),
var_order=Dict("a" => "angle", "b" => "bias"),
variable_names=Dict("a" => "angle", "b" => "bias"),
llm_context="We believe the function to be a trigonometric function of the angle and a quadratic function of the bias.",
)
)
Expand All @@ -112,27 +112,27 @@ Other than `LLMOptions`, We have the same search options as SymbolicRegression.j
LaSR uses PromptingTools.jl for zero shot prompting. If you wish to make changes to the prompting options, you can pass an `LLMOptions` object to the `LaSRRegressor` constructor. The options available are:
```julia
llm_options = LLMOptions(
active=true, # Whether to use LLM inference or not
weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1), # Probability of using LLM for mutation, crossover, and random generation
use_llm=true, # Whether to use LLM inference or not
lasr_weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1), # Probability of using LLM for mutation, crossover, and random generation
num_pareto_context=5, # Number of equations to sample from the Pareto frontier for summarization.
prompt_evol=true, # Whether to evolve natural language concepts through LLM calls.
prompt_concepts=true, # Whether to use natural language concepts in the search.
use_concept_evolution=true, # Whether to evolve natural language concepts through LLM calls.
use_concepts=true, # Whether to use natural language concepts in the search.
api_key="token-abc123", # API key to OpenAI API compatible server.
model="meta-llama/Meta-Llama-3-8B-Instruct", # LLM model to use.
api_kwargs=Dict("url" => "http://localhost:11440/v1"), # Keyword arguments passed to server.
http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), # Keyword arguments passed to HTTP requests.
prompts_dir="prompts/", # Directory to look for zero shot prompts to the LLM.
llm_recorder_dir="lasr_runs/debug_0/", # Directory to log LLM interactions.
llm_context="", # Natural language concept to start with. You should also be able to initialize with a list of concepts.
var_order=nothing, # Dict(variable_name => new_name).
idea_threshold=30 # Number of concepts to keep track of.
variable_names=nothing, # Dict(variable_name => new_name).
max_concepts=30 # Number of concepts to keep track of.
is_parametric=false, # This is a special flag to allow sampling parametric equations from LaSR. This won't be needed for most users.
)
```

### Best Practices

1. Always make sure you cannot find a satisfactory solution with `active=false` before using LLM guidance.
1. Always make sure you cannot find a satisfactory solution with `use_llm=false` before using LLM guidance.
1. Start with a LLM OpenAI compatible server running on your local machine before moving onto paid services. There are many online resources to set up a local LLM server [1](https://ollama.com/blog/openai-compatibility) [2](https://docs.vllm.ai/en/latest/getting_started/installation.html) [3](https://github.com/sgl-project/sglang?tab=readme-ov-file#backend-sglang-runtime-srt) [4](https://old.reddit.com/r/LocalLLaMA/comments/16y95hk/a_starter_guide_for_playing_with_your_own_local_ai/)
1. If you are using LLM, do a back-of-the-envelope calculation to estimate the cost of running LLM for your problem. Each iteration will make around 60k calls to the LLM model. With the default prompts (in `prompts/`), each call usually requires generating 250 to 1000 tokens. This gives us an upper bound of 60M tokens per iteration if `p=1.00`. Hence, running the model at `p=0.01` for 40 iterations will result in 24M tokens for each equation.

Expand Down Expand Up @@ -215,16 +215,16 @@ model = LaSRRegressor(
binary_operators=[+, -, *],
unary_operators=[cos],
llm_options=LLMOptions(
active=true,
weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1),
prompt_evol=true,
prompt_concepts=true,
use_llm=true,
lasr_weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1),
use_concept_evolution=true,
use_concepts=true,
api_key="token-abc123",
prompts_dir="prompts/",
llm_recorder_dir="lasr_runs/debug_0/",
model="llama3.1:latest",
api_kwargs=Dict("url" => "http://127.0.0.1:11434/v1"),
var_order=Dict("a" => "angle", "b" => "bias"),
variable_names=Dict("a" => "angle", "b" => "bias"),
llm_context="We believe the function to be a trigonometric function of the angle and a quadratic function of the bias."
)
)
Expand Down
14 changes: 0 additions & 14 deletions benchmark/Project.toml

This file was deleted.

67 changes: 0 additions & 67 deletions benchmark/analyze.py

This file was deleted.

Loading
Loading