diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 5e1e61f3c..fbd8ce3ec 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,9 +26,7 @@ jobs: fail-fast: false matrix: test: - - "part1" - - "part2" - - "part3" + - "online" julia-version: - "1.10" - "1" @@ -37,32 +35,13 @@ jobs: include: - os: windows-latest julia-version: "1" - test: "part1" - - os: windows-latest - julia-version: "1" - test: "part2" - - os: windows-latest - julia-version: "1" - test: "part3" + test: "online" - os: macOS-latest julia-version: "1" - test: "part1" - - os: macOS-latest - julia-version: "1" - test: "part2" - - os: macOS-latest - julia-version: "1" - test: "part3" - - os: ubuntu-latest - julia-version: "~1.11.0-0" - test: "part1" + test: "online" - os: ubuntu-latest julia-version: "~1.11.0-0" - test: "part2" - - os: ubuntu-latest - julia-version: "~1.11.0-0" - test: "part3" - + test: "online" steps: - uses: actions/checkout@v4 - name: "Set up Julia" diff --git a/Project.toml b/Project.toml index 23d850c71..f3b12f7b4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,20 +1,19 @@ name = "LibraryAugmentedSymbolicRegression" uuid = "158930c3-947c-4174-974b-74b39e64a28f" authors = ["AryaGrayeli ", "AtharvaSehgal ", "MilesCranmer "] -version = "0.1.0" +version = "0.2.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" +Logging = "56ddb016-857b-54e1-b83d-db4d58db5568" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" @@ -22,58 +21,35 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb" PackageExtensionCompat = "65ce6f38-6b18-4e1d-a461-8949797d7930" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" -Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" -ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" -SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" +SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" -[weakdeps] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" - -[extensions] -LaSREnzymeExt = "Enzyme" -LaSRJSON3Ext = "JSON3" -LaSRSymbolicUtilsExt = "SymbolicUtils" - [compat] ADTypes = "^1.4.0" -Compat = "^4.2" -ConstructionBase = "1.5.7" -Dates = "1" -DifferentiationInterface = "0.5, 0.6" -DispatchDoctor = "0.4" +Compat = "^4.16" +ConstructionBase = "1.0.0 - 1.5.6, 1.5.8 - 1" +DispatchDoctor = "^0.4.17" Distributed = "<0.0.1, 1" -DynamicExpressions = "0.19.3, 1" -DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14, 1" -Enzyme = "0.12, 0.13" +DynamicExpressions = "~1.8" +DynamicQuantities = "1" JSON = "0.21" -JSON3 = "1" LineSearches = "7" +Logging = "1" LossFunctions = "0.10, 0.11, 1" MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10, ~1.11" MacroTools = "0.4, 0.5" -Optim = "~1.8, ~1.9, 1" -PackageExtensionCompat = "1" +Optim = "~1.8, ~1.9, ~1.10" +PackageExtensionCompat = "1.0.2" Pkg = "<0.0.1, 1" PrecompileTools = "1" -Printf = "<0.0.1, 1" -ProgressBars = "~1.4, ~1.5" -PromptingTools = "0.53, 0.54, 0.56" +PromptingTools = "0.65 - 0.70" Random = "<0.0.1, 1" Reexport = "1" -SpecialFunctions = "0.10.1, 1, 2" StatsBase = "0.33, 0.34" -SymbolicUtils = "0.19, ^1.0.5, 2, 3" +SymbolicRegression = "1" TOML = "<0.0.1, 1" -julia = "1.10" - -[extras] -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" +julia = "1.10" \ No newline at end of file diff --git a/README.md b/README.md index a4044ce5a..326266308 100644 --- a/README.md +++ b/README.md @@ -81,16 +81,16 @@ model = LaSRRegressor( binary_operators=[+, -, *], unary_operators=[cos], llm_options=LLMOptions( - active=true, - weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1), - prompt_evol=true, - prompt_concepts=true, + use_llm=true, + lasr_weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1), + use_concept_evolution=true, + use_concepts=true, api_key="token-abc123", prompts_dir="prompts/", llm_recorder_dir="lasr_runs/debug_0/", model="meta-llama/Meta-Llama-3-8B-Instruct", api_kwargs=Dict("url" => "http://localhost:11440/v1"), - var_order=Dict("a" => "angle", "b" => "bias"), + variable_names=Dict("a" => "angle", "b" => "bias"), llm_context="We believe the function to be a trigonometric function of the angle and a quadratic function of the bias.", ) ) @@ -112,11 +112,11 @@ Other than `LLMOptions`, We have the same search options as SymbolicRegression.j LaSR uses PromptingTools.jl for zero shot prompting. If you wish to make changes to the prompting options, you can pass an `LLMOptions` object to the `LaSRRegressor` constructor. The options available are: ```julia llm_options = LLMOptions( - active=true, # Whether to use LLM inference or not - weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1), # Probability of using LLM for mutation, crossover, and random generation + use_llm=true, # Whether to use LLM inference or not + lasr_weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1), # Probability of using LLM for mutation, crossover, and random generation num_pareto_context=5, # Number of equations to sample from the Pareto frontier for summarization. - prompt_evol=true, # Whether to evolve natural language concepts through LLM calls. - prompt_concepts=true, # Whether to use natural language concepts in the search. + use_concept_evolution=true, # Whether to evolve natural language concepts through LLM calls. + use_concepts=true, # Whether to use natural language concepts in the search. api_key="token-abc123", # API key to OpenAI API compatible server. model="meta-llama/Meta-Llama-3-8B-Instruct", # LLM model to use. api_kwargs=Dict("url" => "http://localhost:11440/v1"), # Keyword arguments passed to server. @@ -124,15 +124,15 @@ llm_options = LLMOptions( prompts_dir="prompts/", # Directory to look for zero shot prompts to the LLM. llm_recorder_dir="lasr_runs/debug_0/", # Directory to log LLM interactions. llm_context="", # Natural language concept to start with. You should also be able to initialize with a list of concepts. - var_order=nothing, # Dict(variable_name => new_name). - idea_threshold=30 # Number of concepts to keep track of. + variable_names=nothing, # Dict(variable_name => new_name). + max_concepts=30 # Number of concepts to keep track of. is_parametric=false, # This is a special flag to allow sampling parametric equations from LaSR. This won't be needed for most users. ) ``` ### Best Practices -1. Always make sure you cannot find a satisfactory solution with `active=false` before using LLM guidance. +1. Always make sure you cannot find a satisfactory solution with `use_llm=false` before using LLM guidance. 1. Start with a LLM OpenAI compatible server running on your local machine before moving onto paid services. There are many online resources to set up a local LLM server [1](https://ollama.com/blog/openai-compatibility) [2](https://docs.vllm.ai/en/latest/getting_started/installation.html) [3](https://github.com/sgl-project/sglang?tab=readme-ov-file#backend-sglang-runtime-srt) [4](https://old.reddit.com/r/LocalLLaMA/comments/16y95hk/a_starter_guide_for_playing_with_your_own_local_ai/) 1. If you are using LLM, do a back-of-the-envelope calculation to estimate the cost of running LLM for your problem. Each iteration will make around 60k calls to the LLM model. With the default prompts (in `prompts/`), each call usually requires generating 250 to 1000 tokens. This gives us an upper bound of 60M tokens per iteration if `p=1.00`. Hence, running the model at `p=0.01` for 40 iterations will result in 24M tokens for each equation. @@ -215,16 +215,16 @@ model = LaSRRegressor( binary_operators=[+, -, *], unary_operators=[cos], llm_options=LLMOptions( - active=true, - weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1), - prompt_evol=true, - prompt_concepts=true, + use_llm=true, + lasr_weights=LLMWeights(llm_mutate=0.1, llm_crossover=0.1, llm_gen_random=0.1), + use_concept_evolution=true, + use_concepts=true, api_key="token-abc123", prompts_dir="prompts/", llm_recorder_dir="lasr_runs/debug_0/", model="llama3.1:latest", api_kwargs=Dict("url" => "http://127.0.0.1:11434/v1"), - var_order=Dict("a" => "angle", "b" => "bias"), + variable_names=Dict("a" => "angle", "b" => "bias"), llm_context="We believe the function to be a trigonometric function of the angle and a quadratic function of the bias." ) ) diff --git a/benchmark/Project.toml b/benchmark/Project.toml deleted file mode 100644 index 6e47cf260..000000000 --- a/benchmark/Project.toml +++ /dev/null @@ -1,14 +0,0 @@ -[deps] -BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" -Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" -CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" -DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" -DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[preferences.DynamicExpressions] -instability_check = "disable" - -[preferences.LibraryAugmentedSymbolicRegression] -instability_check = "disable" diff --git a/benchmark/analyze.py b/benchmark/analyze.py deleted file mode 100644 index 5e20e9297..000000000 --- a/benchmark/analyze.py +++ /dev/null @@ -1,67 +0,0 @@ -import pandas as pd -import numpy as np - - -class Node: - """Simple tree storing data at each node""" - - def __init__(self): - self.data = {} - self.parent = None - self.children = [] - - def __repr__(self): - return f"Node with {len(self.children)} children: " + str(self.data) - - -# Pre-process with: -# Delete first two lines. -# :%s/╎/|/gce -# %g/^/exe "norm v$F|/[0-9]\hr|" -f = open("prof_v5.txt", "r") -lines = f.read().split("\n") -nlines = len(lines) - - -def collect_children(parent, start_line_idx): - - for line_idx in range(start_line_idx, nlines): - - l = lines[line_idx] - l = l.split("|") - indent = len(l) - 1 - same_level = indent == parent.data["indent"] - if same_level: - break - - is_child = indent == parent.data["indent"] + 1 - too_nested = indent > 25 - if is_child and not too_nested: - tokens = l[-1].split() - time = int(tokens[0]) - info = " ".join(tokens[1:]) - new_node = Node() - new_node.data = {"time": time, "info": info, "indent": indent} - new_node.parent = parent - collect_children(new_node, line_idx + 1) - new_node.children = sorted(new_node.children, key=lambda n: -n.data["time"]) - parent.children.append(new_node) - - return - - -root = Node() -root.data = {"time": 0, "info": "", "indent": 4} -collect_children(root, 0) - - -def go_to_level(node, levels): - for level in levels: - node = node.children[level] - return node - - -# Walk through biggest functions: -print(go_to_level(root, [0] * 13 + [1] + [0] * 4)) -print(go_to_level(root, [0] * 15)) -print(go_to_level(root, [0] * 13 + [1] + [0])) diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl deleted file mode 100644 index 71c1404af..000000000 --- a/benchmark/benchmarks.jl +++ /dev/null @@ -1,177 +0,0 @@ -using BenchmarkTools -using LibraryAugmentedSymbolicRegression, BenchmarkTools, Random -using LibraryAugmentedSymbolicRegression.AdaptiveParsimonyModule: RunningSearchStatistics -using LibraryAugmentedSymbolicRegression.PopulationModule: best_of_sample -using LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: optimize_constants -using LibraryAugmentedSymbolicRegression.CheckConstraintsModule: check_constraints -using Bumper, LoopVectorization - -function create_search_benchmark() - suite = BenchmarkGroup() - - n = 1000 - T = Float32 - - extra_kws = NamedTuple() - if hasfield(Options, :turbo) - extra_kws = merge(extra_kws, (turbo=true,)) - end - if hasfield(Options, :save_to_file) - extra_kws = merge(extra_kws, (save_to_file=false,)) - end - if hasfield(Options, :define_helper_functions) - extra_kws = merge(extra_kws, (define_helper_functions=false,)) - end - option_kws = (; - binary_operators=(+, -, /, *), - unary_operators=(exp, abs), - maxsize=30, - verbosity=0, - progress=false, - mutation_weights=MutationWeights(), - loss=(pred, target) -> (pred - target)^2, - extra_kws..., - ) - if hasfield(MutationWeights, :swap_operands) - option_kws.mutation_weights.swap_operands = 0.0 - end - if hasfield(MutationWeights, :form_connection) - option_kws.mutation_weights.form_connection = 0.0 - end - if hasfield(MutationWeights, :break_connection) - option_kws.mutation_weights.break_connection = 0.0 - end - seeds = 1:3 - niterations = 30 - # We create an equation that cannot be found exactly, so the search - # is more realistic. - eqn(x) = Float32(cos(2.13 * x[1]) + 0.5 * x[2] * abs(x[3])^0.9 - 0.3 * abs(x[4])^1.5) - all_options = Dict( - :serial => - [Options(; seed=seed, deterministic=true, option_kws...) for seed in seeds], - :multithreading => - [Options(; seed=seed, deterministic=false, option_kws...) for seed in seeds], - ) - all_X = [rand(MersenneTwister(seed), T, 5, n) .* 10 .- 5 for seed in seeds] - all_y = [ - [eqn(x) for x in eachcol(X)] .+ 0.1f0 .* randn(MersenneTwister(seed + 1), T, n) for - (X, seed) in zip(all_X, seeds) - ] - - for parallelism in (:serial, :multithreading) - # TODO: Add determinism for other parallelisms - function f() - for (options, X, y) in zip(all_options[parallelism], all_X, all_y) - equation_search(X, y; options, parallelism, niterations) - end - end - f() # Warmup - samples = if parallelism == :serial - 5 - else - 10 - end - suite[parallelism] = @benchmarkable( - ($f)(), evals = 1, samples = samples, seconds = 2_000 - ) - end - return suite -end - -function create_utils_benchmark() - suite = BenchmarkGroup() - - options = Options(; unary_operators=[sin, cos], binary_operators=[+, -, *, /]) - - suite["best_of_sample"] = @benchmarkable( - best_of_sample(pop, rss, $options), - setup = ( - nfeatures = 1; - dataset = Dataset(randn(nfeatures, 32), randn(32)); - pop = Population(dataset; npop=100, nlength=20, options=$options, nfeatures); - rss = RunningSearchStatistics(; options=$options) - ) - ) - - ntrees = 10 - suite["optimize_constants_x10"] = @benchmarkable( - foreach(members) do member - optimize_constants(dataset, member, $options) - end, - seconds = 20, - setup = ( - nfeatures = 1; - T = Float64; - dataset = Dataset(randn(nfeatures, 512), randn(512)); - ntrees = $ntrees; - trees = [ - gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:ntrees - ]; - members = [ - PopMember(dataset, tree, $options; deterministic=false) for tree in trees - ] - ) - ) - - ntrees = 10 - suite["compute_complexity_x10"] = let s = BenchmarkGroup() - for T in (Float64, Int, nothing) - options = Options(; - unary_operators=[sin, cos], - binary_operators=[+, -, *, /], - complexity_of_constants=T === nothing ? T : T(1), - ) - s[T] = @benchmarkable( - foreach(trees) do tree - compute_complexity(tree, $options) - end, - setup = ( - T = Float64; - nfeatures = 3; - trees = [ - gen_random_tree_fixed_size(20, $options, nfeatures, T) for - i in 1:($ntrees) - ] - ) - ) - end - s - end - - ntrees = 10 - options = Options(; - unary_operators=[sin, cos], - binary_operators=[+, -, *, /], - maxsize=30, - maxdepth=20, - nested_constraints=[ - (+) => [(/) => 1, (+) => 2], - sin => [sin => 0, cos => 2], - cos => [sin => 0, cos => 0, (+) => 1, (-) => 1], - ], - constraints=[(+) => (-1, 10), (/) => (10, 10), sin => 12, cos => 5], - ) - suite["check_constraints_x10"] = @benchmarkable( - foreach(trees) do tree - check_constraints(tree, $options, $options.maxsize) - end, - setup = ( - T = Float64; - nfeatures = 3; - trees = [ - gen_random_tree_fixed_size(20, $options, nfeatures, T) for i in 1:($ntrees) - ] - ) - ) - - return suite -end - -function create_benchmark() - suite = BenchmarkGroup() - suite["search"] = create_search_benchmark() - suite["utils"] = create_utils_benchmark() - return suite -end - -const SUITE = create_benchmark() diff --git a/benchmark/single_eval.jl b/benchmark/single_eval.jl deleted file mode 100644 index d20db7310..000000000 --- a/benchmark/single_eval.jl +++ /dev/null @@ -1,28 +0,0 @@ -using BenchmarkTools -using LibraryAugmentedSymbolicRegression - -nfeatures = 3 -X = randn(nfeatures, 200) -options = Options(; binary_operators=(+, *, /, -), unary_operators=(cos, sin)) - -x1 = Node("x1") -x2 = Node("x2") -x3 = Node("x3") - -# 48 nodes in this tree: -tree = ( - ((x2 + x2) * ((-0.5982493 / x1) / -0.54734415)) + ( - sin( - cos( - sin(1.2926733 - 1.6606787) / - sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426), - ) * (cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)), - ) / (0.14854191 - ((cos(x2) * -1.6047639) - 0.023943262)) - ) -) - -function testfunc() - out = eval_tree_array(tree, X, options) - return nothing -end -@btime testfunc() diff --git a/examples/example_w_llm.jl b/examples/example_w_llm.jl index 6b8bc51a8..df07916bf 100644 --- a/examples/example_w_llm.jl +++ b/examples/example_w_llm.jl @@ -6,14 +6,14 @@ X = randn(Float32, 5, 100) y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 .- 2 llm_options = LibraryAugmentedSymbolicRegression.LLMOptions(; - active=true, - weights=LLMWeights(; llm_mutate=0.01, llm_crossover=0.01, llm_gen_random=0.01), + use_llm=true, + lasr_weights=LLMWeights(; llm_mutate=0.01, llm_crossover=0.01, llm_gen_random=0.01), promtp_evol=true, - prompt_concepts=true, + use_concepts=true, api_key="token-abc123", model="meta-llama/Meta-Llama-3-8B-Instruct", api_kwargs=Dict("url" => "http://localhost:11440/v1"), - var_order=Dict("a" => "angle", "b" => "bias"), + variable_names=Dict("a" => "angle", "b" => "bias"), ) options = LibraryAugmentedSymbolicRegression.Options(; diff --git a/ext/LaSREnzymeExt.jl b/ext/LaSREnzymeExt.jl deleted file mode 100644 index 1bba1a80f..000000000 --- a/ext/LaSREnzymeExt.jl +++ /dev/null @@ -1,61 +0,0 @@ -module LaSREnzymeExt - -using LibraryAugmentedSymbolicRegression.LossFunctionsModule: eval_loss -using DynamicExpressions: - AbstractExpression, - AbstractExpressionNode, - get_scalar_constants, - set_scalar_constants!, - extract_gradient, - with_contents, - get_contents -using ADTypes: AutoEnzyme -using Enzyme: autodiff, Reverse, Active, Const, Duplicated - -import LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: GradEvaluator - -# We prepare a copy of the tree and all arrays -function GradEvaluator(f::F, backend::AE) where {F,AE<:AutoEnzyme} - storage_tree = copy(f.tree) - _, storage_refs = get_scalar_constants(storage_tree) - storage_dataset = deepcopy(f.dataset) - # TODO: It is super inefficient to deepcopy; how can we skip this - return GradEvaluator(f, backend, (; storage_tree, storage_refs, storage_dataset)) -end - -function evaluator(tree, dataset, options, idx, output) - output[] = eval_loss(tree, dataset, options; regularization=false, idx=idx) - return nothing -end - -with_stacksize(f::F, n) where {F} = fetch(schedule(Task(f, n))) - -function (g::GradEvaluator{<:Any,<:AutoEnzyme})(_, G, x::AbstractVector{T}) where {T} - set_scalar_constants!(g.f.tree, x, g.f.refs) - set_scalar_constants!(g.extra.storage_tree, zero(x), g.extra.storage_refs) - fill!(g.extra.storage_dataset, 0) - - output = [zero(T)] - doutput = [one(T)] - - with_stacksize(32 * 1024 * 1024) do - autodiff( - Reverse, - evaluator, - Duplicated(g.f.tree, g.extra.storage_tree), - Duplicated(g.f.dataset, g.extra.storage_dataset), - Const(g.f.options), - Const(g.f.idx), - Duplicated(output, doutput), - ) - end - - if G !== nothing - # TODO: This is redundant since we already have the references. - # Should just be able to extract from the references directly. - G .= first(get_scalar_constants(g.extra.storage_tree)) - end - return output[] -end - -end diff --git a/ext/LaSRJSON3Ext.jl b/ext/LaSRJSON3Ext.jl deleted file mode 100644 index b44ceb21e..000000000 --- a/ext/LaSRJSON3Ext.jl +++ /dev/null @@ -1,12 +0,0 @@ -module LaSRJSON3Ext - -using JSON3: JSON3 -import LibraryAugmentedSymbolicRegression.UtilsModule: json3_write - -function json3_write(record, recorder_file) - open(recorder_file, "w") do io - JSON3.write(io, record; allow_inf=true) - end -end - -end diff --git a/ext/LaSRSymbolicUtilsExt.jl b/ext/LaSRSymbolicUtilsExt.jl deleted file mode 100644 index 2a7f32cc7..000000000 --- a/ext/LaSRSymbolicUtilsExt.jl +++ /dev/null @@ -1,68 +0,0 @@ -module LaSRSymbolicUtilsExt - -using SymbolicUtils: Symbolic -using LibraryAugmentedSymbolicRegression: - AbstractExpressionNode, AbstractExpression, Node, Options -using LibraryAugmentedSymbolicRegression.MLJInterfaceModule: - AbstractSRRegressor, get_options -using DynamicExpressions: get_tree, get_operators - -import LibraryAugmentedSymbolicRegression: node_to_symbolic, symbolic_to_node - -""" - node_to_symbolic(tree::AbstractExpressionNode, options::Options; kws...) - -Convert an expression to SymbolicUtils.jl form. -""" -function node_to_symbolic( - tree::Union{AbstractExpressionNode,AbstractExpression}, options::Options; kws... -) - return node_to_symbolic(get_tree(tree), get_operators(tree, options); kws...) -end -function node_to_symbolic( - tree::Union{AbstractExpressionNode,AbstractExpression}, m::AbstractSRRegressor; kws... -) - return node_to_symbolic(tree, get_options(m); kws...) -end - -""" - symbolic_to_node(eqn::Symbolic, options::Options; kws...) - -Convert a SymbolicUtils.jl expression to LibraryAugmentedSymbolicRegression.jl's `Node` type. -""" -function symbolic_to_node(eqn::Symbolic, options::Options; kws...) - return symbolic_to_node(eqn, options.operators; kws...) -end -function symbolic_to_node(eqn::Symbolic, m::AbstractSRRegressor; kws...) - return symbolic_to_node(eqn, get_options(m); kws...) -end - -function Base.convert( - ::Type{Symbolic}, - tree::Union{AbstractExpressionNode,AbstractExpression}, - options::Union{Options,Nothing}=nothing; - kws..., -) - return convert(Symbolic, get_tree(tree), get_operators(tree, options); kws...) -end -function Base.convert( - ::Type{Symbolic}, - tree::Union{AbstractExpressionNode,AbstractExpression}, - m::AbstractSRRegressor; - kws..., -) - return convert(Symbolic, tree, get_options(m); kws...) -end - -function Base.convert( - ::Type{N}, x::Union{Number,Symbolic}, options::Options; kws... -) where {N<:Union{AbstractExpressionNode,AbstractExpression}} - return convert(N, x, options.operators; kws...) -end -function Base.convert( - ::Type{N}, x::Union{Number,Symbolic}, m::AbstractSRRegressor; kws... -) where {N<:Union{AbstractExpressionNode,AbstractExpression}} - return convert(N, x, get_options(m); kws...) -end - -end diff --git a/src/AdaptiveParsimony.jl b/src/AdaptiveParsimony.jl deleted file mode 100644 index b438aef4b..000000000 --- a/src/AdaptiveParsimony.jl +++ /dev/null @@ -1,97 +0,0 @@ -module AdaptiveParsimonyModule - -using ..CoreModule: Options, MAX_DEGREE - -""" - RunningSearchStatistics - -A struct to keep track of various running averages of the search and discovered -equations, for use in adaptive losses and parsimony. - -# Fields - -- `window_size::Int`: After this many equations are seen, the frequencies are reduced - by 1, averaged over all complexities, each time a new equation is seen. -- `frequencies::Vector{Float64}`: The number of equations seen at this complexity, - given by the index. -- `normalized_frequencies::Vector{Float64}`: This is the same as `frequencies`, but - normalized to sum to 1.0. This is updated once in a while. -""" -struct RunningSearchStatistics - window_size::Int - frequencies::Vector{Float64} - normalized_frequencies::Vector{Float64} # Stores `frequencies`, but normalized (updated once in a while) -end - -function RunningSearchStatistics(; options::Options, window_size::Int=100000) - maxsize = options.maxsize - actualMaxsize = maxsize + MAX_DEGREE - init_frequencies = ones(Float64, actualMaxsize) - - return RunningSearchStatistics( - window_size, init_frequencies, copy(init_frequencies) / sum(init_frequencies) - ) -end - -""" - update_frequencies!(running_search_statistics::RunningSearchStatistics; size=nothing) - -Update the frequencies in `running_search_statistics` by adding 1 to the frequency -for an equation at size `size`. -""" -@inline function update_frequencies!( - running_search_statistics::RunningSearchStatistics; size=nothing -) - if 0 < size <= length(running_search_statistics.frequencies) - running_search_statistics.frequencies[size] += 1 - end - return nothing -end - -""" - move_window!(running_search_statistics::RunningSearchStatistics) - -Reduce `running_search_statistics.frequencies` until it sums to -`window_size`. -""" -function move_window!(running_search_statistics::RunningSearchStatistics) - smallest_frequency_allowed = 1 - max_loops = 1000 - - frequencies = running_search_statistics.frequencies - window_size = running_search_statistics.window_size - - cur_size_frequency_complexities = sum(frequencies) - if cur_size_frequency_complexities > window_size - difference_in_size = cur_size_frequency_complexities - window_size - # We need frequencyComplexities to be positive, but also sum to a number. - num_loops = 0 - # TODO: Clean this code up. Should not have to have - # loop catching. - while difference_in_size > 0 - indices_to_subtract = findall(frequencies .> smallest_frequency_allowed) - num_remaining = size(indices_to_subtract, 1) - amount_to_subtract = min( - difference_in_size / num_remaining, - min(frequencies[indices_to_subtract]...) - smallest_frequency_allowed, - ) - frequencies[indices_to_subtract] .-= amount_to_subtract - total_amount_to_subtract = amount_to_subtract * num_remaining - difference_in_size -= total_amount_to_subtract - num_loops += 1 - if num_loops > max_loops || total_amount_to_subtract < 1e-6 - # Sometimes, total_amount_to_subtract can be a very very small number. - break - end - end - end - return nothing -end - -function normalize_frequencies!(running_search_statistics::RunningSearchStatistics) - running_search_statistics.normalized_frequencies .= - running_search_statistics.frequencies ./ sum(running_search_statistics.frequencies) - return nothing -end - -end diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl deleted file mode 100644 index 7f6093631..000000000 --- a/src/CheckConstraints.jl +++ /dev/null @@ -1,109 +0,0 @@ -module CheckConstraintsModule - -using DynamicExpressions: - AbstractExpressionNode, AbstractExpression, get_tree, count_depth, tree_mapreduce -using ..CoreModule: Options -using ..ComplexityModule: compute_complexity, past_complexity_limit - -# Check if any binary operator are overly complex -function flag_bin_operator_complexity( - tree::AbstractExpressionNode, op, cons, options::Options -)::Bool - any(tree) do subtree - if subtree.degree == 2 && subtree.op == op - cons[1] > -1 && - past_complexity_limit(subtree.l, options, cons[1]) && - return true - cons[2] > -1 && - past_complexity_limit(subtree.r, options, cons[2]) && - return true - end - return false - end -end - -""" -Check if any unary operators are overly complex. -This assumes you have already checked whether the constraint is > -1. -""" -function flag_una_operator_complexity( - tree::AbstractExpressionNode, op, cons, options::Options -)::Bool - any(tree) do subtree - if subtree.degree == 1 && tree.op == op - past_complexity_limit(subtree.l, options, cons) && return true - end - return false - end -end - -function count_max_nestedness(tree, degree, op) - # TODO: Update this to correctly share nodes - nestedness = tree_mapreduce( - t -> 0, # Leafs - t -> (t.degree == degree && t.op == op) ? 1 : 0, # Branches - (p, c...) -> p + max(c...), # Reduce - tree; - break_sharing=Val(true), - ) - # Remove count of self: - is_self = tree.degree == degree && tree.op == op - return nestedness - (is_self ? 1 : 0) -end - -"""Check if there are any illegal combinations of operators""" -function flag_illegal_nests(tree::AbstractExpressionNode, options::Options)::Bool - # We search from the top first, then from child nodes at end. - (nested_constraints = options.nested_constraints) === nothing && return false - for (degree, op_idx, op_constraint) in nested_constraints - for (nested_degree, nested_op_idx, max_nestedness) in op_constraint - any(tree) do subtree - if subtree.degree == degree && subtree.op == op_idx - nestedness = count_max_nestedness(subtree, nested_degree, nested_op_idx) - return nestedness > max_nestedness - end - return false - end && return true - end - end - return false -end - -"""Check if user-passed constraints are violated or not""" -function check_constraints( - ex::AbstractExpression, - options::Options, - maxsize::Int, - cursize::Union{Int,Nothing}=nothing, -)::Bool - tree = get_tree(ex) - return check_constraints(tree, options, maxsize, cursize) -end -function check_constraints( - tree::AbstractExpressionNode, - options::Options, - maxsize::Int, - cursize::Union{Int,Nothing}=nothing, -)::Bool - ((cursize === nothing) ? compute_complexity(tree, options) : cursize) > maxsize && - return false - count_depth(tree) > options.maxdepth && return false - for i in 1:(options.nbin) - cons = options.bin_constraints[i] - cons == (-1, -1) && continue - flag_bin_operator_complexity(tree, i, cons, options) && return false - end - for i in 1:(options.nuna) - cons = options.una_constraints[i] - cons == -1 && continue - flag_una_operator_complexity(tree, i, cons, options) && return false - end - flag_illegal_nests(tree, options) && return false - return true -end - -check_constraints( - ex::Union{AbstractExpression,AbstractExpressionNode}, options::Options -)::Bool = check_constraints(ex, options, options.maxsize) - -end diff --git a/src/Complexity.jl b/src/Complexity.jl deleted file mode 100644 index dccb05bd3..000000000 --- a/src/Complexity.jl +++ /dev/null @@ -1,60 +0,0 @@ -module ComplexityModule - -using DynamicExpressions: - AbstractExpression, AbstractExpressionNode, get_tree, count_nodes, tree_mapreduce -using ..CoreModule: Options, ComplexityMapping - -function past_complexity_limit( - tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options, limit -)::Bool - return compute_complexity(tree, options) > limit -end - -""" -Compute the complexity of a tree. - -By default, this is the number of nodes in a tree. -However, it could use the custom settings in options.complexity_mapping -if these are defined. -""" -function compute_complexity( - tree::AbstractExpression, options::Options; break_sharing=Val(false) -) - return compute_complexity(get_tree(tree), options; break_sharing) -end -function compute_complexity( - tree::AbstractExpressionNode, options::Options; break_sharing=Val(false) -)::Int - if options.complexity_mapping.use - raw_complexity = _compute_complexity( - tree, options.complexity_mapping; break_sharing - ) - return round(Int, raw_complexity) - else - return count_nodes(tree; break_sharing) - end -end - -function _compute_complexity( - tree::AbstractExpressionNode, cmap::ComplexityMapping{CT}; break_sharing=Val(false) -)::CT where {CT} - return tree_mapreduce( - let vc = cmap.variable_complexity, cc = cmap.constant_complexity - if vc isa AbstractVector - t -> t.constant ? cc : @inbounds(vc[t.feature]) - else - t -> t.constant ? cc : vc - end - end, - let uc = cmap.unaop_complexities, bc = cmap.binop_complexities - t -> t.degree == 1 ? @inbounds(uc[t.op]) : @inbounds(bc[t.op]) - end, - +, - tree, - CT; - break_sharing=break_sharing, - f_on_shared=(result, is_shared) -> is_shared ? result : zero(CT), - ) -end - -end diff --git a/src/Configure.jl b/src/Configure.jl deleted file mode 100644 index 1440ba7ad..000000000 --- a/src/Configure.jl +++ /dev/null @@ -1,345 +0,0 @@ -const TEST_TYPE = Float32 - -function test_operator(op::F, x::T, y=nothing) where {F,T} - local output - try - output = y === nothing ? op(x) : op(x, y) - catch e - error( - "The operator `$(op)` is not well-defined over the " * - ((T <: Complex) ? "complex plane, " : "real line, ") * - "as it threw the error `$(typeof(e))` when evaluating the " * - (y === nothing ? "input $(x). " : "inputs $(x) and $(y). ") * - "You can work around this by returning " * - "NaN for invalid inputs. For example, " * - "`safe_log(x::T) where {T} = x > 0 ? log(x) : T(NaN)`.", - ) - end - if !isa(output, T) - error( - "The operator `$(op)` returned an output of type `$(typeof(output))`, " * - "when it was given " * - (y === nothing ? "an input $(x) " : "inputs $(x) and $(y) ") * - "of type `$(T)`. " * - "Please ensure that your operators return the same type as their inputs.", - ) - end - return nothing -end - -const TEST_INPUTS = collect(range(-100, 100; length=99)) - -function assert_operators_well_defined(T, options::Options) - test_input = if T <: Complex - (x -> convert(T, x)).(TEST_INPUTS .+ TEST_INPUTS .* im) - else - (x -> convert(T, x)).(TEST_INPUTS) - end - for x in test_input, y in test_input, op in options.operators.binops - test_operator(op, x, y) - end - for x in test_input, op in options.operators.unaops - test_operator(op, x) - end -end - -# Check for errors before they happen -function test_option_configuration( - parallelism, datasets::Vector{D}, options::Options, verbosity -) where {T,D<:Dataset{T}} - if options.deterministic && parallelism != :serial - error("Determinism is only guaranteed for serial mode.") - end - if parallelism == :multithreading && Threads.nthreads() == 1 - verbosity > 0 && - @warn "You are using multithreading mode, but only one thread is available. Try starting julia with `--threads=auto`." - end - if any(d -> d.X_units !== nothing || d.y_units !== nothing, datasets) && - options.dimensional_constraint_penalty === nothing - verbosity > 0 && - @warn "You are using dimensional constraints, but `dimensional_constraint_penalty` was not set. The default penalty of `1000.0` will be used." - end - - for op in (options.operators.binops..., options.operators.unaops...) - if is_anonymous_function(op) - throw( - AssertionError( - "Anonymous functions can't be used as operators for LibraryAugmentedSymbolicRegression.jl", - ), - ) - end - end - - assert_operators_well_defined(T, options) - - operator_intersection = intersect(options.operators.binops, options.operators.unaops) - if length(operator_intersection) > 0 - throw( - AssertionError( - "Your configuration is invalid - $(operator_intersection) appear in both the binary operators and unary operators.", - ), - ) - end -end - -# Check for errors before they happen -function test_dataset_configuration( - dataset::Dataset{T}, options::Options, verbosity -) where {T<:DATA_TYPE} - n = dataset.n - if n != size(dataset.X, 2) || - (dataset.y !== nothing && n != size(dataset.y::AbstractArray{T}, 1)) - throw( - AssertionError( - "Dataset dimensions are invalid. Make sure X is of shape [features, rows], y is of shape [rows] and if there are weights, they are of shape [rows].", - ), - ) - end - - if size(dataset.X, 2) > 10000 && !options.batching && verbosity > 0 - @info "Note: you are running with more than 10,000 datapoints. You should consider turning on batching (`options.batching`), and also if you need that many datapoints. Unless you have a large amount of noise (in which case you should smooth your dataset first), generally < 10,000 datapoints is enough to find a functional form." - end - - if !(typeof(options.elementwise_loss) <: SupervisedLoss) && - dataset.weighted && - !(3 in [m.nargs - 1 for m in methods(options.elementwise_loss)]) - throw( - AssertionError( - "When you create a custom loss function, and are using weights, you need to define your loss function with three scalar arguments: f(prediction, target, weight).", - ), - ) - end -end - -""" Move custom operators and loss functions to workers, if undefined """ -function move_functions_to_workers( - procs, options::Options, dataset::Dataset{T}, verbosity -) where {T} - # All the types of functions we need to move to workers: - function_sets = ( - :unaops, :binops, :elementwise_loss, :early_stop_condition, :loss_function - ) - - for function_set in function_sets - if function_set == :unaops - ops = options.operators.unaops - example_inputs = (zero(T),) - elseif function_set == :binops - ops = options.operators.binops - example_inputs = (zero(T), zero(T)) - elseif function_set == :elementwise_loss - if typeof(options.elementwise_loss) <: SupervisedLoss - continue - end - ops = (options.elementwise_loss,) - example_inputs = if dataset.weighted - (zero(T), zero(T), zero(T)) - else - (zero(T), zero(T)) - end - elseif function_set == :early_stop_condition - if !(typeof(options.early_stop_condition) <: Function) - continue - end - ops = (options.early_stop_condition,) - example_inputs = (zero(T), 0) - elseif function_set == :loss_function - if options.loss_function === nothing - continue - end - ops = (options.loss_function,) - example_inputs = (Node(T; val=zero(T)), dataset, options) - else - error("Invalid function set: $function_set") - end - for op in ops - try - test_function_on_workers(example_inputs, op, procs) - catch e - undefined_on_workers = isa(e.captured.ex, UndefVarError) - if undefined_on_workers - copy_definition_to_workers(op, procs, options, verbosity) - else - throw(e) - end - end - test_function_on_workers(example_inputs, op, procs) - end - end -end - -function copy_definition_to_workers(op, procs, options::Options, verbosity) - name = nameof(op) - verbosity > 0 && @info "Copying definition of $op to workers..." - src_ms = methods(op).ms - # Thanks https://discourse.julialang.org/t/easy-way-to-send-custom-function-to-distributed-workers/22118/2 - @everywhere procs @eval function $name end - for m in src_ms - @everywhere procs @eval $m - end - verbosity > 0 && @info "Finished!" - return nothing -end - -function test_function_on_workers(example_inputs, op, procs) - futures = [] - for proc in procs - push!(futures, @spawnat proc op(example_inputs...)) - end - for future in futures - fetch(future) - end -end - -function activate_env_on_workers(procs, project_path::String, options::Options, verbosity) - verbosity > 0 && @info "Activating environment on workers." - @everywhere procs begin - Base.MainInclude.eval( - quote - using Pkg - Pkg.activate($$project_path) - end, - ) - end -end - -function import_module_on_workers(procs, filename::String, options::Options, verbosity) - loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules] - - included_as_local = "LibraryAugmentedSymbolicRegression" ∉ loaded_modules_head_worker - expr = if included_as_local - quote - include($filename) - using .LibraryAugmentedSymbolicRegression - end - else - quote - using LibraryAugmentedSymbolicRegression - end - end - - # Need to import any extension code, if loaded on head node - relevant_extensions = [ - :Bumper, - :CUDA, - :ClusterManagers, - :Enzyme, - :LoopVectorization, - :SymbolicUtils, - :Zygote, - ] - filter!(m -> String(m) ∈ loaded_modules_head_worker, relevant_extensions) - # HACK TODO – this workaround is very fragile. Likely need to submit a bug report - # to JuliaLang. - - for ext in relevant_extensions - push!( - expr.args, - quote - using $ext: $ext - end, - ) - end - - verbosity > 0 && if isempty(relevant_extensions) - @info "Importing LibraryAugmentedSymbolicRegression on workers." - else - @info "Importing LibraryAugmentedSymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))." - end - @everywhere procs Core.eval(Core.Main, $expr) - verbosity > 0 && @info "Finished!" - return nothing -end - -function test_module_on_workers(procs, options::Options, verbosity) - verbosity > 0 && @info "Testing module on workers..." - futures = [] - for proc in procs - push!( - futures, - @spawnat proc LibraryAugmentedSymbolicRegression.gen_random_tree( - 3, options, 5, TEST_TYPE - ) - ) - end - for future in futures - fetch(future) - end - verbosity > 0 && @info "Finished!" - return nothing -end - -function test_entire_pipeline( - procs, dataset::Dataset{T}, options::Options, verbosity -) where {T<:DATA_TYPE} - futures = [] - verbosity > 0 && @info "Testing entire pipeline on workers..." - for proc in procs - push!( - futures, - @spawnat proc begin - tmp_pop = Population( - dataset; - population_size=20, - nlength=3, - options=options, - nfeatures=dataset.nfeatures, - ) - tmp_pop = s_r_cycle( - dataset, - tmp_pop, - 5, - 5, - RunningSearchStatistics(; options=options); - verbosity=verbosity, - options=options, - record=RecordType(), - )[1] - tmp_pop = optimize_and_simplify_population( - dataset, tmp_pop, options, options.maxsize, RecordType() - ) - end - ) - end - for future in futures - fetch(future) - end - verbosity > 0 && @info "Finished!" - return nothing -end - -function configure_workers(; - procs::Union{Vector{Int},Nothing}, - numprocs::Int, - addprocs_function::Function, - options::Options, - project_path, - file, - exeflags::Cmd, - verbosity, - example_dataset::Dataset, - runtests::Bool, -) - (procs, we_created_procs) = if procs === nothing - (addprocs_function(numprocs; lazy=false, exeflags), true) - else - (procs, false) - end - - if we_created_procs - if VERSION < v"1.9.0" - # On newer Julia; environment is activated automatically - activate_env_on_workers(procs, project_path, options, verbosity) - end - import_module_on_workers(procs, file, options, verbosity) - end - - move_functions_to_workers(procs, options, example_dataset, verbosity) - - if runtests - test_module_on_workers(procs, options, verbosity) - test_entire_pipeline(procs, example_dataset, options, verbosity) - end - - return (procs, we_created_procs) -end diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl deleted file mode 100644 index fe66b4f5d..000000000 --- a/src/ConstantOptimization.jl +++ /dev/null @@ -1,136 +0,0 @@ -module ConstantOptimizationModule - -using LineSearches: LineSearches -using Optim: Optim -using ADTypes: AbstractADType, AutoEnzyme -using DifferentiationInterface: value_and_gradient -using DynamicExpressions: - AbstractExpression, - Expression, - count_scalar_constants, - get_scalar_constants, - set_scalar_constants!, - extract_gradient -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options -using ..UtilsModule: get_birth_order -using ..LossFunctionsModule: eval_loss, loss_to_score, batch_sample -using ..PopMemberModule: PopMember - -function optimize_constants( - dataset::Dataset{T,L}, member::P, options::Options -)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} - if options.batching - dispatch_optimize_constants( - dataset, member, options, batch_sample(dataset, options) - ) - else - dispatch_optimize_constants(dataset, member, options, nothing) - end -end -function dispatch_optimize_constants( - dataset::Dataset{T,L}, member::P, options::Options, idx -) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} - nconst = count_constants_for_optimization(member.tree) - nconst == 0 && return (member, 0.0) - if nconst == 1 && !(T <: Complex) - algorithm = Optim.Newton(; linesearch=LineSearches.BackTracking()) - return _optimize_constants( - dataset, - member, - specialized_options(options), - algorithm, - options.optimizer_options, - idx, - ) - end - return _optimize_constants( - dataset, - member, - specialized_options(options), - # We use specialized options here due to Enzyme being - # more particular about dynamic dispatch - options.optimizer_algorithm, - options.optimizer_options, - idx, - ) -end - -"""How many constants will be optimized.""" -count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex) - -function _optimize_constants( - dataset, member::P, options, algorithm, optimizer_options, idx -)::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}} - tree = member.tree - eval_fraction = options.batching ? (options.batch_size / dataset.n) : 1.0 - x0, refs = get_scalar_constants(tree) - @assert count_constants_for_optimization(tree) == length(x0) - f = Evaluator(tree, refs, dataset, options, idx) - fg! = GradEvaluator(f, options.autodiff_backend) - obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing - f - else - Optim.only_fg!(fg!) - end - baseline = f(x0) - result = Optim.optimize(obj, x0, algorithm, optimizer_options) - num_evals = result.f_calls * eval_fraction - # Try other initial conditions: - for _ in 1:(options.optimizer_nrestarts) - eps = randn(T, size(x0)...) - xt = @. x0 * (T(1) + T(1//2) * eps) - tmpresult = Optim.optimize(obj, xt, algorithm, optimizer_options) - num_evals += tmpresult.f_calls * eval_fraction - # TODO: Does this need to take into account h_calls? - - if tmpresult.minimum < result.minimum - result = tmpresult - end - end - - if result.minimum < baseline - member.tree = tree - member.loss = f(result.minimizer; regularization=true) - member.score = loss_to_score( - member.loss, dataset.use_baseline, dataset.baseline_loss, member, options - ) - member.birth = get_birth_order(; deterministic=options.deterministic) - num_evals += eval_fraction - else - set_scalar_constants!(member.tree, x0, refs) - end - - return member, num_evals -end - -struct Evaluator{N<:AbstractExpression,R,D<:Dataset,O<:Options,I} <: Function - tree::N - refs::R - dataset::D - options::O - idx::I -end -function (e::Evaluator)(x::AbstractVector; regularization=false) - set_scalar_constants!(e.tree, x, e.refs) - return eval_loss(e.tree, e.dataset, e.options; regularization, e.idx) -end -struct GradEvaluator{F<:Evaluator,AD<:Union{Nothing,AbstractADType},EX} <: Function - f::F - backend::AD - extra::EX -end -GradEvaluator(f::F, backend::AD) where {F,AD} = GradEvaluator(f, backend, nothing) - -function (g::GradEvaluator{<:Any,AD})(_, G, x::AbstractVector) where {AD} - AD isa AutoEnzyme && error("Please load the `Enzyme.jl` package.") - set_scalar_constants!(g.f.tree, x, g.f.refs) - (val, grad) = value_and_gradient(g.backend, g.f.tree) do tree - eval_loss(tree, g.f.dataset, g.f.options; regularization=false, idx=g.f.idx) - end - if G !== nothing && grad !== nothing - G .= extract_gradient(grad, g.f.tree) - end - return val -end - -end diff --git a/src/Core.jl b/src/Core.jl index 63860b0f7..493bda8e5 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -3,44 +3,10 @@ module CoreModule function create_expression end include("Utils.jl") -include("ProgramConstants.jl") -include("Dataset.jl") include("MutationWeights.jl") include("LLMOptions.jl") -include("OptionsStruct.jl") -include("Operators.jl") -include("Options.jl") -using .ProgramConstantsModule: - MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE -using .DatasetModule: Dataset -using .MutationWeightsModule: MutationWeights, sample_mutation -using .LLMOptionsModule: LLMOptions, LLMWeights -using .OptionsStructModule: Options, ComplexityMapping, specialized_options -using .OptionsModule: Options, binopmap, unaopmap -using .OperatorsModule: - plus, - sub, - mult, - square, - cube, - pow, - safe_pow, - safe_log, - safe_log2, - safe_log10, - safe_log1p, - safe_sqrt, - safe_acosh, - neg, - greater, - cond, - relu, - logical_or, - logical_and, - gamma, - erf, - erfc, - atanh_clip +using .LaSRMutationWeightsModule: LLMMutationProbabilities, LaSRMutationWeights +using .LLMOptionsModule: LLMOperationWeights, LLMOptions, LaSROptions end diff --git a/src/Dataset.jl b/src/Dataset.jl deleted file mode 100644 index 99c31ee3d..000000000 --- a/src/Dataset.jl +++ /dev/null @@ -1,275 +0,0 @@ -module DatasetModule - -using DynamicQuantities: Quantity - -using ..UtilsModule: subscriptify, get_base_type, @constfield -using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE -using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units - -import ...deprecate_varmap - -""" - Dataset{T<:DATA_TYPE,L<:LOSS_TYPE} - -# Fields - -- `X::AbstractMatrix{T}`: The input features, with shape `(nfeatures, n)`. -- `y::AbstractVector{T}`: The desired output values, with shape `(n,)`. -- `index::Int`: The index of the output feature corresponding to this - dataset, if any. -- `n::Int`: The number of samples. -- `nfeatures::Int`: The number of features. -- `weighted::Bool`: Whether the dataset is non-uniformly weighted. -- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted, - these specify the per-sample weight (with shape `(n,)`). -- `extra::NamedTuple`: Extra information to pass to a custom evaluation - function. Since this is an arbitrary named tuple, you could pass - any sort of dataset you wish to here. -- `avg_y`: The average value of `y` (weighted, if `weights` are passed). -- `use_baseline`: Whether to use a baseline loss. This will be set to `false` - if the baseline loss is calculated to be `Inf`. -- `baseline_loss`: The loss of a constant function which predicts the average - value of `y`. This is loss-dependent and should be updated with - `update_baseline_loss!`. -- `variable_names::Array{String,1}`: The names of the features, - with shape `(nfeatures,)`. -- `display_variable_names::Array{String,1}`: A version of `variable_names` - but for printing to the terminal (e.g., with unicode versions). -- `y_variable_name::String`: The name of the output variable. -- `X_units`: Unit information of `X`. When used, this is a vector - of `DynamicQuantities.Quantity{<:Any,<:Dimensions}` with shape `(nfeatures,)`. -- `y_units`: Unit information of `y`. When used, this is a single - `DynamicQuantities.Quantity{<:Any,<:Dimensions}`. -- `X_sym_units`: Unit information of `X`. When used, this is a vector - of `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}` with shape `(nfeatures,)`. -- `y_sym_units`: Unit information of `y`. When used, this is a single - `DynamicQuantities.Quantity{<:Any,<:SymbolicDimensions}`. -""" -mutable struct Dataset{ - T<:DATA_TYPE, - L<:LOSS_TYPE, - AX<:AbstractMatrix{T}, - AY<:Union{AbstractVector{T},Nothing}, - AW<:Union{AbstractVector{T},Nothing}, - NT<:NamedTuple, - XU<:Union{AbstractVector{<:Quantity},Nothing}, - YU<:Union{Quantity,Nothing}, - XUS<:Union{AbstractVector{<:Quantity},Nothing}, - YUS<:Union{Quantity,Nothing}, -} - @constfield X::AX - @constfield y::AY - @constfield index::Int - @constfield n::Int - @constfield nfeatures::Int - @constfield weighted::Bool - @constfield weights::AW - @constfield extra::NT - @constfield avg_y::Union{T,Nothing} - use_baseline::Bool - baseline_loss::L - @constfield variable_names::Array{String,1} - @constfield display_variable_names::Array{String,1} - @constfield y_variable_name::String - @constfield X_units::XU - @constfield y_units::YU - @constfield X_sym_units::XUS - @constfield y_sym_units::YUS -end - -""" - Dataset(X::AbstractMatrix{T}, - y::Union{AbstractVector{T},Nothing}=nothing, - loss_type::Type=Nothing; - weights::Union{AbstractVector{T}, Nothing}=nothing, - variable_names::Union{Array{String, 1}, Nothing}=nothing, - y_variable_name::Union{String,Nothing}=nothing, - extra::NamedTuple=NamedTuple(), - X_units::Union{AbstractVector, Nothing}=nothing, - y_units=nothing, - ) where {T<:DATA_TYPE} - -Construct a dataset to pass between internal functions. -""" -function Dataset( - X::AbstractMatrix{T}, - y::Union{AbstractVector{T},Nothing}=nothing, - loss_type::Type{L}=Nothing; - index::Int=1, - weights::Union{AbstractVector{T},Nothing}=nothing, - variable_names::Union{Array{String,1},Nothing}=nothing, - display_variable_names=variable_names, - y_variable_name::Union{String,Nothing}=nothing, - extra::NamedTuple=NamedTuple(), - X_units::Union{AbstractVector,Nothing}=nothing, - y_units=nothing, - # Deprecated: - varMap=nothing, - kws..., -) where {T<:DATA_TYPE,L} - Base.require_one_based_indexing(X) - y !== nothing && Base.require_one_based_indexing(y) - # Deprecation warning: - variable_names = deprecate_varmap(variable_names, varMap, :Dataset) - if haskey(kws, :loss_type) - Base.depwarn( - "The `loss_type` keyword argument is deprecated. Pass as an argument instead.", - :Dataset, - ) - return Dataset( - X, - y, - kws[:loss_type]; - index, - weights, - variable_names, - display_variable_names, - y_variable_name, - extra, - X_units, - y_units, - ) - end - - n = size(X, BATCH_DIM) - nfeatures = size(X, FEATURE_DIM) - weighted = weights !== nothing - variable_names = if variable_names === nothing - ["x$(i)" for i in 1:nfeatures] - else - variable_names - end - display_variable_names = if display_variable_names === nothing - ["x$(subscriptify(i))" for i in 1:nfeatures] - else - display_variable_names - end - - y_variable_name = if y_variable_name === nothing - ("y" ∉ variable_names) ? "y" : "target" - else - y_variable_name - end - avg_y = if y === nothing - nothing - else - if weighted - sum(y .* weights) / sum(weights) - else - sum(y) / n - end - end - out_loss_type = if L === Nothing - T <: Complex ? get_base_type(T) : T - else - L - end - - use_baseline = true - baseline = one(out_loss_type) - y_si_units = get_si_units(T, y_units) - y_sym_units = get_sym_units(T, y_units) - - # TODO: Refactor - # This basically just ensures that if the `y` units are set, - # then the `X` units are set as well. - X_si_units = let (_X = get_si_units(T, X_units)) - if _X === nothing && y_si_units !== nothing - get_si_units(T, [one(T) for _ in 1:nfeatures]) - else - _X - end - end - X_sym_units = let _X = get_sym_units(T, X_units) - if _X === nothing && y_sym_units !== nothing - get_sym_units(T, [one(T) for _ in 1:nfeatures]) - else - _X - end - end - - error_on_mismatched_size(nfeatures, X_si_units) - - return Dataset{ - T, - out_loss_type, - typeof(X), - typeof(y), - typeof(weights), - typeof(extra), - typeof(X_si_units), - typeof(y_si_units), - typeof(X_sym_units), - typeof(y_sym_units), - }( - X, - y, - index, - n, - nfeatures, - weighted, - weights, - extra, - avg_y, - use_baseline, - baseline, - variable_names, - display_variable_names, - y_variable_name, - X_si_units, - y_si_units, - X_sym_units, - y_sym_units, - ) -end -function Dataset( - X::AbstractMatrix, - y::Union{<:AbstractVector,Nothing}=nothing; - weights::Union{<:AbstractVector,Nothing}=nothing, - kws..., -) - T = promote_type( - eltype(X), - (y === nothing) ? eltype(X) : eltype(y), - (weights === nothing) ? eltype(X) : eltype(weights), - ) - X = Base.Fix1(convert, T).(X) - if y !== nothing - y = Base.Fix1(convert, T).(y) - end - if weights !== nothing - weights = Base.Fix1(convert, T).(weights) - end - return Dataset(X, y; weights=weights, kws...) -end - -function error_on_mismatched_size(_, ::Nothing) - return nothing -end -function error_on_mismatched_size(nfeatures, X_units::AbstractVector) - if nfeatures != length(X_units) - error( - "Number of features ($(nfeatures)) does not match number of units ($(length(X_units)))", - ) - end - return nothing -end - -function has_units(dataset::Dataset) - return dataset.X_units !== nothing || dataset.y_units !== nothing -end - -# Used for Enzyme -function Base.fill!(d::Dataset, val) - _fill!(d.X, val) - _fill!(d.y, val) - _fill!(d.weights, val) - _fill!(d.extra, val) - return d -end -_fill!(x::AbstractArray, val) = fill!(x, val) -_fill!(x::NamedTuple, val) = foreach(v -> _fill!(v, val), values(x)) -_fill!(::Nothing, val) = nothing -_fill!(x, val) = x - -end diff --git a/src/DimensionalAnalysis.jl b/src/DimensionalAnalysis.jl deleted file mode 100644 index cc9440db1..000000000 --- a/src/DimensionalAnalysis.jl +++ /dev/null @@ -1,233 +0,0 @@ -module DimensionalAnalysisModule - -using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree -using DynamicQuantities: Quantity, DimensionError, AbstractQuantity, constructorof - -using ..CoreModule: Options, Dataset -using ..UtilsModule: safe_call - -import DynamicQuantities: dimension, ustrip -import ..CoreModule.OperatorsModule: safe_pow, safe_sqrt - -""" - @maybe_return_call(T, op, (args...)) - -Basically, we try to evaluate the operator. If -the method is defined AND there is no dimension error, -we return. Otherwise, continue. -""" -macro maybe_return_call(T, op, inputs) - result = gensym() - successful = gensym() - quote - try - $(result), $(successful) = safe_call($(esc(op)), $(esc(inputs)), one($(esc(T)))) - $(successful) && valid($(result)) && return $(result) - catch e - !isa(e, DimensionError) && rethrow(e) - end - false - end -end - -function safe_sqrt(x::Q) where {T,Q<:AbstractQuantity{T}} - ustrip(x) < 0 && return sqrt(abs(x)) * T(NaN) - return sqrt(x) -end - -""" - WildcardQuantity{Q<:AbstractQuantity} - -A wrapper for a `AbstractQuantity` that allows for a wildcard feature, indicating -there is a free constant whose dimensions are not yet determined. -Also stores a flag indicating whether an expression is dimensionally consistent. -""" -struct WildcardQuantity{Q<:AbstractQuantity} - val::Q - wildcard::Bool - violates::Bool -end - -ustrip(w::WildcardQuantity) = ustrip(w.val) -dimension(w::WildcardQuantity) = dimension(w.val) -valid(x::WildcardQuantity) = !x.violates - -Base.one(::Type{W}) where {Q,W<:WildcardQuantity{Q}} = return W(one(Q), false, false) -Base.isfinite(w::WildcardQuantity) = isfinite(w.val) - -same_dimensions(x::WildcardQuantity, y::WildcardQuantity) = dimension(x) == dimension(y) -has_no_dims(x::Quantity) = iszero(dimension(x)) - -# Overload *, /, +, -, ^ for WildcardQuantity, as -# we want wildcards to propagate through these operations. -for op in (:(Base.:*), :(Base.:/)) - @eval function $(op)(l::W, r::W) where {W<:WildcardQuantity} - l.violates && return l - r.violates && return r - return W($(op)(l.val, r.val), l.wildcard || r.wildcard, false) - end -end -for op in (:(Base.:+), :(Base.:-)) - @eval function $(op)(l::W, r::W) where {Q,W<:WildcardQuantity{Q}} - l.violates && return l - r.violates && return r - if same_dimensions(l, r) - return W($(op)(l.val, r.val), l.wildcard && r.wildcard, false) - elseif l.wildcard && r.wildcard - return W( - constructorof(Q)($(op)(ustrip(l), ustrip(r)), typeof(dimension(l))), - true, - false, - ) - elseif l.wildcard - return W($(op)(constructorof(Q)(ustrip(l), dimension(r)), r.val), false, false) - elseif r.wildcard - return W($(op)(l.val, constructorof(Q)(ustrip(r), dimension(l))), false, false) - else - return W(one(Q), false, true) - end - end -end -function Base.:^(l::W, r::W) where {Q,W<:WildcardQuantity{Q}} - l.violates && return l - r.violates && return r - if (has_no_dims(l.val) || l.wildcard) && (has_no_dims(r.val) || r.wildcard) - # Require both base and power to be dimensionless: - x = ustrip(l) - y = ustrip(r) - return W(safe_pow(x, y) * one(Q), false, false) - else - return W(one(Q), false, true) - end -end - -function Base.sqrt(l::W) where {W<:WildcardQuantity} - return l.violates ? l : W(safe_sqrt(l.val), l.wildcard, false) -end -function Base.cbrt(l::W) where {W<:WildcardQuantity} - return l.violates ? l : W(cbrt(l.val), l.wildcard, false) -end -function Base.abs(l::W) where {W<:WildcardQuantity} - return l.violates ? l : W(abs(l.val), l.wildcard, false) -end -function Base.inv(l::W) where {W<:WildcardQuantity} - return l.violates ? l : W(inv(l.val), l.wildcard, false) -end - -# Define dimensionally-aware evaluation routine: -@inline function deg0_eval( - x::AbstractVector{T}, - x_units::Vector{Q}, - t::AbstractExpressionNode{T}, - allow_wildcards::Bool, -) where {T,R,Q<:AbstractQuantity{T,R}} - if t.constant - return WildcardQuantity{Q}(Quantity(t.val, R), allow_wildcards, false) - else - return WildcardQuantity{Q}( - (@inbounds x[t.feature]) * (@inbounds x_units[t.feature]), false, false - ) - end -end -@inline function deg1_eval( - op::F, l::W -) where {F,T,Q<:AbstractQuantity{T},W<:WildcardQuantity{Q}} - l.violates && return l - !isfinite(l) && return W(one(Q), false, true) - - hasmethod(op, Tuple{W}) && @maybe_return_call(W, op, (l,)) - l.wildcard && return W(Quantity(op(ustrip(l))::T), false, false) - return W(one(Q), false, true) -end -@inline function deg2_eval( - op::F, l::W, r::W -) where {F,T,Q<:AbstractQuantity{T},W<:WildcardQuantity{Q}} - l.violates && return l - r.violates && return r - (!isfinite(l) || !isfinite(r)) && return W(one(Q), false, true) - hasmethod(op, Tuple{W,W}) && @maybe_return_call(W, op, (l, r)) - hasmethod(op, Tuple{T,W}) && l.wildcard && @maybe_return_call(W, op, (ustrip(l), r)) - hasmethod(op, Tuple{W,T}) && r.wildcard && @maybe_return_call(W, op, (l, ustrip(r))) - l.wildcard && - r.wildcard && - return W(Quantity(op(ustrip(l), ustrip(r))::T), false, false) - return W(one(Q), false, true) -end - -function violates_dimensional_constraints_dispatch( - tree::AbstractExpressionNode{T}, - x_units::Vector{Q}, - x::AbstractVector{T}, - operators, - allow_wildcards, -) where {T,Q<:AbstractQuantity{T}} - if tree.degree == 0 - return deg0_eval(x, x_units, tree, allow_wildcards)::WildcardQuantity{Q} - elseif tree.degree == 1 - l = violates_dimensional_constraints_dispatch( - tree.l, x_units, x, operators, allow_wildcards - ) - return deg1_eval((@inbounds operators.unaops[tree.op]), l)::WildcardQuantity{Q} - else - l = violates_dimensional_constraints_dispatch( - tree.l, x_units, x, operators, allow_wildcards - ) - r = violates_dimensional_constraints_dispatch( - tree.r, x_units, x, operators, allow_wildcards - ) - return deg2_eval((@inbounds operators.binops[tree.op]), l, r)::WildcardQuantity{Q} - end -end - -""" - violates_dimensional_constraints(tree::AbstractExpressionNode, dataset::Dataset, options::Options) - -Checks whether an expression violates dimensional constraints. -""" -function violates_dimensional_constraints( - tree::AbstractExpressionNode, dataset::Dataset, options::Options -) - X = dataset.X - return violates_dimensional_constraints( - tree, dataset.X_units, dataset.y_units, (@view X[:, 1]), options - ) -end -function violates_dimensional_constraints( - tree::AbstractExpression, dataset::Dataset, options::Options -) - return violates_dimensional_constraints(get_tree(tree), dataset, options) -end -function violates_dimensional_constraints( - tree::AbstractExpressionNode{T}, - X_units::AbstractVector{<:Quantity}, - y_units::Union{Quantity,Nothing}, - x::AbstractVector{T}, - options::Options, -) where {T} - allow_wildcards = !(options.dimensionless_constants_only) - dimensional_output = violates_dimensional_constraints_dispatch( - tree, X_units, x, options.operators, allow_wildcards - ) - # ^ Eventually do this with map_treereduce. However, right now it seems - # like we are passing around too many arguments, which slows things down. - violates = dimensional_output.violates - if y_units !== nothing - violates |= ( - !dimensional_output.wildcard && - dimension(dimensional_output) != dimension(y_units) - ) - end - return violates -end -function violates_dimensional_constraints( - ::AbstractExpressionNode{T}, ::Nothing, ::Quantity, ::AbstractVector{T}, ::Options -) where {T} - return error("This should never happen. Please submit a bug report.") -end -function violates_dimensional_constraints( - ::AbstractExpressionNode{T}, ::Nothing, ::Nothing, ::AbstractVector{T}, ::Options -) where {T} - return false -end - -end diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl deleted file mode 100644 index bcad22575..000000000 --- a/src/ExpressionBuilder.jl +++ /dev/null @@ -1,290 +0,0 @@ -module ExpressionBuilderModule - -using DispatchDoctor: @unstable -using DynamicExpressions: - AbstractExpressionNode, - AbstractExpression, - Expression, - ParametricExpression, - ParametricNode, - constructorof, - get_tree, - get_contents, - get_metadata, - with_contents, - with_metadata, - count_scalar_constants, - eval_tree_array -using Random: default_rng, AbstractRNG -using StatsBase: StatsBase -using ..CoreModule: Options, Dataset, DATA_TYPE -using ..HallOfFameModule: HallOfFame -using ..LossFunctionsModule: maybe_getindex -using ..InterfaceDynamicExpressionsModule: expected_array_type -using ..PopulationModule: Population -using ..PopMemberModule: PopMember - -import DynamicExpressions: get_operators -import ..CoreModule: create_expression -import ..MutationFunctionsModule: - make_random_leaf, crossover_trees, mutate_constant, mutate_factor -import ..LossFunctionsModule: eval_tree_dispatch -import ..ConstantOptimizationModule: count_constants_for_optimization - -@unstable function create_expression( - t::T, options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false) -) where {T,L,embed} - return create_expression( - constructorof(options.node_type)(; val=t), options, dataset, Val(embed) - ) -end -@unstable function create_expression( - t::AbstractExpressionNode{T}, - options::Options, - dataset::Dataset{T,L}, - ::Val{embed}=Val(false), -) where {T,L,embed} - return constructorof(options.expression_type)( - t; init_params(options, dataset, nothing, Val(embed))... - ) -end -function create_expression( - ex::AbstractExpression{T}, ::Options, ::Dataset{T,L}, ::Val{embed}=Val(false) -) where {T,L,embed} - return ex -end -@unstable function init_params( - options::Options, - dataset::Dataset{T,L}, - prototype::Union{Nothing,AbstractExpression}, - ::Val{embed}, -) where {T,L,embed} - consistency_checks(options, prototype) - return (; - operators=embed ? options.operators : nothing, - variable_names=embed ? dataset.variable_names : nothing, - extra_init_params( - options.expression_type, prototype, options, dataset, Val(embed) - )..., - ) -end -function extra_init_params( - ::Type{E}, - prototype::Union{Nothing,AbstractExpression}, - options::Options, - dataset::Dataset{T}, - ::Val{embed}, -) where {T,embed,E<:AbstractExpression} - return (;) -end -function extra_init_params( - ::Type{E}, - prototype::Union{Nothing,ParametricExpression}, - options::Options, - dataset::Dataset{T}, - ::Val{embed}, -) where {T,embed,E<:ParametricExpression} - num_params = options.expression_options.max_parameters - num_classes = length(unique(dataset.extra.classes)) - parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing - _parameters = if prototype === nothing - randn(T, (num_params, num_classes)) - else - copy(get_metadata(prototype).parameters) - end - return (; parameters=_parameters, parameter_names) -end - -consistency_checks(::Options, prototype::Nothing) = nothing -function consistency_checks(options::Options, prototype) - if prototype === nothing - return nothing - end - @assert( - prototype isa options.expression_type, - "Need prototype to be of type $(options.expression_type), but got $(prototype)::$(typeof(prototype))" - ) - if prototype isa ParametricExpression - if prototype.metadata.parameter_names !== nothing - @assert( - length(prototype.metadata.parameter_names) == - options.expression_options.max_parameters, - "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(prototype.metadata.parameter_names)" - ) - end - @assert size(prototype.metadata.parameters, 1) == - options.expression_options.max_parameters - end - return nothing -end - -@unstable begin - function embed_metadata( - ex::AbstractExpression, options::Options, dataset::Dataset{T,L} - ) where {T,L} - return with_metadata(ex; init_params(options, dataset, ex, Val(true))...) - end - function embed_metadata( - member::PopMember, options::Options, dataset::Dataset{T,L} - ) where {T,L} - return PopMember( - embed_metadata(member.tree, options, dataset), - member.score, - member.loss, - nothing; - member.ref, - member.parent, - deterministic=options.deterministic, - ) - end - function embed_metadata( - pop::Population, options::Options, dataset::Dataset{T,L} - ) where {T,L} - return Population( - map(member -> embed_metadata(member, options, dataset), pop.members) - ) - end - function embed_metadata( - hof::HallOfFame, options::Options, dataset::Dataset{T,L} - ) where {T,L} - return HallOfFame( - map(member -> embed_metadata(member, options, dataset), hof.members), hof.exists - ) - end - function embed_metadata( - vec::Vector{H}, options::Options, dataset::Dataset{T,L} - ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} - return map(elem -> embed_metadata(elem, options, dataset), vec) - end -end - -"""Strips all metadata except for top-level information""" -function strip_metadata(ex::Expression, options::Options, dataset::Dataset{T,L}) where {T,L} - return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) -end -function strip_metadata( - ex::ParametricExpression, options::Options, dataset::Dataset{T,L} -) where {T,L} - return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) -end -function strip_metadata( - member::PopMember, options::Options, dataset::Dataset{T,L} -) where {T,L} - return PopMember( - strip_metadata(member.tree, options, dataset), - member.score, - member.loss, - nothing; - member.ref, - member.parent, - deterministic=options.deterministic, - ) -end -function strip_metadata( - pop::Population, options::Options, dataset::Dataset{T,L} -) where {T,L} - return Population(map(member -> strip_metadata(member, options, dataset), pop.members)) -end -function strip_metadata( - hof::HallOfFame, options::Options, dataset::Dataset{T,L} -) where {T,L} - return HallOfFame( - map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists - ) -end - -function eval_tree_dispatch( - tree::ParametricExpression{T}, dataset::Dataset{T}, options::Options, idx -) where {T<:DATA_TYPE} - A = expected_array_type(dataset.X) - return eval_tree_array( - tree, - maybe_getindex(dataset.X, :, idx), - maybe_getindex(dataset.extra.classes, idx), - options.operators, - )::Tuple{A,Bool} -end - -function make_random_leaf( - nfeatures::Int, - ::Type{T}, - ::Type{N}, - rng::AbstractRNG=default_rng(), - options::Union{Options,Nothing}=nothing, -) where {T<:DATA_TYPE,N<:ParametricNode} - choice = rand(rng, 1:3) - if choice == 1 - return ParametricNode(; val=randn(rng, T)) - elseif choice == 2 - return ParametricNode(T; feature=rand(rng, 1:nfeatures)) - else - tree = ParametricNode{T}() - tree.val = zero(T) - tree.degree = 0 - tree.feature = 0 - tree.constant = false - tree.is_parameter = true - tree.parameter = rand( - rng, UInt16(1):UInt16(options.expression_options.max_parameters) - ) - return tree - end -end - -function crossover_trees( - ex1::ParametricExpression{T}, ex2::AbstractExpression{T}, rng::AbstractRNG=default_rng() -) where {T} - tree1 = get_contents(ex1) - tree2 = get_contents(ex2) - out1, out2 = crossover_trees(tree1, tree2, rng) - ex1 = with_contents(ex1, out1) - ex2 = with_contents(ex2, out2) - - # We also randomly share parameters - nparams1 = size(ex1.metadata.parameters, 1) - nparams2 = size(ex2.metadata.parameters, 1) - num_params_switch = min(nparams1, nparams2) - idx_to_switch = StatsBase.sample( - rng, 1:num_params_switch, num_params_switch; replace=false - ) - for param_idx in idx_to_switch - ex2_params = ex2.metadata.parameters[param_idx, :] - ex2.metadata.parameters[param_idx, :] .= ex1.metadata.parameters[param_idx, :] - ex1.metadata.parameters[param_idx, :] .= ex2_params - end - - return ex1, ex2 -end - -function count_constants_for_optimization(ex::ParametricExpression) - return count_scalar_constants(get_tree(ex)) + length(ex.metadata.parameters) -end - -function mutate_constant( - ex::ParametricExpression{T}, - temperature, - options::Options, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - if rand(rng, Bool) - # Normal mutation of inner constant - tree = get_contents(ex) - return with_contents(ex, mutate_constant(tree, temperature, options, rng)) - else - # Mutate parameters - parameter_index = rand(rng, 1:(options.expression_options.max_parameters)) - # We mutate all the parameters at once - factor = mutate_factor(T, temperature, options, rng) - ex.metadata.parameters[parameter_index, :] .*= factor - return ex - end -end - -@unstable function get_operators(ex::AbstractExpression, options::Options) - return get_operators(ex, options.operators) -end -@unstable function get_operators(ex::AbstractExpressionNode, options::Options) - return get_operators(ex, options.operators) -end - -end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl deleted file mode 100644 index 19c52f933..000000000 --- a/src/HallOfFame.jl +++ /dev/null @@ -1,220 +0,0 @@ -module HallOfFameModule - -using DynamicExpressions: AbstractExpression, string_tree -using ..UtilsModule: split_string -using ..CoreModule: - MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression -using ..ComplexityModule: compute_complexity -using ..PopMemberModule: PopMember -using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING -using Printf: @sprintf - -""" - HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE} - -List of the best members seen all time in `.members`, with `.members[c]` being -the best member seen at complexity c. Including only the members which actually -have been set, you can run `.members[exists]`. - -# Fields - -- `members::Array{PopMember{T,L},1}`: List of the best members seen all time. - These are ordered by complexity, with `.members[1]` the member with complexity 1. -- `exists::Array{Bool,1}`: Whether the member at the given complexity has been set. -""" -struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} - exists::Array{Bool,1} #Whether it has been set -end -function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N} - println(io, "HallOfFame{...}:") - for i in eachindex(hof.members, hof.exists) - s_member, s_exists = if hof.exists[i] - sprint((io, m) -> show(io, mime, m), hof.members[i]), "true" - else - "undef", "false" - end - println(io, " "^4 * ".exists[$i] = $s_exists") - print(io, " "^4 * ".members[$i] =") - splitted = split(strip(s_member), '\n') - if length(splitted) == 1 - println(io, " " * s_member) - else - println(io) - foreach(line -> println(io, " "^8 * line), splitted) - end - end - return nothing -end - -""" - HallOfFame(options::Options, dataset::Dataset{T,L}) where {T<:DATA_TYPE,L<:LOSS_TYPE} - -Create empty HallOfFame. The HallOfFame stores a list -of `PopMember` objects in `.members`, which is enumerated -by size (i.e., `.members[1]` is the constant solution). -`.exists` is used to determine whether the particular member -has been instantiated or not. - -Arguments: -- `options`: Options containing specification about deterministic. -- `dataset`: Dataset containing the input data. -""" -function HallOfFame( - options::Options, dataset::Dataset{T,L} -) where {T<:DATA_TYPE,L<:LOSS_TYPE} - actualMaxsize = options.maxsize + MAX_DEGREE - base_tree = create_expression(zero(T), options, dataset) - - return HallOfFame{T,L,typeof(base_tree)}( - [ - PopMember( - copy(base_tree), - L(0), - L(Inf), - options; - parent=-1, - deterministic=options.deterministic, - ) for i in 1:actualMaxsize - ], - [false for i in 1:actualMaxsize], - ) -end - -function Base.copy(hof::HallOfFame) - return HallOfFame( - [copy(member) for member in hof.members], [exists for exists in hof.exists] - ) -end - -""" - calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,P}) where {T<:DATA_TYPE,L<:LOSS_TYPE} -""" -function calculate_pareto_frontier(hallOfFame::HallOfFame{T,L,N}) where {T,L,N} - # TODO - remove dataset from args. - P = PopMember{T,L,N} - # Dominating pareto curve - must be better than all simpler equations - dominating = P[] - actualMaxsize = length(hallOfFame.members) - for size in 1:actualMaxsize - if !hallOfFame.exists[size] - continue - end - member = hallOfFame.members[size] - # We check if this member is better than all members which are smaller than it and - # also exist. - betterThanAllSmaller = true - for i in 1:(size - 1) - if !hallOfFame.exists[i] - continue - end - simpler_member = hallOfFame.members[i] - if member.loss >= simpler_member.loss - betterThanAllSmaller = false - break - end - end - if betterThanAllSmaller - push!(dominating, copy(member)) - end - end - return dominating -end - -function string_dominating_pareto_curve( - hallOfFame, dataset, options; width::Union{Integer,Nothing}=nothing -) - twidth = (width === nothing) ? 100 : max(100, width::Integer) - output = "" - output *= "Hall of Fame:\n" - # TODO: Get user's terminal width. - output *= "-"^(twidth - 1) * "\n" - output *= @sprintf( - "%-10s %-8s %-8s %-8s\n", "Complexity", "Loss", "Score", "Equation" - ) - - formatted = format_hall_of_fame(hallOfFame, options) - for (tree, score, loss, complexity) in - zip(formatted.trees, formatted.scores, formatted.losses, formatted.complexities) - eqn_string = string_tree( - tree, - options; - display_variable_names=dataset.display_variable_names, - X_sym_units=dataset.X_sym_units, - y_sym_units=dataset.y_sym_units, - raw=false, - ) - y_prefix = dataset.y_variable_name - unit_str = format_dimensions(dataset.y_sym_units) - y_prefix *= unit_str - if dataset.y_sym_units === nothing && dataset.X_sym_units !== nothing - y_prefix *= WILDCARD_UNIT_STRING - end - eqn_string = y_prefix * " = " * eqn_string - base_string_length = length(@sprintf("%-10d %-8.3e %8.3e ", 1, 1.0, 1.0)) - - dots = "..." - equation_width = (twidth - 1) - base_string_length - length(dots) - - output *= @sprintf("%-10d %-8.3e %-8.3e ", complexity, loss, score) - - split_eqn = split_string(eqn_string, equation_width) - print_pad = false - while length(split_eqn) > 1 - cur_piece = popfirst!(split_eqn) - output *= " "^(print_pad * base_string_length) * cur_piece * dots * "\n" - print_pad = true - end - output *= " "^(print_pad * base_string_length) * split_eqn[1] * "\n" - end - output *= "-"^(twidth - 1) - return output -end - -function format_hall_of_fame(hof::HallOfFame{T,L}, options) where {T,L} - dominating = calculate_pareto_frontier(hof) - foreach(dominating) do member - if member.loss < 0.0 - throw( - DomainError( - member.loss, - "Your loss function must be non-negative. To do this, consider wrapping your loss inside an exponential, which will not affect the search (unless you are using annealing).", - ), - ) - end - end - - ZERO_POINT = eps(L) - cur_loss = typemax(L) - last_loss = cur_loss - last_complexity = 0 - - trees = [member.tree for member in dominating] - losses = [member.loss for member in dominating] - complexities = [compute_complexity(member, options) for member in dominating] - scores = Array{L}(undef, length(dominating)) - - for i in 1:length(dominating) - complexity = complexities[i] - cur_loss = losses[i] - delta_c = complexity - last_complexity - delta_l_mse = log(relu(cur_loss / last_loss) + ZERO_POINT) - - scores[i] = relu(-delta_l_mse / delta_c) - last_loss = cur_loss - last_complexity = complexity - end - return (; trees, scores, losses, complexities) -end -function format_hall_of_fame(hof::AbstractVector{<:HallOfFame}, options) - outs = [format_hall_of_fame(h, options) for h in hof] - return (; - trees=[out.trees for out in outs], - scores=[out.scores for out in outs], - losses=[out.losses for out in outs], - complexities=[out.complexities for out in outs], - ) -end -# TODO: Re-use this in `string_dominating_pareto_curve` - -end diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl deleted file mode 100644 index d5cf52300..000000000 --- a/src/InterfaceDynamicExpressions.jl +++ /dev/null @@ -1,361 +0,0 @@ -module InterfaceDynamicExpressionsModule - -using Printf: @sprintf -using DynamicExpressions: - DynamicExpressions as DE, - OperatorEnum, - GenericOperatorEnum, - AbstractExpression, - AbstractExpressionNode, - ParametricExpression, - Node, - GraphNode -using DynamicQuantities: dimension, ustrip -using ..CoreModule: Options -using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap -using ..UtilsModule: subscriptify - -import ..deprecate_varmap - -""" - eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; kws...) - -Evaluate a binary tree (equation) over a given input data matrix. The -operators contain all of the operators used. This function fuses doublets -and triplets of operations for lower memory usage. - -This function can be represented by the following pseudocode: - -``` -function eval(current_node) - if current_node is leaf - return current_node.value - elif current_node is degree 1 - return current_node.operator(eval(current_node.left_child)) - else - return current_node.operator(eval(current_node.left_child), eval(current_node.right_child)) -``` -The bulk of the code is for optimizations and pre-emptive NaN/Inf checks, -which speed up evaluation significantly. - -# Arguments -- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The root node of the tree to evaluate. -- `X::AbstractArray`: The input data to evaluate the tree on. -- `options::Options`: Options used to define the operators used in the tree. - -# Returns -- `(output, complete)::Tuple{AbstractVector, Bool}`: the result, - which is a 1D array, as well as if the evaluation completed - successfully (true/false). A `false` complete means an infinity - or nan was encountered, and a large loss should be assigned - to the equation. -""" -function DE.eval_tree_array( - tree::Union{AbstractExpressionNode,AbstractExpression}, - X::AbstractMatrix, - options::Options; - kws..., -) - A = expected_array_type(X) - return DE.eval_tree_array( - tree, - X, - DE.get_operators(tree, options); - turbo=options.turbo, - bumper=options.bumper, - kws..., - )::Tuple{A,Bool} -end -function DE.eval_tree_array( - tree::ParametricExpression, - X::AbstractMatrix, - classes::AbstractVector{<:Integer}, - options::Options; - kws..., -) - A = expected_array_type(X) - return DE.eval_tree_array( - tree, - X, - classes, - DE.get_operators(tree, options); - turbo=options.turbo, - bumper=options.bumper, - kws..., - )::Tuple{A,Bool} -end - -# Improve type inference by telling Julia the expected array returned -function expected_array_type(X::AbstractArray) - return typeof(similar(X, axes(X, 2))) -end - -""" - eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options, direction::Int) - -Compute the forward derivative of an expression, using a similar -structure and optimization to eval_tree_array. `direction` is the index of a particular -variable in the expression. e.g., `direction=1` would indicate derivative with -respect to `x1`. - -# Arguments - -- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate. -- `X::AbstractArray`: The data matrix, with each column being a data point. -- `options::Options`: The options containing the operators used to create the `tree`. -- `direction::Int`: The index of the variable to take the derivative with respect to. - -# Returns - -- `(evaluation, derivative, complete)::Tuple{AbstractVector, AbstractVector, Bool}`: the normal evaluation, - the derivative, and whether the evaluation completed as normal (or encountered a nan or inf). -""" -function DE.eval_diff_tree_array( - tree::Union{AbstractExpression,AbstractExpressionNode}, - X::AbstractArray, - options::Options, - direction::Int, -) - A = expected_array_type(X) - # TODO: Add `AbstractExpression` implementation in `Expression.jl` - return DE.eval_diff_tree_array( - DE.get_tree(tree), X, DE.get_operators(tree, options), direction - )::Tuple{A,A,Bool} -end - -""" - eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; variable::Bool=false) - -Compute the forward-mode derivative of an expression, using a similar -structure and optimization to eval_tree_array. `variable` specifies whether -we should take derivatives with respect to features (i.e., `X`), or with respect -to every constant in the expression. - -# Arguments - -- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate. -- `X::AbstractArray`: The data matrix, with each column being a data point. -- `options::Options`: The options containing the operators used to create the `tree`. -- `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`), - or with respect to every constant in the expression (`variable=false`). - -# Returns - -- `(evaluation, gradient, complete)::Tuple{AbstractVector, AbstractArray, Bool}`: the normal evaluation, - the gradient, and whether the evaluation completed as normal (or encountered a nan or inf). -""" -function DE.eval_grad_tree_array( - tree::Union{AbstractExpression,AbstractExpressionNode}, - X::AbstractArray, - options::Options; - kws..., -) - A = expected_array_type(X) - M = typeof(X) # TODO: This won't work with StaticArrays! - return DE.eval_grad_tree_array( - tree, X, DE.get_operators(tree, options); kws... - )::Tuple{A,M,Bool} -end - -""" - differentiable_eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options) - -Evaluate an expression tree in a way that can be auto-differentiated. -""" -function DE.differentiable_eval_tree_array( - tree::Union{AbstractExpression,AbstractExpressionNode}, - X::AbstractArray, - options::Options, -) - A = expected_array_type(X) - # TODO: Add `AbstractExpression` implementation in `Expression.jl` - return DE.differentiable_eval_tree_array( - DE.get_tree(tree), X, DE.get_operators(tree, options) - )::Tuple{A,Bool} -end - -const WILDCARD_UNIT_STRING = "[?]" - -""" - string_tree(tree::AbstractExpressionNode, options::Options; kws...) - -Convert an equation to a string. - -# Arguments - -- `tree::AbstractExpressionNode`: The equation to convert to a string. -- `options::Options`: The options holding the definition of operators. -- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables - to print for each feature. -""" -@inline function DE.string_tree( - tree::Union{AbstractExpression,AbstractExpressionNode}, - options::Options; - raw::Bool=true, - X_sym_units=nothing, - y_sym_units=nothing, - variable_names=nothing, - display_variable_names=variable_names, - varMap=nothing, - kws..., -) - variable_names = deprecate_varmap(variable_names, varMap, :string_tree) - - if raw - tree = tree isa GraphNode ? convert(Node, tree) : tree - return DE.string_tree( - tree, - DE.get_operators(tree, options); - f_variable=string_variable_raw, - variable_names, - ) - end - - vprecision = vals[options.print_precision] - if X_sym_units !== nothing || y_sym_units !== nothing - return DE.string_tree( - tree, - DE.get_operators(tree, options); - f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units), - f_constant=let - unit_placeholder = - options.dimensionless_constants_only ? "" : WILDCARD_UNIT_STRING - (val,) -> string_constant(val, vprecision, unit_placeholder) - end, - variable_names=display_variable_names, - kws..., - ) - else - return DE.string_tree( - tree, - DE.get_operators(tree, options); - f_variable=string_variable, - f_constant=(val,) -> string_constant(val, vprecision, ""), - variable_names=display_variable_names, - kws..., - ) - end -end -const vals = ntuple(Val, 8192) -function string_variable_raw(feature, variable_names) - if variable_names === nothing || feature > length(variable_names) - return "x" * string(feature) - else - return variable_names[feature] - end -end -function string_variable(feature, variable_names, variable_units=nothing) - base = if variable_names === nothing || feature > length(variable_names) - "x" * subscriptify(feature) - else - variable_names[feature] - end - if variable_units !== nothing - base *= format_dimensions(variable_units[feature]) - end - return base -end -function string_constant(val, ::Val{precision}, unit_placeholder) where {precision} - if typeof(val) <: Real - return sprint_precision(val, Val(precision)) * unit_placeholder - else - return "(" * string(val) * ")" * unit_placeholder - end -end -function format_dimensions(::Nothing) - return "" -end -function format_dimensions(u) - if isone(ustrip(u)) - dim = dimension(u) - if iszero(dim) - return "" - else - return "[" * string(dim) * "]" - end - else - return "[" * string(u) * "]" - end -end -@generated function sprint_precision(x, ::Val{precision}) where {precision} - fmt_string = "%.$(precision)g" - return :(@sprintf($fmt_string, x)) -end - -""" - print_tree(tree::AbstractExpressionNode, options::Options; kws...) - -Print an equation - -# Arguments - -- `tree::AbstractExpressionNode`: The equation to convert to a string. -- `options::Options`: The options holding the definition of operators. -- `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables - to print for each feature. -""" -function DE.print_tree( - tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws... -) - return DE.print_tree(tree, DE.get_operators(tree, options); kws...) -end -function DE.print_tree( - io::IO, tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws... -) - return DE.print_tree(io, tree, DE.get_operators(tree, options); kws...) -end - -""" - @extend_operators options - -Extends all operators defined in this options object to work on the -`AbstractExpressionNode` type. While by default this is already done for operators defined -in `Base` when you create an options and pass `define_helper_functions=true`, -this does not apply to the user-defined operators. Thus, to do so, you must -apply this macro to the operator enum in the same module you have the operators -defined. -""" -macro extend_operators(options) - operators = :($(options).operators) - type_requirements = Options - alias_operators = gensym("alias_operators") - return quote - if !isa($(options), $type_requirements) - error("You must pass an options type to `@extend_operators`.") - end - $alias_operators = $define_alias_operators($operators) - $(DE).@extend_operators $alias_operators - end |> esc -end -function define_alias_operators(operators) - # We undo some of the aliases so that the user doesn't need to use, e.g., - # `safe_pow(x1, 1.5)`. They can use `x1 ^ 1.5` instead. - constructor = isa(operators, OperatorEnum) ? OperatorEnum : GenericOperatorEnum - return constructor(; - binary_operators=inverse_binopmap.(operators.binops), - unary_operators=inverse_unaopmap.(operators.unaops), - define_helper_functions=false, - empty_old_operators=false, - ) -end - -function (tree::Union{AbstractExpression,AbstractExpressionNode})( - X, options::Options; kws... -) - return tree( - X, - DE.get_operators(tree, options); - turbo=options.turbo, - bumper=options.bumper, - kws..., - ) -end -function DE.EvaluationHelpersModule._grad_evaluator( - tree::Union{AbstractExpression,AbstractExpressionNode}, X, options::Options; kws... -) - return DE.EvaluationHelpersModule._grad_evaluator( - tree, X, DE.get_operators(tree, options); turbo=options.turbo, kws... - ) -end - -end diff --git a/src/InterfaceDynamicQuantities.jl b/src/InterfaceDynamicQuantities.jl deleted file mode 100644 index 34580a3cc..000000000 --- a/src/InterfaceDynamicQuantities.jl +++ /dev/null @@ -1,133 +0,0 @@ -module InterfaceDynamicQuantitiesModule - -using DispatchDoctor: @unstable -using DynamicQuantities: - UnionAbstractQuantity, - AbstractDimensions, - Dimensions, - SymbolicDimensions, - Quantity, - dimension, - uparse, - sym_uparse, - dim_type, - DEFAULT_DIM_BASE_TYPE - -""" - get_units(T, D, x, f) - -Gets unit information from a vector or scalar. The first two -types are the default numeric type and dimensions type, respectively. -The third argument is the value to get units from, and the fourth -argument is a function for parsing strings (in case a string is passed) -""" -function get_units(args...) - return error( - "Unit information must be passed as one of `AbstractDimensions`, `AbstractQuantity`, `AbstractString`, `Function`.", - ) -end -function get_units(_, _, ::Nothing, ::Function) - return nothing -end -function get_units(::Type{T}, ::Type{D}, x::AbstractString, f::Function) where {T,D} - isempty(x) && return one(Quantity{T,D}) - return convert(Quantity{T,D}, f(x)) -end -function get_units(::Type{T}, ::Type{D}, x::Quantity, ::Function) where {T,D} - return convert(Quantity{T,D}, x) -end -function get_units(::Type{T}, ::Type{D}, x::AbstractDimensions, ::Function) where {T,D} - return convert(Quantity{T,D}, Quantity(one(T), x)) -end -function get_units(::Type{T}, ::Type{D}, x::Real, ::Function) where {T,D} - return Quantity(convert(T, x), D)::Quantity{T,D} -end -function get_units(::Type{T}, ::Type{D}, x::AbstractVector, f::Function) where {T,D} - return Quantity{T,D}[get_units(T, D, xi, f) for xi in x] -end -# TODO: Allow for AbstractQuantity output here - -""" - get_si_units(::Type{T}, units) - -Gets the units with Dimensions{DEFAULT_DIM_BASE_TYPE} type from a vector or scalar. -""" -function get_si_units(::Type{T}, units) where {T} - return get_units(T, Dimensions{DEFAULT_DIM_BASE_TYPE}, units, uparse) -end - -""" - get_sym_units(::Type{T}, units) - -Gets the units with SymbolicDimensions{DEFAULT_DIM_BASE_TYPE} type from a vector or scalar. -""" -function get_sym_units(::Type{T}, units) where {T} - return get_units(T, SymbolicDimensions{DEFAULT_DIM_BASE_TYPE}, units, sym_uparse) -end - -""" - get_dimensions_type(A, default_dimensions) - -Recursively finds the dimension type from an array, or, -if no quantity is found, returns the default type. -""" -@unstable function get_dimensions_type(A::AbstractArray, default::Type{D}) where {D} - i = findfirst(a -> isa(a, UnionAbstractQuantity), A) - if i === nothing - return D - else - return typeof(dimension(A[i])) - end -end -function get_dimensions_type( - ::AbstractArray{Q}, default::Type -) where {Q<:UnionAbstractQuantity} - return dim_type(Q) -end -function get_dimensions_type(_, default::Type{D}) where {D} - return D -end - -# Shortcut for basic numeric types -function get_dimensions_type( - ::AbstractArray{ - <:Union{ - Bool, - Int8, - UInt8, - Int16, - UInt16, - Int32, - UInt32, - Int64, - UInt64, - Int128, - UInt128, - Float16, - Float32, - Float64, - BigFloat, - BigInt, - ComplexF16, - ComplexF32, - ComplexF64, - Complex{BigFloat}, - Rational{Int8}, - Rational{UInt8}, - Rational{Int16}, - Rational{UInt16}, - Rational{Int32}, - Rational{UInt32}, - Rational{Int64}, - Rational{UInt64}, - Rational{Int128}, - Rational{UInt128}, - Rational{BigInt}, - }, - }, - default::Type{D}, -) where {D} - return D -end - -end diff --git a/src/LLMFunctions.jl b/src/LLMFunctions.jl index 6e14dedf4..34a0d085a 100644 --- a/src/LLMFunctions.jl +++ b/src/LLMFunctions.jl @@ -20,164 +20,116 @@ using DynamicExpressions: string_tree, AbstractOperatorEnum using Compat: Returns, @inline -using ..CoreModule: Options, DATA_TYPE, binopmap, unaopmap, LLMOptions -using ..MutationFunctionsModule: gen_random_tree_fixed_size, random_node_and_parent - +using SymbolicRegression: + Options, DATA_TYPE, gen_random_tree_fixed_size, random_node_and_parent, AbstractOptions +using ..LLMOptionsModule: LaSROptions +using ..LLMUtilsModule: + llm_recorder, + load_prompt, + convertDict, + get_vars, + get_ops, + construct_prompt, + format_pareto, + sample_context +using ..ParseModule: render_expr, parse_expr using PromptingTools: SystemMessage, UserMessage, AIMessage, aigenerate, + render, CustomOpenAISchema, OllamaSchema, OpenAISchema using JSON: parse -"""LLM Recoder records the LLM calls for debugging purposes.""" -function llm_recorder(options::LLMOptions, expr::String, mode::String="debug") - if options.active - if !isdir(options.llm_recorder_dir) - mkdir(options.llm_recorder_dir) - end - recorder = open(joinpath(options.llm_recorder_dir, "llm_calls.txt"), "a") - write(recorder, string("[", mode, "] ", expr, "\n[/", mode, "]\n")) - close(recorder) - end -end - -function load_prompt(path::String)::String - # load prompt file - f = open(path, "r") - s = read(f, String) - close(f) - return s -end - -function convertDict(d)::NamedTuple - return (; Dict(Symbol(k) => v for (k, v) in d)...) -end - -function get_vars(options::Options)::String - variable_names = get_variable_names(options.llm_options.var_order) - return join(variable_names, ", ") -end - -function get_ops(options::Options)::String - binary_operators = map(v -> string(v), map(binopmap, options.operators.binops)) - unary_operators = map(v -> string(v), map(unaopmap, options.operators.unaops)) - # Binary Ops: +, *, -, /, safe_pow (^) - # Unary Ops: exp, safe_log, safe_sqrt, sin, cos - return replace( - replace( - "binary operators: " * - join(binary_operators, ", ") * - ", and unary operators: " * - join(unary_operators, ", "), - "safe_" => "", - ), - "pow" => "^", +function llm_randomize_tree( + ex::AbstractExpression, + curmaxsize::Int, + options::AbstractOptions, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) + tree = get_contents(ex) + context = nothing + ex = with_contents_for_mutation( + ex, llm_randomize_tree(tree, curmaxsize, options, nfeatures, rng), context ) + return ex +end +function llm_randomize_tree( + ::AbstractExpressionNode{T}, + curmaxsize::Int, + options::AbstractOptions, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree_size_to_generate = rand(rng, 1:curmaxsize) + return _gen_llm_random_tree(tree_size_to_generate, options, nfeatures, T) end -""" -Constructs a prompt by replacing the element_id_tag with the corresponding element in the element_list. -If the element_list is longer than the number of occurrences of the element_id_tag, the missing elements are added after the last occurrence. -If the element_list is shorter than the number of occurrences of the element_id_tag, the extra ids are removed. -""" -function construct_prompt( - user_prompt::String, element_list::Vector, element_id_tag::String -)::String - # Split the user prompt into lines - lines = split(user_prompt, "\n") - - # Filter lines that match the pattern "... : {{element_id_tag[1-9]}} - pattern = r"^.*: \{\{" * element_id_tag * r"\d+\}\}$" - - # find all occurrences of the element_id_tag - n_occurrences = count(x -> occursin(pattern, x), lines) - - # if n_occurrences is less than |element_list|, add the missing elements after the last occurrence - if n_occurrences < length(element_list) - last_occurrence = findlast(x -> occursin(pattern, x), lines) - @assert last_occurrence !== nothing "No occurrences of the element_id_tag found in the user prompt." - for i in reverse((n_occurrences + 1):length(element_list)) - new_line = replace(lines[last_occurrence], string(n_occurrences) => string(i)) - insert!(lines, last_occurrence + 1, new_line) - end +function _gen_llm_random_tree( + node_count::Int, options::LaSROptions, nfeatures::Int, ::Type{T} +)::AbstractExpressionNode{T} where {T<:DATA_TYPE} + if isnothing(options.idea_database) + assumptions = [] + else + assumptions = sample_context( + options.idea_database, + min(options.num_pareto_context, length(options.idea_database)), + options.max_concepts, + ) end - new_prompt = "" - idx = 1 - for line in lines - # if the line matches the pattern - if occursin(pattern, line) - if idx > length(element_list) - continue - end - # replace the element_id_tag with the corresponding element - new_prompt *= - replace(line, r"\{\{" * element_id_tag * r"\d+\}\}" => element_list[idx]) * - "\n" - idx += 1 - else - new_prompt *= line * "\n" - end + if options.llm_context != "" + pushfirst!(assumptions, options.llm_context) end - return new_prompt -end - -function gen_llm_random_tree( - node_count::Int, - options::Options, - nfeatures::Int, - ::Type{T}, - idea_database::Union{Vector{String},Nothing}, -)::AbstractExpressionNode{T} where {T<:DATA_TYPE} - # Note that this base tree is just a placeholder; it will be replaced. - N = 5 - # LLM prompt - # conversation = [ - # SystemMessage(load_prompt(options.llm_options.prompts_dir * "gen_random_system.txt")), - # UserMessage(load_prompt(options.llm_options.prompts_dir * "gen_random_user.txt"))] - assumptions = sample_context( - idea_database, - options.llm_options.num_pareto_context, - options.llm_options.idea_threshold, - ) - if !options.llm_options.prompt_concepts + if !options.use_concepts assumptions = [] end conversation = [ + SystemMessage(load_prompt(options.prompts_dir * "gen_random_system.txt")), UserMessage( - load_prompt(options.llm_options.prompts_dir * "gen_random_system.txt") * - "\n" * construct_prompt( - load_prompt(options.llm_options.prompts_dir * "gen_random_user.txt"), + load_prompt(options.prompts_dir * "gen_random_user.txt"), assumptions, "assump", ), ), ] - llm_recorder(options.llm_options, conversation[1].content, "llm_input|gen_random") + rendered_msg = join( + [ + x["content"] for x in render( + CustomOpenAISchema(), + conversation; + variables=get_vars(options), + operators=get_ops(options), + N=options.num_generated_equations, + no_system_message=false, + ) + ], + "\n", + ) - if options.llm_options.llm_context != "" - pushfirst!(assumptions, options.llm_options.llm_context) - end + llm_recorder(options.llm_options, rendered_msg, "llm_input|gen_random") msg = nothing try msg = aigenerate( CustomOpenAISchema(), - conversation; #OllamaSchema(), conversation; + conversation; variables=get_vars(options), - operators=get_ops(options), - N=N, - api_key=options.llm_options.api_key, - model=options.llm_options.model, - api_kwargs=convertDict(options.llm_options.api_kwargs), - http_kwargs=convertDict(options.llm_options.http_kwargs), + operaotrs=get_ops(options), + N=options.num_generated_equations, + api_key=options.api_key, + model=options.model, + api_kwargs=convertDict(options.api_kwargs), + http_kwargs=convertDict(options.http_kwargs), + no_system_message=true, + verbose=false, ) catch e llm_recorder(options.llm_options, "None " * string(e), "gen_random|failed") @@ -185,7 +137,7 @@ function gen_llm_random_tree( end llm_recorder(options.llm_options, string(msg.content), "llm_output|gen_random") - gen_tree_options = parse_msg_content(msg.content) + gen_tree_options = parse_msg_content(String(msg.content)) N = min(size(gen_tree_options)[1], N) @@ -196,7 +148,7 @@ function gen_llm_random_tree( for i in 1:N l = rand(1:N) - t = expr_to_tree( + t = parse_expr( T, String(strip(gen_tree_options[l], [' ', '\n', '"', ',', '.', '[', ']'])), options, @@ -204,16 +156,16 @@ function gen_llm_random_tree( if t.val == 1 && t.constant continue end - llm_recorder(options.llm_options, tree_to_expr(t, options), "gen_random") + llm_recorder(options.llm_options, render_expr(t, options), "gen_random") return t end - out = expr_to_tree( + out = parse_expr( T, String(strip(gen_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])), options ) - llm_recorder(options.llm_options, tree_to_expr(out, options), "gen_random") + llm_recorder(options.llm_options, render_expr(out, options), "gen_random") if out.val == 1 && out.constant return gen_random_tree_fixed_size(node_count, options, nfeatures, T) @@ -255,248 +207,46 @@ function crossover_trees( return tree1, tree2 end -function sketch_const(val) - does_not_need_brackets = (typeof(val) <: Union{Real,AbstractArray}) - - if does_not_need_brackets - if isinteger(val) && (abs(val) < 5) # don't abstract integer constants from -4 to 4, useful for exponents - string(val) - else - "C" - end - else - if isinteger(val) && (abs(val) < 5) # don't abstract integer constants from -4 to 4, useful for exponents - "(" * string(val) * ")" - else - "(C)" - end - end -end - -function tree_to_expr( - ex::AbstractExpression{T}, options::Options -)::String where {T<:DATA_TYPE} - return tree_to_expr(get_contents(ex), options) -end - -function tree_to_expr(tree::AbstractExpressionNode{T}, options)::String where {T<:DATA_TYPE} - variable_names = get_variable_names(options.llm_options.var_order) - return string_tree( - tree, options.operators; f_constant=sketch_const, variable_names=variable_names - ) -end - -function get_variable_names(var_order::Dict)::Vector{String} - return [var_order[key] for key in sort(collect(keys(var_order)))] -end - -function get_variable_names(var_order::Nothing)::Vector{String} - return ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] -end - -function handle_not_expr(::Type{T}, x, var_names)::Node{T} where {T<:DATA_TYPE} - if x isa Real - Node{T}(; val=convert(T, x)) # old: Node(T, 0, true, convert(T,x)) - elseif x isa Symbol - if x === :C # constant that got abstracted - Node{T}(; val=convert(T, 1)) # old: Node(T, 0, true, convert(T,1)) - else - feature = findfirst(isequal(string(x)), var_names) - if isnothing(feature) # invalid var name, just assume its x0 - feature = 1 - end - Node{T}(; feature=feature) # old: Node(T, 0, false, nothing, feature) - end - else - Node{T}(; val=convert(T, 1)) # old: Node(T, 0, true, convert(T,1)) # return a constant being 0 - end -end - -function expr_to_tree_recurse( - ::Type{T}, node::Expr, op::AbstractOperatorEnum, var_names -)::Node{T} where {T<:DATA_TYPE} - args = node.args - x = args[1] - degree = length(args) - - if degree == 1 - handle_not_expr(T, x, var_names) - elseif degree == 2 - unary_operators = map(v -> string(v), map(unaopmap, op.unaops)) - idx = findfirst(isequal(string(x)), unary_operators) - if isnothing(idx) # if not used operator, make it the first one - idx = findfirst(isequal("safe_" * string(x)), unary_operators) - if isnothing(idx) - idx = 1 - end - end - - left = if (args[2] isa Expr) - expr_to_tree_recurse(T, args[2], op, var_names) - else - handle_not_expr(T, args[2], var_names) - end - - Node(; op=idx, l=left) # old: Node(1, false, nothing, 0, idx, left) - elseif degree == 3 - if x === :^ - x = :pow - end - binary_operators = map(v -> string(v), map(binopmap, op.binops)) - idx = findfirst(isequal(string(x)), binary_operators) - if isnothing(idx) # if not used operator, make it the first one - idx = findfirst(isequal("safe_" * string(x)), binary_operators) - if isnothing(idx) - idx = 1 - end - end - - left = if (args[2] isa Expr) - expr_to_tree_recurse(T, args[2], op, var_names) - else - handle_not_expr(T, args[2], var_names) - end - right = if (args[3] isa Expr) - expr_to_tree_recurse(T, args[3], op, var_names) - else - handle_not_expr(T, args[3], var_names) - end - - Node(; op=idx, l=left, r=right) # old: Node(2, false, nothing, 0, idx, left, right) - else - Node{T}(; val=convert(T, 1)) # old: Node(T, 0, true, convert(T,1)) # return a constant being 1 - end -end - -function expr_to_tree_run(::Type{T}, x::String, options)::Node{T} where {T<:DATA_TYPE} - try - expr = Meta.parse(x) - variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] - if !isnothing(options.llm_options.var_order) - variable_names = [ - options.llm_options.var_order[key] for - key in sort(collect(keys(options.llm_options.var_order))) - ] - end - if expr isa Expr - expr_to_tree_recurse(T, expr, options.operators, variable_names) - else - handle_not_expr(T, expr, variable_names) - end - catch - Node{T}(; val=convert(T, 1)) # old: Node(T, 0, true, convert(T,1)) # return a constant being 1 - end -end - -function expr_to_tree(::Type{T}, x::String, options) where {T<:DATA_TYPE} - if options.llm_options.is_parametric - out = ParametricNode{T}(expr_to_tree_run(T, x, options)) - else - out = Node{T}(expr_to_tree_run(T, x, options)) - end - return out -end - -function format_pareto(dominating, options, num_pareto_context::Int)::Vector{String} - pareto = Vector{String}() - if !isnothing(dominating) && size(dominating)[1] > 0 - idx = randperm(size(dominating)[1]) - for i in 1:min(size(dominating)[1], num_pareto_context) - push!(pareto, tree_to_expr(dominating[idx[i]].tree, options)) - end - end - while size(pareto)[1] < num_pareto_context - push!(pareto, "None") - end - return pareto -end - -function sample_one_context(idea_database, idea_threshold)::String - if isnothing(idea_database) - return "None" - end - - N = size(idea_database)[1] - if N == 0 - return "None" - end - - try - idea_database[rand(1:min(idea_threshold, N))] - catch e - "None" - end -end - -function sample_context(idea_database, N, idea_threshold)::Vector{String} - assumptions = Vector{String}() - if isnothing(idea_database) - for _ in 1:N - push!(assumptions, "None") - end - return assumptions - end - - if size(idea_database)[1] < N - for i in 1:(size(idea_database)[1]) - push!(assumptions, idea_database[i]) - end - for i in (size(idea_database)[1] + 1):N - push!(assumptions, "None") - end - return assumptions - end - - while size(assumptions)[1] < N - chosen_idea = sample_one_context(idea_database, idea_threshold) - if chosen_idea in assumptions - continue - end - push!(assumptions, chosen_idea) - end - return assumptions -end - -function prompt_evol(idea_database, options::Options) +function concept_evolution(idea_database, options::LaSROptions) num_ideas = size(idea_database)[1] - if num_ideas <= options.llm_options.idea_threshold + if num_ideas <= options.max_concepts return nothing end - idea1 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] - idea2 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] # they could be same (should be allowed) - idea3 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] # they could be same (should be allowed) - idea4 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] # they could be same (should be allowed) - idea5 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] # they could be same (should be allowed) - - N = 5 - - # conversation = [ - # SystemMessage(load_prompt(options.llm_options.prompts_dir * "prompt_evol_system.txt")), - # UserMessage(load_prompt(options.llm_options.prompts_dir * "prompt_evol_user.txt"))] + ideas = [idea_database[rand((options.idea_threshold + 1):num_ideas)] for _ in 1:n_ideas] conversation = [ + SystemMessage(load_prompt(options.prompts_dir * "prompt_evol_system.txt")), UserMessage( - load_prompt(options.llm_options.prompts_dir * "prompt_evol_system.txt") * - "\n" * construct_prompt( - load_prompt(options.llm_options.prompts_dir * "prompt_evol_user.txt"), - [idea1, idea2, idea3, idea4, idea5], - "idea", + load_prompt(options.prompts_dir * "prompt_evol_user.txt"), ideas, "idea" ), ), ] - llm_recorder(options.llm_options, conversation[1].content, "llm_input|ideas") + rendered_msg = join( + [ + x["content"] for x in render( + CustomOpenAISchema(), + conversation; + variables=get_vars(options), + operators=get_ops(options), + N=options.num_generated_concepts, + no_system_message=false, + ) + ], + "\n", + ) + llm_recorder(options.llm_options, rendered_msg, "llm_input|ideas") msg = nothing try msg = aigenerate( CustomOpenAISchema(), - conversation; #OllamaSchema(), conversation; - N=N, - api_key=options.llm_options.api_key, - model=options.llm_options.model, - api_kwargs=convertDict(options.llm_options.api_kwargs), - http_kwargs=convertDict(options.llm_options.http_kwargs), + conversation; + N=options.num_generated_concepts, + api_key=options.api_key, + model=options.model, + api_kwargs=convertDict(options.api_kwargs), + http_kwargs=convertDict(options.http_kwargs), ) catch e llm_recorder(options.llm_options, "None " * string(e), "ideas|failed") @@ -504,7 +254,7 @@ function prompt_evol(idea_database, options::Options) end llm_recorder(options.llm_options, string(msg.content), "llm_output|ideas") - idea_options = parse_msg_content(msg.content) + idea_options = parse_msg_content(String(msg.content)) N = min(size(idea_options)[1], N) @@ -584,30 +334,21 @@ function parse_msg_content(msg_content) # new method (for Llama since it follows directions better): end -function update_idea_database(idea_database, dominating, worst_members, options::Options) +function update_idea_database(dominating, worst_members, options::LaSROptions) # turn dominating pareto curve into ideas as strings if isnothing(dominating) return nothing end - op = options.operators - num_pareto_context = 5 # options.mutation_weights.num_pareto_context # must be 5 right now for prompts + gexpr = format_pareto(dominating, options, options.num_pareto_context) + bexpr = format_pareto(worst_members, options, options.num_pareto_context) - gexpr = format_pareto(dominating, options, num_pareto_context) - bexpr = format_pareto(worst_members, options, num_pareto_context) - - N = 5 - - # conversation = [ - # SystemMessage(load_prompt(options.llm_options.prompts_dir * "extract_idea_system.txt")), - # UserMessage(load_prompt(options.llm_options.prompts_dir * "extract_idea_user.txt"))] conversation = [ + SystemMessage(load_prompt(options.prompts_dir * "extract_idea_system.txt")), UserMessage( - load_prompt(options.llm_options.prompts_dir * "extract_idea_system.txt") * - "\n" * construct_prompt( construct_prompt( - load_prompt(options.llm_options.prompts_dir * "extract_idea_user.txt"), + load_prompt(options.prompts_dir * "extract_idea_user.txt"), gexpr, "gexpr", ), @@ -616,36 +357,36 @@ function update_idea_database(idea_database, dominating, worst_members, options: ), ), ] - llm_recorder(options.llm_options, conversation[1].content, "llm_input|gen_random") + rendered_msg = join( + [ + x["content"] for x in render( + CustomOpenAISchema(), + conversation; + variables=get_vars(options), + operators=get_ops(options), + N=options.num_generated_concepts, + no_system_message=false, + ) + ], + "\n", + ) + + llm_recorder(options.llm_options, rendered_msg, "llm_input|gen_random") msg = nothing try - # msg = aigenerate(OpenAISchema(), conversation; #OllamaSchema(), conversation; - # variables=get_vars(options), - # operators=get_ops(options), - # N=N, - # gexpr1=gexpr[1], - # gexpr2=gexpr[2], - # gexpr3=gexpr[3], - # gexpr4=gexpr[4], - # gexpr5=gexpr[5], - # bexpr1=bexpr[1], - # bexpr2=bexpr[2], - # bexpr3=bexpr[3], - # bexpr4=bexpr[4], - # bexpr5=bexpr[5], - # model="gpt-3.5-turbo-0125" - # ) msg = aigenerate( CustomOpenAISchema(), - conversation; #OllamaSchema(), conversation; + conversation; variables=get_vars(options), operators=get_ops(options), - N=N, - api_key=options.llm_options.api_key, - model=options.llm_options.model, - api_kwargs=convertDict(options.llm_options.api_kwargs), - http_kwargs=convertDict(options.llm_options.http_kwargs), + N=options.num_generated_concepts, + api_key=options.api_key, + model=options.model, + api_kwargs=convertDict(options.api_kwargs), + http_kwargs=convertDict(options.http_kwargs), + no_system_message=true, + verbose=false, ) catch e llm_recorder(options.llm_options, "None " * string(e), "ideas|failed") @@ -654,7 +395,7 @@ function update_idea_database(idea_database, dominating, worst_members, options: llm_recorder(options.llm_options, string(msg.content), "llm_output|ideas") - idea_options = parse_msg_content(msg.content) + idea_options = parse_msg_content(String(msg.content)) N = min(size(idea_options)[1], N) @@ -668,7 +409,7 @@ function update_idea_database(idea_database, dominating, worst_members, options: chosen_idea1 = String(strip(idea_options[a], [' ', '\n', '"', ',', '.', '[', ']'])) llm_recorder(options.llm_options, chosen_idea1, "ideas") - pushfirst!(idea_database, chosen_idea1) + pushfirst!(options.idea_database, chosen_idea1) if N > 1 b = rand(1:(N - 1)) @@ -679,65 +420,70 @@ function update_idea_database(idea_database, dominating, worst_members, options: llm_recorder(options.llm_options, chosen_idea2, "ideas") - pushfirst!(idea_database, chosen_idea2) + pushfirst!(options.idea_database, chosen_idea2) end num_add = 2 for _ in 1:num_add - out = prompt_evol(idea_database, options) + out = concept_evolution(options.idea_database, options) if !isnothing(out) - pushfirst!(idea_database, out) + pushfirst!(options.idea_database, out) end end end -function llm_mutate_op( - ex::AbstractExpression{T}, options::Options, dominating, idea_database +function llm_mutate_tree( + ex::AbstractExpression{T}, options::LaSROptions )::AbstractExpression{T} where {T<:DATA_TYPE} tree = get_contents(ex) - ex = with_contents(ex, llm_mutate_op(tree, options, dominating, idea_database)) + ex = with_contents(ex, llm_mutate_tree(tree, options)) return ex end """LLM Mutation on a tree""" -function llm_mutate_op( - tree::AbstractExpressionNode{T}, options::Options, dominating, idea_database +function llm_mutate_tree( + tree::AbstractExpressionNode{T}, options::LaSROptions )::AbstractExpressionNode{T} where {T<:DATA_TYPE} - expr = tree_to_expr(tree, options) # TODO: change global expr right now, could do it by subtree (weighted near root more) - N = 5 - # LLM prompt - # TODO: we can use async map to do concurrent requests (useful for trying multiple prompts), see: https://github.com/svilupp/PromptingTools.jl?tab=readme-ov-file#asynchronous-execution - - # conversation = [ - # SystemMessage(load_prompt(options.llm_options.prompts_dir * "mutate_system.txt")), - # UserMessage(load_prompt(options.llm_options.prompts_dir * "mutate_user.txt"))] - - assumptions = sample_context( - idea_database, - options.llm_options.num_pareto_context, - options.llm_options.idea_threshold, - ) - pareto = format_pareto(dominating, options, options.llm_options.num_pareto_context) - if !options.llm_options.prompt_concepts + expr = render_expr(tree, options) + + if isnothing(options.idea_database) assumptions = [] - pareto = [] + else + assumptions = sample_context( + options.idea_database, options.num_pareto_context, options.max_concepts + ) + end + + if !options.use_concepts + assumptions = [] + end + if options.llm_context != "" + pushfirst!(assumptions, options.llm_context) end + conversation = [ + SystemMessage(load_prompt(options.prompts_dir * "mutate_system.txt")), UserMessage( - load_prompt(options.llm_options.prompts_dir * "mutate_system.txt") * - "\n" * construct_prompt( - load_prompt(options.llm_options.prompts_dir * "mutate_user.txt"), - assumptions, - "assump", + load_prompt(options.prompts_dir * "mutate_user.txt"), assumptions, "assump" ), ), ] - llm_recorder(options.llm_options, conversation[1].content, "llm_input|mutate") + rendered_msg = join( + [ + x["content"] for x in render( + CustomOpenAISchema(), + conversation; + variables=get_vars(options), + operators=get_ops(options), + N=options.num_generated_equations, + no_system_message=false, + ) + ], + "\n", + ) - if options.llm_options.llm_context != "" - pushfirst!(assumptions, options.llm_options.llm_context) - end + llm_recorder(options.llm_options, rendered_msg, "llm_input|mutate") msg = nothing try @@ -746,12 +492,14 @@ function llm_mutate_op( conversation; #OllamaSchema(), conversation; variables=get_vars(options), operators=get_ops(options), - N=N, + N=options.num_generated_equations, expr=expr, - api_key=options.llm_options.api_key, - model=options.llm_options.model, - api_kwargs=convertDict(options.llm_options.api_kwargs), - http_kwargs=convertDict(options.llm_options.http_kwargs), + api_key=options.api_key, + model=options.model, + api_kwargs=convertDict(options.api_kwargs), + http_kwargs=convertDict(options.http_kwargs), + no_system_message=true, + verbose=false, ) catch e llm_recorder(options.llm_options, "None " * string(e), "mutate|failed") @@ -761,7 +509,7 @@ function llm_mutate_op( llm_recorder(options.llm_options, string(msg.content), "llm_output|mutate") - mut_tree_options = parse_msg_content(msg.content) + mut_tree_options = parse_msg_content(String(msg.content)) N = min(size(mut_tree_options)[1], N) @@ -772,7 +520,7 @@ function llm_mutate_op( for i in 1:N l = rand(1:N) - t = expr_to_tree( + t = parse_expr( T, String(strip(mut_tree_options[l], [' ', '\n', '"', ',', '.', '[', ']'])), options, @@ -781,26 +529,26 @@ function llm_mutate_op( continue end - llm_recorder(options.llm_options, tree_to_expr(t, options), "mutate") + llm_recorder(options.llm_options, render_expr(t, options), "mutate") return t end - out = expr_to_tree( + out = parse_expr( T, String(strip(mut_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])), options ) - llm_recorder(options.llm_options, tree_to_expr(out, options), "mutate") + llm_recorder(options.llm_options, render_expr(out, options), "mutate") return out end function llm_crossover_trees( - ex1::E, ex2::E, options::Options, dominating, idea_database + ex1::E, ex2::E, options::LaSROptions )::Tuple{E,E} where {T,E<:AbstractExpression{T}} tree1 = get_contents(ex1) tree2 = get_contents(ex2) - tree1, tree2 = llm_crossover_trees(tree1, tree2, options, dominating, idea_database) + tree1, tree2 = llm_crossover_trees(tree1, tree2, options) ex1 = with_contents(ex1, tree1) ex2 = with_contents(ex2, tree2) return ex1, ex2 @@ -808,66 +556,71 @@ end """LLM Crossover between two expressions""" function llm_crossover_trees( - tree1::AbstractExpressionNode{T}, - tree2::AbstractExpressionNode{T}, - options::Options, - dominating, - idea_database, + tree1::AbstractExpressionNode{T}, tree2::AbstractExpressionNode{T}, options::LaSROptions )::Tuple{AbstractExpressionNode{T},AbstractExpressionNode{T}} where {T<:DATA_TYPE} - expr1 = tree_to_expr(tree1, options) - expr2 = tree_to_expr(tree2, options) - N = 5 - - # LLM prompt - # conversation = [ - # SystemMessage(load_prompt(options.llm_options.prompts_dir * "crossover_system.txt")), - # UserMessage(load_prompt(options.llm_options.prompts_dir * "crossover_user.txt"))] - assumptions = sample_context( - idea_database, - options.llm_options.num_pareto_context, - options.llm_options.idea_threshold, - ) - pareto = format_pareto(dominating, options, options.llm_options.num_pareto_context) - if !options.llm_options.prompt_concepts + expr1 = render_expr(tree1, options) + expr2 = render_expr(tree2, options) + + if isnothing(options.idea_database) + assumptions = [] + else + assumptions = sample_context( + options.idea_database, + min(options.num_pareto_context, length(options.idea_database)), + options.max_concepts, + ) + end + + if !options.use_concepts assumptions = [] - pareto = [] + end + + if options.llm_context != "" + pushfirst!(assumptions, options.llm_context) end conversation = [ + SystemMessage(load_prompt(options.prompts_dir * "crossover_system.txt")), UserMessage( - load_prompt(options.llm_options.prompts_dir * "crossover_system.txt") * - "\n" * construct_prompt( - load_prompt(options.llm_options.prompts_dir * "crossover_user.txt"), + load_prompt(options.prompts_dir * "crossover_user.txt"), assumptions, "assump", ), ), ] + rendered_msg = join( + [ + x["content"] for x in render( + CustomOpenAISchema(), + conversation; + variables=get_vars(options), + operators=get_ops(options), + N=options.num_generated_equations, + no_system_message=false, + ) + ], + "\n", + ) - if options.llm_options.llm_context != "" - pushfirst!(assumptions, options.llm_options.llm_context) - end - - llm_recorder(options.llm_options, conversation[1].content, "llm_input|crossover") + llm_recorder(options.llm_options, rendered_msg, "llm_input|crossover") msg = nothing try msg = aigenerate( CustomOpenAISchema(), - conversation; #OllamaSchema(), conversation; + conversation; variables=get_vars(options), operators=get_ops(options), - N=N, - # pareto1=pareto[1], - # pareto2=pareto[2], - # pareto3=pareto[3], + N=options.num_generated_equations, expr1=expr1, expr2=expr2, - api_key=options.llm_options.api_key, - model=options.llm_options.model, - api_kwargs=convertDict(options.llm_options.api_kwargs), - http_kwargs=convertDict(options.llm_options.http_kwargs), + api_key=options.api_key, + model=options.model, + api_kwargs=convertDict(options.api_kwargs), + http_kwargs=convertDict(options.http_kwargs), + no_system_message=true, + verbose=false, ) catch e llm_recorder(options.llm_options, "None " * string(e), "crossover|failed") @@ -876,7 +629,7 @@ function llm_crossover_trees( llm_recorder(options.llm_options, string(msg.content), "llm_output|crossover") - cross_tree_options = parse_msg_content(msg.content) + cross_tree_options = parse_msg_content(String(msg.content)) cross_tree1 = nothing cross_tree2 = nothing @@ -889,20 +642,20 @@ function llm_crossover_trees( end if N == 1 - t = expr_to_tree( + t = parse_expr( T, String(strip(cross_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])), options, ) - llm_recorder(options.llm_options, tree_to_expr(t, options), "crossover") + llm_recorder(options.llm_options, render_expr(t, options), "crossover") return t, tree2 end for i in 1:(2 * N) l = rand(1:N) - t = expr_to_tree( + t = parse_expr( T, String(strip(cross_tree_options[l], [' ', '\n', '"', ',', '.', '[', ']'])), options, @@ -920,7 +673,7 @@ function llm_crossover_trees( end if isnothing(cross_tree1) - cross_tree1 = expr_to_tree( + cross_tree1 = parse_expr( T, String(strip(cross_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])), options, @@ -928,7 +681,7 @@ function llm_crossover_trees( end if isnothing(cross_tree2) - cross_tree2 = expr_to_tree( + cross_tree2 = parse_expr( T, String(strip(cross_tree_options[2], [' ', '\n', '"', ',', '.', '[', ']'])), options, @@ -936,7 +689,7 @@ function llm_crossover_trees( end recording_str = - tree_to_expr(cross_tree1, options) * " && " * tree_to_expr(cross_tree2, options) + render_expr(cross_tree1, options) * " && " * render_expr(cross_tree2, options) llm_recorder(options.llm_options, recording_str, "crossover") return cross_tree1, cross_tree2 diff --git a/src/LLMOptions.jl b/src/LLMOptions.jl index 82915219e..a73d56778 100644 --- a/src/LLMOptions.jl +++ b/src/LLMOptions.jl @@ -1,26 +1,23 @@ module LLMOptionsModule +using DispatchDoctor: @unstable using StatsBase: StatsBase using Base: isvalid +using SymbolicRegression +using ..LaSRMutationWeightsModule: LaSRMutationWeights """ - LLMWeights(;kws...) + LLMOperationWeights(;kws...) -Defines the probability of different LLM-based mutation operations. Follows the same -pattern as MutationWeights. These weights will be normalized to sum to 1.0 after initialization. +Defines the probability of different LLM-based mutation operations. +NOTE: The LLM operations can be significantly slower than their symbolic counterparts, +so higher probabilities will result in slower operations. By default, we set all probs to 0.0. +The maximum value for these parameters is 1.0 (100% of the time). # Arguments -- `llm_mutate::Float64`: Probability of calling LLM version of mutation. - The LLM operations are significantly slower than their symbolic counterparts, - so higher probabilities will result in slower operations. - `llm_crossover::Float64`: Probability of calling LLM version of crossover. - Same limitation as llm_mutate. -- `llm_gen_random::Float64`: Probability of calling LLM version of gen_random. - Same limitation as llm_mutate. """ -Base.@kwdef mutable struct LLMWeights - llm_mutate::Float64 = 0.0 +Base.@kwdef mutable struct LLMOperationWeights llm_crossover::Float64 = 0.0 - llm_gen_random::Float64 = 0.0 end """ @@ -29,60 +26,77 @@ end This defines how to call the LLM inference functions. LLM inference is managed by PromptingTools.jl but this module serves as the entry point to define new options for the LLM inference. # Arguments -- `active::Bool`: Whether to use LLM inference or not. -- `weights::LLMWeights`: Weights for different LLM operations. +- `use_llm::Bool`: Whether to use LLM inference or not. (default: false) +- `use_concepts::Bool`: Whether to summarize programs into concepts and use the concepts to guide the search. (default: false) + NOTE: If `use_llm` is false, then `use_concepts` will be ignored. +- `use_concept_evolution::Bool`: Whether to evolve the concepts after every iteration. (default: false) + NOTE: If `use_concepts` is false, then `use_concept_evolution` will be ignored. +- `lasr_weights::LLMWeights`: lasr_weights for different LLM operations. - `num_pareto_context::Int64`: Number of equations to sample from pareto frontier. -- `prompt_concepts::Bool`: Use natural language concepts in the LLM prompts. -- `prompt_evol::Bool`: Evolve natural language concepts through succesive LLM +- `use_concepts::Bool`: Use natural language concepts in the LLM prompts. +- `use_concept_evolution::Bool`: Evolve natural language concepts through succesive LLM calls. -- api_key::String: OpenAI API key. Required. -- model::String: OpenAI model to use. Required. +- api_key::AbstractString: OpenAI API key. Required. +- model::AbstractString: OpenAI model to use. Required. - api_kwargs::Dict: Additional keyword arguments to pass to the OpenAI API. - - url::String: URL to send the request to. Required. + - url::AbstractString: URL to send the request to. Required. - max_tokens::Int: Maximum number of tokens to generate. (default: 1000) - http_kwargs::Dict: Additional keyword arguments for the HTTP request. - retries::Int: Number of retries to attempt. (default: 3) - readtimeout::Int: Read timeout for the HTTP request (in seconds; default is 1 hour). -- `llm_recorder_dir::String`: File to save LLM logs to. Useful for debugging. -- `llm_context::AbstractString`: Context string for LLM. -- `var_order::Union{Dict,Nothing}`: Variable order for LLM. (default: nothing) -- `idea_threshold::UInt32`: Number of concepts to keep track of. (default: 30) +- `llm_recorder_dir::AbstractString`: File to save LLM logs to. Useful for debugging. +- `llm_context::AbstractString`: Context AbstractString for LLM. +- `variable_names::Union{Dict,Nothing}`: Variable order for LLM. (default: nothing) +- `max_concepts::UInt32`: Number of concepts to keep track of. (default: 30) """ Base.@kwdef mutable struct LLMOptions - active::Bool = false - weights::LLMWeights = LLMWeights() - num_pareto_context::Int64 = 0 - prompt_concepts::Bool = false - prompt_evol::Bool = false - api_key::String = "" - model::String = "" + # LaSR Ablation Modifiers + use_llm::Bool = false + use_concepts::Bool = false + use_concept_evolution::Bool = false + lasr_mutation_weights::LaSRMutationWeights = LaSRMutationWeights() + llm_operation_weights::LLMOperationWeights = LLMOperationWeights() + + # LaSR Performance Modifiers + num_pareto_context::Int64 = 5 + num_generated_equations::Int64 = 5 + num_generated_concepts::Int64 = 5 + max_concepts::Int64 = 30 + # This is a cheeky hack to not have to deal with parametric types in LLMFunctions.jl. TODO: High priority rectify. + is_parametric::Bool = false + llm_context::AbstractString = "" + + # LaSR Bookkeeping Utilities + # llm_logger::Union{SymbolicRegression.AbstractSRLogger, Nothing} = nothing + llm_recorder_dir::AbstractString = "lasr_runs/" + variable_names::Union{Dict,Nothing} = nothing + prompts_dir::AbstractString = "prompts/" + idea_database::Vector{AbstractString} = [] + + # LaSR LLM API Options + api_key::AbstractString = "" + model::AbstractString = "" api_kwargs::Dict = Dict("max_tokens" => 1000) http_kwargs::Dict = Dict("retries" => 3, "readtimeout" => 3600) - llm_recorder_dir::String = "lasr_runs/" - prompts_dir::String = "prompts/" - llm_context::AbstractString = "" - var_order::Union{Dict,Nothing} = nothing - idea_threshold::UInt32 = 30 - is_parametric::Bool = false end -const llm_mutations = fieldnames(LLMWeights) +const llm_mutations = fieldnames(LLMOperationWeights) const v_llm_mutations = Symbol[llm_mutations...] # Validate some options are set correctly. """Validate some options are set correctly. Specifically, need to check -- If `active` is true, then `api_key` and `model` must be set. -- If `active` is true, then `api_kwargs` must have a `url` key and it must be a valid URL. -- If `active` is true, then `llm_recorder_dir` must be a valid directory. +- If `use_llm` is true, then `api_key` and `model` must be set. +- If `use_llm` is true, then `api_kwargs` must have a `url` key and it must be a valid URL. +- If `use_llm` is true, then `llm_recorder_dir` must be a valid directory. """ function validate_llm_options(options::LLMOptions) - if options.active + if options.use_llm if options.api_key == "" - throw(ArgumentError("api_key must be set if LLM is active.")) + throw(ArgumentError("api_key must be set if LLM is use_llm.")) end if options.model == "" - throw(ArgumentError("model must be set if LLM is active.")) + throw(ArgumentError("model must be set if LLM is use_llm.")) end if !haskey(options.api_kwargs, "url") throw(ArgumentError("api_kwargs must have a 'url' key.")) @@ -93,28 +107,63 @@ function validate_llm_options(options::LLMOptions) end end -# """Sample LLM mutation, given the weightings.""" -# function sample_llm_mutation(w::LLMWeights) -# weights = convert(Vector, w) -# return StatsBase.sample(v_llm, StatsBase.Weights(weights)) -# end +""" + LaSROptions(;kws...) -end # module +This defines the options for the LibraryAugmentedSymbolicRegression module. It is a composite +type that contains both the LLMOptions and the SymbolicRegression.Options. +# Arguments +- `llm_options::LLMOptions`: Options for the LLM inference. +- `sr_options::SymbolicRegression.Options`: Options for the SymbolicRegression module. + +# Example +```julia +llm_options = LLMOptions(; + ... +) + +options = Options(; + binary_operators = (+, *, -, /, ^), + unary_operators = (cos, log), + nested_constraints = [(^) => [(^) => 0, cos => 0, log => 0], (/) => [(/) => 1], (cos) => [cos => 0, log => 0], log => [log => 0, cos => 0, (^) => 0]], + constraints = [(^) => (3, 1), log => 5, cos => 7], + populations=20, +) -# sample invocation following: -# python -m experiments.main --use_llm --use_prompt_evol --model "meta-llama/Meta-Llama-3-8B-Instruct" --api_key "vllm_api.key" --model_url "http://localhost:11440/v1" --exp_idx 0 --dataset_path FeynmanEquations.csv --start_idx 0 -# options = LLMOptions( -# active=true, -# weights=LLMWeights(llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2), -# num_pareto_context=5, -# prompt_evol=true, -# prompt_concepts=true, -# api_key="vllm_api.key", -# model="meta-llama/Meta-Llama-3-8B-Instruct", -# api_kwargs=Dict("url" => "http://localhost:11440/v1"), -# http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), -# llm_recorder_dir="lasr_runs/", -# llm_context="", -# var_order=nothing, -# idea_threshold=30 -# ) +lasr_options = LaSROptions(llm_options, options) +``` + +""" +struct LaSROptions{O<:SymbolicRegression.Options} <: SymbolicRegression.AbstractOptions + llm_options::LLMOptions + sr_options::O +end +const LLM_OPTIONS_KEYS = fieldnames(LLMOptions) + +# Constructor with both sets of parameters: +@unstable function LaSROptions(; kws...) + llm_options_keys = filter(k -> k in LLM_OPTIONS_KEYS, keys(kws)) + llm_options = LLMOptions(; + NamedTuple(llm_options_keys .=> Tuple(kws[k] for k in llm_options_keys))... + ) + sr_options_keys = filter(k -> !(k in LLM_OPTIONS_KEYS), keys(kws)) + sr_options = SymbolicRegression.Options(; + NamedTuple(sr_options_keys .=> Tuple(kws[k] for k in sr_options_keys))... + ) + return LaSROptions(llm_options, sr_options) +end + +# Make all `Options` available while also making `llm_options` accessible +function Base.getproperty(options::LaSROptions, k::Symbol) + if k in LLM_OPTIONS_KEYS + return getproperty(getfield(options, :llm_options), k) + else + return getproperty(getfield(options, :sr_options), k) + end +end + +function Base.propertynames(options::LaSROptions) + return (LLM_OPTIONS_KEYS..., fieldnames(SymbolicRegression.Options)...) +end + +end # module diff --git a/src/LLMUtils.jl b/src/LLMUtils.jl new file mode 100644 index 000000000..e7dafc50b --- /dev/null +++ b/src/LLMUtils.jl @@ -0,0 +1,181 @@ +module LLMUtilsModule + +using Random: rand, randperm +using DynamicExpressions: + Node, + AbstractExpressionNode, + AbstractExpression, + ParametricExpression, + ParametricNode, + AbstractNode, + NodeSampler, + get_contents, + with_contents, + constructorof, + copy_node, + set_node!, + count_nodes, + has_constants, + has_operators, + string_tree, + AbstractOperatorEnum +using SymbolicRegression: DATA_TYPE +using ..LLMOptionsModule: LaSROptions +using ..ParseModule: render_expr +using JSON: parse + +"""LLM Recoder records the LLM calls for debugging purposes.""" +function llm_recorder(options::LaSROptions, expr::String, mode::String="debug") + if options.use_llm + if !isdir(options.llm_recorder_dir) + mkdir(options.llm_recorder_dir) + end + recorder = open(joinpath(options.llm_recorder_dir, "llm_calls.txt"), "a") + write(recorder, string("[", mode, "]\n", expr, "\n[/", mode, "]\n")) + close(recorder) + end +end + +function load_prompt(path::String)::String + # load prompt file + f = open(path, "r") + s = read(f, String) + s = strip(s) + close(f) + return s +end + +function convertDict(d)::NamedTuple + return (; Dict(Symbol(k) => v for (k, v) in d)...) +end + +function get_vars(options::LaSROptions)::String + variable_names = get_variable_names(options.variable_names) + return join(variable_names, ", ") +end + +function get_ops(options::LaSROptions)::String + binary_operators = map(v -> string(v), options.operators.binops) + unary_operators = map(v -> string(v), options.operators.unaops) + # Binary Ops: +, *, -, /, safe_pow (^) + # Unary Ops: exp, safe_log, safe_sqrt, sin, cos + return replace( + replace( + "binary operators: " * + join(binary_operators, ", ") * + ", and unary operators: " * + join(unary_operators, ", "), + "safe_" => "", + ), + "pow" => "^", + ) +end + +""" +Constructs a prompt by replacing the element_id_tag with the corresponding element in the element_list. +If the element_list is longer than the number of occurrences of the element_id_tag, the missing elements are added after the last occurrence. +If the element_list is shorter than the number of occurrences of the element_id_tag, the extra ids are removed. +""" +function construct_prompt( + user_prompt::String, element_list::Vector, element_id_tag::String +)::String + # Split the user prompt into lines + lines = split(user_prompt, "\n") + + # Filter lines that match the pattern "... : {{element_id_tag[1-9]}} + pattern = r"^.*: \{\{" * element_id_tag * r"\d+\}\}$" + + # find all occurrences of the element_id_tag + n_occurrences = count(x -> occursin(pattern, x), lines) + + # if n_occurrences is less than |element_list|, add the missing elements after the last occurrence + if n_occurrences < length(element_list) + last_occurrence = findlast(x -> occursin(pattern, x), lines) + @assert last_occurrence !== nothing "No occurrences of the element_id_tag found in the user prompt." + for i in reverse((n_occurrences + 1):length(element_list)) + new_line = replace(lines[last_occurrence], string(n_occurrences) => string(i)) + insert!(lines, last_occurrence + 1, new_line) + end + end + + new_prompt = "" + idx = 1 + for line in lines + # if the line matches the pattern + if occursin(pattern, line) + if idx > length(element_list) + continue + end + # replace the element_id_tag with the corresponding element + new_prompt *= + replace(line, r"\{\{" * element_id_tag * r"\d+\}\}" => element_list[idx]) * + "\n" + idx += 1 + else + new_prompt *= line * "\n" + end + end + return new_prompt +end + +function format_pareto(dominating, options, num_pareto_context::Int)::Vector{String} + pareto = Vector{String}() + if !isnothing(dominating) && size(dominating)[1] > 0 + idx = randperm(size(dominating)[1]) + for i in 1:min(size(dominating)[1], num_pareto_context) + push!(pareto, render_expr(dominating[idx[i]].tree, options)) + end + end + while size(pareto)[1] < num_pareto_context + push!(pareto, "None") + end + return pareto +end + +function sample_one_context(idea_database, max_concepts)::String + if isnothing(idea_database) + return "None" + end + + N = size(idea_database)[1] + if N == 0 + return "None" + end + + try + idea_database[rand(1:min(max_concepts, N))] + catch e + "None" + end +end + +function sample_context(idea_database, N, max_concepts)::Vector{String} + assumptions = Vector{String}() + if isnothing(idea_database) + for _ in 1:N + push!(assumptions, "None") + end + return assumptions + end + + if size(idea_database)[1] < N + for i in 1:(size(idea_database)[1]) + push!(assumptions, idea_database[i]) + end + for i in (size(idea_database)[1] + 1):N + push!(assumptions, "None") + end + return assumptions + end + + while size(assumptions)[1] < N + chosen_idea = sample_one_context(idea_database, max_concepts) + if chosen_idea in assumptions + continue + end + push!(assumptions, chosen_idea) + end + return assumptions +end + +end diff --git a/src/LibraryAugmentedSymbolicRegression.jl b/src/LibraryAugmentedSymbolicRegression.jl index 27e99b5bb..b100b5de2 100644 --- a/src/LibraryAugmentedSymbolicRegression.jl +++ b/src/LibraryAugmentedSymbolicRegression.jl @@ -7,151 +7,36 @@ export Population, Options, Dataset, MutationWeights, - LLMWeights, - LLMOptions, Node, - GraphNode, - ParametricNode, - Expression, - ParametricExpression, - StructuredExpression, - NodeSampler, - AbstractExpression, - AbstractExpressionNode, LaSRRegressor, MultitargetLaSRRegressor, - LOSS_TYPE, - DATA_TYPE, - - #Functions: - equation_search, - s_r_cycle, - calculate_pareto_frontier, - count_nodes, - compute_complexity, - @parse_expression, - parse_expression, - print_tree, - string_tree, - eval_tree_array, - eval_diff_tree_array, - eval_grad_tree_array, - differentiable_eval_tree_array, - set_node!, - copy_node, - node_to_symbolic, - node_type, - symbolic_to_node, - simplify_tree!, - tree_mapreduce, - combine_operators, - gen_random_tree, - gen_random_tree_fixed_size, - @extend_operators, - get_tree, - get_contents, - get_metadata, - - #Operators - plus, - sub, - mult, - square, - cube, - pow, - safe_pow, - safe_log, - safe_log2, - safe_log10, - safe_log1p, - safe_acosh, - safe_sqrt, - neg, - greater, - cond, - relu, - logical_or, - logical_and, - - # special operators - gamma, - erf, - erfc, - atanh_clip + # Options: + LLMOperationWeights, + LLMOptions, + LaSROptions, + LLMMutationProbabilities, + LaSRMutationWeights, + # Functions: + llm_randomize_tree, + llm_crossover_trees, + llm_mutate_tree, + crossover_trees, + concept_evolution, + update_idea_database, + mutate!, + crossover_generation, + # Utilities: + render_expr, + parse_expr, + parse_msg_content, + llm_recorder, + construct_prompt using Distributed -using Printf: @printf, @sprintf using PackageExtensionCompat: @require_extensions using Pkg: Pkg using TOML: parsefile -using Random: seed!, shuffle! using Reexport -using DynamicExpressions: - Node, - GraphNode, - ParametricNode, - Expression, - ParametricExpression, - StructuredExpression, - NodeSampler, - AbstractExpression, - AbstractExpressionNode, - @parse_expression, - parse_expression, - copy_node, - set_node!, - string_tree, - print_tree, - count_nodes, - get_constants, - get_scalar_constants, - set_constants!, - set_scalar_constants!, - index_constants, - NodeIndex, - eval_tree_array, - differentiable_eval_tree_array, - eval_diff_tree_array, - eval_grad_tree_array, - node_to_symbolic, - symbolic_to_node, - combine_operators, - simplify_tree!, - tree_mapreduce, - set_default_variable_names!, - node_type, - get_tree, - get_contents, - get_metadata -using DynamicExpressions: with_type_parameters -@reexport using LossFunctions: - MarginLoss, - DistanceLoss, - SupervisedLoss, - ZeroOneLoss, - LogitMarginLoss, - PerceptronLoss, - HingeLoss, - L1HingeLoss, - L2HingeLoss, - SmoothedL1HingeLoss, - ModifiedHuberLoss, - L2MarginLoss, - ExpLoss, - SigmoidLoss, - DWDMarginLoss, - LPDistLoss, - L1DistLoss, - L2DistLoss, - PeriodicLoss, - HuberLoss, - EpsilonInsLoss, - L1EpsilonInsLoss, - L2EpsilonInsLoss, - LogitDistLoss, - QuantileLoss, - LogCoshLoss - # https://discourse.julialang.org/t/how-to-find-out-the-version-of-a-package-from-its-module/37755/15 const PACKAGE_VERSION = try root = pkgdir(@__MODULE__) @@ -166,739 +51,65 @@ catch VersionNumber(0, 0, 0) end -function deprecate_varmap(variable_names, varMap, func_name) - if varMap !== nothing - Base.depwarn("`varMap` is deprecated; use `variable_names` instead", func_name) - @assert variable_names === nothing "Cannot pass both `varMap` and `variable_names`" - variable_names = varMap - end - return variable_names -end - using DispatchDoctor: @stable +@reexport using SymbolicRegression +using .SymbolicRegression: + @recorder, @sr_spawner, AbstractSearchState, AbstractRuntimeOptions @stable default_mode = "disable" begin include("Utils.jl") - include("InterfaceDynamicQuantities.jl") - include("Core.jl") - include("InterfaceDynamicExpressions.jl") - include("Recorder.jl") - include("Complexity.jl") - include("DimensionalAnalysis.jl") - include("CheckConstraints.jl") - include("AdaptiveParsimony.jl") - include("MutationFunctions.jl") + include("Parse.jl") + include("MutationWeights.jl") + include("LLMOptions.jl") + include("LLMUtils.jl") include("LLMFunctions.jl") - include("LossFunctions.jl") - include("PopMember.jl") - include("ConstantOptimization.jl") - include("Population.jl") - include("HallOfFame.jl") include("Mutate.jl") - include("RegularizedEvolution.jl") - include("SingleIteration.jl") - include("ProgressBars.jl") - include("Migration.jl") - include("SearchUtils.jl") - include("ExpressionBuilder.jl") + include("Core.jl") end using .CoreModule: - MAX_DEGREE, - BATCH_DIM, - FEATURE_DIM, - DATA_TYPE, - LOSS_TYPE, - RecordType, - Dataset, - Options, - MutationWeights, + LLMOperationWeights, LLMOptions, - LLMWeights, - plus, - sub, - mult, - square, - cube, - pow, - safe_pow, - safe_log, - safe_log2, - safe_log10, - safe_log1p, - safe_sqrt, - safe_acosh, - neg, - greater, - cond, - relu, - logical_or, - logical_and, - gamma, - erf, - erfc, - atanh_clip, - create_expression -using .UtilsModule: is_anonymous_function, recursive_merge, json3_write, @ignore -using .ComplexityModule: compute_complexity -using .CheckConstraintsModule: check_constraints -using .AdaptiveParsimonyModule: - RunningSearchStatistics, update_frequencies!, move_window!, normalize_frequencies! -using .MutationFunctionsModule: - gen_random_tree, - gen_random_tree_fixed_size, - random_node, - random_node_and_parent, - crossover_trees -using .LLMFunctionsModule: update_idea_database, llm_recorder - -using .InterfaceDynamicExpressionsModule: @extend_operators -using .LossFunctionsModule: eval_loss, score_func, update_baseline_loss! -using .PopMemberModule: PopMember, reset_birth! -using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample -using .HallOfFameModule: - HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve -using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population -using .ProgressBarsModule: WrappedProgressBar -using .RecorderModule: @recorder, find_iteration_from_record -using .MigrationModule: migrate! -using .SearchUtilsModule: - SearchState, - RuntimeOptions, - WorkerAssignments, - DefaultWorkerOutputType, - assign_next_worker!, - get_worker_output_type, - extract_from_worker, - @sr_spawner, - StdinReader, - watch_stream, - close_reader!, - check_for_user_quit, - check_for_loss_threshold, - check_for_timeout, - check_max_evals, - ResourceMonitor, - record_channel_state!, - estimate_work_fraction, - update_progress_bar!, - print_search_state, - init_dummy_pops, - load_saved_hall_of_fame, - load_saved_population, - construct_datasets, - save_to_file, - get_cur_maxsize, - update_hall_of_fame! -using .ExpressionBuilderModule: embed_metadata, strip_metadata + LaSROptions, + LLMMutationProbabilities, + LaSRMutationWeights -@stable default_mode = "disable" begin - include("deprecates.jl") - include("Configure.jl") -end +using .UtilsModule: is_anonymous_function, recursive_merge, json3_write, @ignore +using .LLMFunctionsModule: + llm_randomize_tree, + llm_mutate_tree, + crossover_trees, + llm_crossover_trees, + concept_evolution, + parse_msg_content, + update_idea_database +using .LLMUtilsModule: construct_prompt, llm_recorder +using .ParseModule: render_expr, parse_expr +using .MutateModule: mutate!, crossover_generation """ - equation_search(X, y[; kws...]) - -Perform a distributed equation search for functions `f_i` which -describe the mapping `f_i(X[:, j]) ≈ y[i, j]`. Options are -configured using LibraryAugmentedSymbolicRegression.Options(...), -which should be passed as a keyword argument to options. -One can turn off parallelism with `numprocs=0`, -which is useful for debugging and profiling. - -# Arguments -- `X::AbstractMatrix{T}`: The input dataset to predict `y` from. - The first dimension is features, the second dimension is rows. -- `y::Union{AbstractMatrix{T}, AbstractVector{T}}`: The values to predict. The first dimension - is the output feature to predict with each equation, and the - second dimension is rows. -- `niterations::Int=10`: The number of iterations to perform the search. - More iterations will improve the results. -- `weights::Union{AbstractMatrix{T}, AbstractVector{T}, Nothing}=nothing`: Optionally - weight the loss for each `y` by this value (same shape as `y`). -- `options::Options=Options()`: The options for the search, such as - which operators to use, evolution hyperparameters, etc. -- `variable_names::Union{Vector{String}, Nothing}=nothing`: The names - of each feature in `X`, which will be used during printing of equations. -- `display_variable_names::Union{Vector{String}, Nothing}=variable_names`: Names - to use when printing expressions during the search, but not when saving - to an equation file. -- `y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing`: The - names of each output feature in `y`, which will be used during printing - of equations. -- `parallelism=:multithreading`: What parallelism mode to use. - The options are `:multithreading`, `:multiprocessing`, and `:serial`. - By default, multithreading will be used. Multithreading uses less memory, - but multiprocessing can handle multi-node compute. If using `:multithreading` - mode, the number of threads available to julia are used. If using - `:multiprocessing`, `numprocs` processes will be created dynamically if - `procs` is unset. If you have already allocated processes, pass them - to the `procs` argument and they will be used. - You may also pass a string instead of a symbol, like `"multithreading"`. -- `numprocs::Union{Int, Nothing}=nothing`: The number of processes to use, - if you want `equation_search` to set this up automatically. By default - this will be `4`, but can be any number (you should pick a number <= - the number of cores available). -- `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up - a distributed run manually with `procs = addprocs()` and `@everywhere`, - pass the `procs` to this keyword argument. -- `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing - (`parallelism=:multithreading`), and are not passing `procs` manually, - then they will be allocated dynamically using `addprocs`. However, - you may also pass a custom function to use instead of `addprocs`. - This function should take a single positional argument, - which is the number of processes to use, as well as the `lazy` keyword argument. - For example, if set up on a slurm cluster, you could pass - `addprocs_function = addprocs_slurm`, which will set up slurm processes. -- `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint` - flag on Julia processes, recommending garbage collection once a process - is close to the recommended size. This is important for long-running distributed - jobs where each process has an independent memory, and can help avoid - out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`. -- `runtests::Bool=true`: Whether to run (quick) tests before starting the - search, to see if there will be any problems during the equation search - related to the host environment. -- `saved_state=nothing`: If you have already - run `equation_search` and want to resume it, pass the state here. - To get this to work, you need to have set return_state=true, - which will cause `equation_search` to return the state. The second - element of the state is the regular return value with the hall of fame. - Note that you cannot change the operators or dataset, but most other options - should be changeable. -- `return_state::Union{Bool, Nothing}=nothing`: Whether to return the - state of the search for warm starts. By default this is false. -- `loss_type::Type=Nothing`: If you would like to use a different type - for the loss than for the data you passed, specify the type here. - Note that if you pass complex data `::Complex{L}`, then the loss - type will automatically be set to `L`. -- `verbosity`: Whether to print debugging statements or not. -- `progress`: Whether to use a progress bar output. Only available for - single target output. -- `X_units::Union{AbstractVector,Nothing}=nothing`: The units of the dataset, - to be used for dimensional constraints. For example, if `X_units=["kg", "m"]`, - then the first feature will have units of kilograms, and the second will - have units of meters. -- `y_units=nothing`: The units of the output, to be used for dimensional constraints. - If `y` is a matrix, then this can be a vector of units, in which case - each element corresponds to each output feature. - -# Returns -- `hallOfFame::HallOfFame`: The best equations seen during the search. - hallOfFame.members gives an array of `PopMember` objects, which - have their tree (equation) stored in `.tree`. Their score (loss) - is given in `.score`. The array of `PopMember` objects - is enumerated by size from `1` to `options.maxsize`. +@TODO: Modularize _main_search_loop! function so that I don't have to change the +entire function to accomodate prompt evolution. """ -function equation_search( - X::AbstractMatrix{T}, - y::AbstractMatrix{T}; - niterations::Int=10, - weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing, - options::Options=Options(), - variable_names::Union{AbstractVector{String},Nothing}=nothing, - display_variable_names::Union{AbstractVector{String},Nothing}=variable_names, - y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing, - parallelism=:multithreading, - numprocs::Union{Int,Nothing}=nothing, - procs::Union{Vector{Int},Nothing}=nothing, - addprocs_function::Union{Function,Nothing}=nothing, - heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing, - runtests::Bool=true, - saved_state=nothing, - return_state::Union{Bool,Nothing,Val}=nothing, - loss_type::Type{L}=Nothing, - verbosity::Union{Integer,Nothing}=nothing, - progress::Union{Bool,Nothing}=nothing, - X_units::Union{AbstractVector,Nothing}=nothing, - y_units=nothing, - extra::NamedTuple=NamedTuple(), - v_dim_out::Val{DIM_OUT}=Val(nothing), - # Deprecated: - multithreaded=nothing, - varMap=nothing, -) where {T<:DATA_TYPE,L,DIM_OUT} - if multithreaded !== nothing - error( - "`multithreaded` is deprecated. Use the `parallelism` argument instead. " * - "Choose one of :multithreaded, :multiprocessing, or :serial.", - ) - end - variable_names = deprecate_varmap(variable_names, varMap, :equation_search) - - if weights !== nothing - @assert length(weights) == length(y) - weights = reshape(weights, size(y)) - end - - datasets = construct_datasets( - X, - y, - weights, - variable_names, - display_variable_names, - y_variable_names, - X_units, - y_units, - extra, - L, - ) - - return equation_search( - datasets; - niterations=niterations, - options=options, - parallelism=parallelism, - numprocs=numprocs, - procs=procs, - addprocs_function=addprocs_function, - heap_size_hint_in_bytes=heap_size_hint_in_bytes, - runtests=runtests, - saved_state=saved_state, - return_state=return_state, - verbosity=verbosity, - progress=progress, - v_dim_out=Val(DIM_OUT), - ) -end - -function equation_search( - X::AbstractMatrix{T1}, y::AbstractMatrix{T2}; kw... -) where {T1<:DATA_TYPE,T2<:DATA_TYPE} - U = promote_type(T1, T2) - return equation_search( - convert(AbstractMatrix{U}, X), convert(AbstractMatrix{U}, y); kw... - ) -end - -function equation_search( - X::AbstractMatrix{T1}, y::AbstractVector{T2}; kw... -) where {T1<:DATA_TYPE,T2<:DATA_TYPE} - return equation_search(X, reshape(y, (1, size(y, 1))); kw..., v_dim_out=Val(1)) -end - -function equation_search(dataset::Dataset; kws...) - return equation_search([dataset]; kws..., v_dim_out=Val(1)) -end - -function equation_search( - datasets::Vector{D}; - niterations::Int=10, - options::Options=Options(), - parallelism=:multithreading, - numprocs::Union{Int,Nothing}=nothing, - procs::Union{Vector{Int},Nothing}=nothing, - addprocs_function::Union{Function,Nothing}=nothing, - heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing, - runtests::Bool=true, - saved_state=nothing, - return_state::Union{Bool,Nothing,Val}=nothing, - verbosity::Union{Int,Nothing}=nothing, - progress::Union{Bool,Nothing}=nothing, - v_dim_out::Val{DIM_OUT}=Val(nothing), -) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} - concurrency = if parallelism in (:multithreading, "multithreading") - :multithreading - elseif parallelism in (:multiprocessing, "multiprocessing") - :multiprocessing - elseif parallelism in (:serial, "serial") - :serial - else - error( - "Invalid parallelism mode: $parallelism. " * - "You must choose one of :multithreading, :multiprocessing, or :serial.", - ) - :serial - end - not_distributed = concurrency in (:multithreading, :serial) - not_distributed && - procs !== nothing && - error( - "`procs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", - ) - not_distributed && - numprocs !== nothing && - error( - "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", - ) - - _return_state = if return_state isa Val - first(typeof(return_state).parameters) - else - if options.return_state === Val(nothing) - return_state === nothing ? false : return_state - else - @assert( - return_state === nothing, - "You cannot set `return_state` in both the `Options` and in the passed arguments." - ) - first(typeof(options.return_state).parameters) - end - end - - dim_out = if DIM_OUT === nothing - length(datasets) > 1 ? 2 : 1 - else - DIM_OUT - end - _numprocs::Int = if numprocs === nothing - if procs === nothing - 4 - else - length(procs) - end - else - if procs === nothing - numprocs - else - @assert length(procs) == numprocs - numprocs - end - end - - _verbosity = if verbosity === nothing && options.verbosity === nothing - 1 - elseif verbosity === nothing && options.verbosity !== nothing - options.verbosity - elseif verbosity !== nothing && options.verbosity === nothing - verbosity - else - error( - "You cannot set `verbosity` in both the search parameters `Options` and the call to `equation_search`.", - ) - 1 - end - _progress::Bool = if progress === nothing && options.progress === nothing - (_verbosity > 0) && length(datasets) == 1 - elseif progress === nothing && options.progress !== nothing - options.progress - elseif progress !== nothing && options.progress === nothing - progress - else - error( - "You cannot set `progress` in both the search parameters `Options` and the call to `equation_search`.", - ) - false - end - - _addprocs_function = addprocs_function === nothing ? addprocs : addprocs_function - - exeflags = if VERSION >= v"1.9" && concurrency == :multiprocessing - heap_size_hint_in_megabytes = floor( - Int, ( - if heap_size_hint_in_bytes === nothing - (Sys.free_memory() / _numprocs) - else - heap_size_hint_in_bytes - end - ) / 1024^2 - ) - _verbosity > 0 && - heap_size_hint_in_bytes === nothing && - @info "Automatically setting `--heap-size-hint=$(heap_size_hint_in_megabytes)M` on each Julia process. You can configure this with the `heap_size_hint_in_bytes` parameter." - - `--heap-size=$(heap_size_hint_in_megabytes)M` - else - `` - end - - # Underscores here mean that we have mutated the variable - return _equation_search( - datasets, - RuntimeOptions(; - niterations=niterations, - total_cycles=options.populations * niterations, - numprocs=_numprocs, - init_procs=procs, - addprocs_function=_addprocs_function, - exeflags=exeflags, - runtests=runtests, - verbosity=_verbosity, - progress=_progress, - parallelism=Val(concurrency), - dim_out=Val(dim_out), - return_state=Val(_return_state), - ), - options, - saved_state, - ) -end - -@noinline function _equation_search( - datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state -) where {D<:Dataset} - # PROMPT EVOLUTION - idea_database_all = [Vector{String}() for j in 1:length(datasets)] - - _validate_options(datasets, ropt, options) - state = _create_workers(datasets, ropt, options) - _initialize_search!(state, datasets, ropt, options, saved_state, idea_database_all) - _warmup_search!(state, datasets, ropt, options, idea_database_all) - _main_search_loop!(state, datasets, ropt, options, idea_database_all) - _tear_down!(state, ropt, options) - return _format_output(state, datasets, ropt, options) -end - -function _validate_options( - datasets::Vector{D}, ropt::RuntimeOptions, options::Options -) where {T,L,D<:Dataset{T,L}} - example_dataset = first(datasets) - nout = length(datasets) - @assert nout >= 1 - @assert (nout == 1 || ropt.dim_out == 2) - @assert options.populations >= 1 - if ropt.progress - @assert(nout == 1, "You cannot display a progress bar for multi-output searches.") - @assert(ropt.verbosity > 0, "You cannot display a progress bar with `verbosity=0`.") - end - if options.node_type <: GraphNode && ropt.verbosity > 0 - @warn "The `GraphNode` interface and mutation operators are experimental and will change in future versions." - end - if ropt.runtests - test_option_configuration(ropt.parallelism, datasets, options, ropt.verbosity) - test_dataset_configuration(example_dataset, options, ropt.verbosity) - end - for dataset in datasets - update_baseline_loss!(dataset, options) - end - if options.define_helper_functions - set_default_variable_names!(first(datasets).variable_names) - end - if options.seed !== nothing - seed!(options.seed) - end - return nothing -end -@stable default_mode = "disable" function _create_workers( - datasets::Vector{D}, ropt::RuntimeOptions, options::Options -) where {T,L,D<:Dataset{T,L}} - stdin_reader = watch_stream(stdin) - - record = RecordType() - @recorder record["options"] = "$(options)" - - nout = length(datasets) - example_dataset = first(datasets) - example_ex = create_expression(zero(T), options, example_dataset) - NT = typeof(example_ex) - PopType = Population{T,L,NT} - HallOfFameType = HallOfFame{T,L,NT} - WorkerOutputType = get_worker_output_type( - Val(ropt.parallelism), PopType, HallOfFameType - ) - ChannelType = ropt.parallelism == :multiprocessing ? RemoteChannel : Channel - - # Pointers to populations on each worker: - worker_output = Vector{WorkerOutputType}[WorkerOutputType[] for j in 1:nout] - # Initialize storage for workers - tasks = [Task[] for j in 1:nout] - # Set up a channel to send finished populations back to head node - channels = [[ChannelType(1) for i in 1:(options.populations)] for j in 1:nout] - (procs, we_created_procs) = if ropt.parallelism == :multiprocessing - configure_workers(; - procs=ropt.init_procs, - ropt.numprocs, - ropt.addprocs_function, - options, - project_path=splitdir(Pkg.project().path)[1], - file=@__FILE__, - ropt.exeflags, - ropt.verbosity, - example_dataset, - ropt.runtests, - ) - else - Int[], false - end - # Get the next worker process to give a job: - worker_assignment = WorkerAssignments() - # Randomly order which order to check populations: - # This is done so that we do work on all nout equally. - task_order = [(j, i) for j in 1:nout for i in 1:(options.populations)] - shuffle!(task_order) - - # Persistent storage of last-saved population for final return: - last_pops = init_dummy_pops(options.populations, datasets, options) - # Best 10 members from each population for migration: - best_sub_pops = init_dummy_pops(options.populations, datasets, options) - # TODO: Should really be one per population too. - all_running_search_statistics = [ - RunningSearchStatistics(; options=options) for j in 1:nout - ] - # Records the number of evaluations: - # Real numbers indicate use of batching. - num_evals = [[0.0 for i in 1:(options.populations)] for j in 1:nout] - - halls_of_fame = Vector{HallOfFameType}(undef, nout) - - cycles_remaining = [ropt.total_cycles for j in 1:nout] - cur_maxsizes = [ - get_cur_maxsize(; options, ropt.total_cycles, cycles_remaining=cycles_remaining[j]) - for j in 1:nout - ] - - return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; - procs=procs, - we_created_procs=we_created_procs, - worker_output=worker_output, - tasks=tasks, - channels=channels, - worker_assignment=worker_assignment, - task_order=task_order, - halls_of_fame=halls_of_fame, - last_pops=last_pops, - best_sub_pops=best_sub_pops, - all_running_search_statistics=all_running_search_statistics, - num_evals=num_evals, - cycles_remaining=cycles_remaining, - cur_maxsizes=cur_maxsizes, - stdin_reader=stdin_reader, - record=Ref(record), - ) -end -function _initialize_search!( - state::SearchState{T,L,N}, - datasets, - ropt::RuntimeOptions, - options::Options, - saved_state, - idea_database_all::Vector{Vector{String}}, -) where {T,L,N} - nout = length(datasets) - - init_hall_of_fame = load_saved_hall_of_fame(saved_state) - if init_hall_of_fame === nothing - for j in 1:nout - state.halls_of_fame[j] = HallOfFame(options, datasets[j]) - end - else - # Recompute losses for the hall of fame, in - # case the dataset changed: - for j in eachindex(init_hall_of_fame, datasets, state.halls_of_fame) - hof = strip_metadata(init_hall_of_fame[j], options, datasets[j]) - for member in hof.members[hof.exists] - score, result_loss = score_func(datasets[j], member, options) - member.score = score - member.loss = result_loss - end - state.halls_of_fame[j] = hof - end - end - - for j in 1:nout, i in 1:(options.populations) - worker_idx = assign_next_worker!( - state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs - ) - saved_pop = load_saved_population(saved_state; out=j, pop=i) - new_pop = - if saved_pop !== nothing && length(saved_pop.members) == options.population_size - _saved_pop = strip_metadata(saved_pop, options, datasets[j]) - ## Update losses: - for member in _saved_pop.members - score, result_loss = score_func(datasets[j], member, options) - member.score = score - member.loss = result_loss - end - copy_pop = copy(_saved_pop) - @sr_spawner( - begin - (copy_pop, HallOfFame(options, datasets[j]), RecordType(), 0.0) - end, - parallelism = ropt.parallelism, - worker_idx = worker_idx - ) - else - if saved_pop !== nothing && ropt.verbosity > 0 - @warn "Recreating population (output=$(j), population=$(i)), as the saved one doesn't have the correct number of members." - end - @sr_spawner( - begin - ( - Population( - datasets[j]; - population_size=options.population_size, - nlength=3, - options=options, - nfeatures=datasets[j].nfeatures, - idea_database=idea_database_all[j], - ), - HallOfFame(options, datasets[j]), - RecordType(), - Float64(options.population_size), - ) - end, - parallelism = ropt.parallelism, - worker_idx = worker_idx - ) - # This involves population_size evaluations, on the full dataset: - end - push!(state.worker_output[j], new_pop) - end - return nothing -end -function _warmup_search!( - state::SearchState{T,L,N}, - datasets, - ropt::RuntimeOptions, - options::Options, - idea_database_all, -) where {T,L,N} - nout = length(datasets) - for j in 1:nout, i in 1:(options.populations) - dataset = datasets[j] - running_search_statistics = state.all_running_search_statistics[j] - cur_maxsize = state.cur_maxsizes[j] - @recorder state.record[]["out$(j)_pop$(i)"] = RecordType() - worker_idx = assign_next_worker!( - state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs - ) - - # TODO - why is this needed?? - # Multi-threaded doesn't like to fetch within a new task: - c_rss = deepcopy(running_search_statistics) - last_pop = state.worker_output[j][i] - updated_pop = @sr_spawner( - begin - in_pop = first( - extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N}) - ) - _dispatch_s_r_cycle( - in_pop, - dataset, - options; - pop=i, - out=j, - iteration=0, - ropt.verbosity, - cur_maxsize, - running_search_statistics=c_rss, - idea_database=idea_database_all[j], - )::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}} - end, - parallelism = ropt.parallelism, - worker_idx = worker_idx - ) - state.worker_output[j][i] = updated_pop - end - return nothing -end function _main_search_loop!( - state::SearchState{T,L,N}, + state::AbstractSearchState{T,L,N}, datasets, - ropt::RuntimeOptions, - options::Options, - idea_database_all, + ropt::AbstractRuntimeOptions, + options::LaSROptions, ) where {T,L,N} ropt.verbosity > 0 && @info "Started!" nout = length(datasets) start_time = time() - if ropt.progress + progress_bar = if ropt.progress #TODO: need to iterate this on the max cycles remaining! sum_cycle_remaining = sum(state.cycles_remaining) - progress_bar = WrappedProgressBar( - 1:sum_cycle_remaining; width=options.terminal_width + WrappedProgressBar( + sum_cycle_remaining, ropt.niterations; barlen=options.terminal_width ) + else + nothing end + last_print_time = time() last_speed_recording_time = time() num_evals_last = sum(sum, state.num_evals) @@ -933,7 +144,6 @@ function _main_search_loop!( end # nout, populations: j, i = state.task_order[kappa] - idea_database = idea_database_all[j] # Check if error on population: if ropt.parallelism in (:multiprocessing, :multithreading) @@ -977,17 +187,10 @@ function _main_search_loop!( dataset = datasets[j] cur_maxsize = state.cur_maxsizes[j] - worst_member = nothing for member in cur_pop.members - if worst_member == nothing || worst_member.loss < member.loss - worst_member = member - end size = compute_complexity(member, options) update_frequencies!(state.all_running_search_statistics[j]; size) end - if worst_member != nothing && worst_member.loss > 100 # if the worst of population is good then thats still good to keep - push!(worst_members, worst_member) - end #! format: off update_hall_of_fame!(state.halls_of_fame[j], cur_pop.members, options) update_hall_of_fame!(state.halls_of_fame[j], best_seen.members[best_seen.exists], options) @@ -995,14 +198,30 @@ function _main_search_loop!( # Dominating pareto curve - must be better than all simpler equations dominating = calculate_pareto_frontier(state.halls_of_fame[j]) - if options.llm_options.active && - options.llm_options.prompt_evol && + + worst_member = nothing + for member in cur_pop.members + if worst_member === nothing || member.loss > worst_member.loss + worst_member = member + end + end + + if worst_member !== nothing && worst_member.loss > dominating[end].loss + push!(worst_members, worst_member) + end + + if options.use_llm && + options.use_prompt_evol && (n_iterations % options.populations == 0) - update_idea_database(idea_database, dominating, worst_members, options) + state.idea_database = update_idea_database( + state.idea_database, dominating, worst_members, options + ) end + options.idea_database = state.idea_database + if options.save_to_file - save_to_file(dominating, nout, j, dataset, options) + save_to_file(dominating, nout, j, dataset, options, ropt) end ################################################################### # Migration ####################################################### @@ -1051,8 +270,6 @@ function _main_search_loop!( ropt.verbosity, cur_maxsize, running_search_statistics=c_rss, - dominating=dominating, - idea_database=idea_database, ) end, parallelism = ropt.parallelism, @@ -1064,11 +281,12 @@ function _main_search_loop!( ) end + total_cycles = ropt.niterations * options.populations state.cur_maxsizes[j] = get_cur_maxsize(; - options, ropt.total_cycles, cycles_remaining=state.cycles_remaining[j] + options, total_cycles, cycles_remaining=state.cycles_remaining[j] ) move_window!(state.all_running_search_statistics[j]) - if ropt.progress + if !isnothing(progress_bar) head_node_occupation = estimate_work_fraction(resource_monitor) update_progress_bar!( progress_bar, @@ -1080,6 +298,9 @@ function _main_search_loop!( ropt.parallelism, ) end + if ropt.logger !== nothing + logging_callback!(ropt.logger; state, datasets, ropt, options) + end end yield() @@ -1109,12 +330,13 @@ function _main_search_loop!( # Dominating pareto curve - must be better than all simpler equations head_node_occupation = estimate_work_fraction(resource_monitor) + total_cycles = ropt.niterations * options.populations print_search_state( state.halls_of_fame, datasets; options, equation_speed, - ropt.total_cycles, + total_cycles, state.cycles_remaining, head_node_occupation, parallelism=ropt.parallelism, @@ -1137,91 +359,11 @@ function _main_search_loop!( end ################################################################ end - llm_recorder( - options.llm_options, string(div(n_iterations, options.populations)), "n_iterations" - ) - return nothing -end -function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options) - close_reader!(state.stdin_reader) - # Safely close all processes or threads - if ropt.parallelism == :multiprocessing - state.we_created_procs && rmprocs(state.procs) - elseif ropt.parallelism == :multithreading - nout = length(state.worker_output) - for j in 1:nout, i in eachindex(state.worker_output[j]) - wait(state.worker_output[j][i]) - end + if !isnothing(progress_bar) + finish!(progress_bar) end - @recorder json3_write(state.record[], options.recorder_file) return nothing end -function _format_output( - state::SearchState, datasets, ropt::RuntimeOptions, options::Options -) - nout = length(datasets) - out_hof = if ropt.dim_out == 1 - embed_metadata(only(state.halls_of_fame), options, only(datasets)) - else - map(j -> embed_metadata(state.halls_of_fame[j], options, datasets[j]), 1:nout) - end - if ropt.return_state - return ( - map(j -> embed_metadata(state.last_pops[j], options, datasets[j]), 1:nout), - out_hof, - ) - else - return out_hof - end -end - -@stable default_mode = "disable" function _dispatch_s_r_cycle( - in_pop::Population{T,L,N}, - dataset::Dataset, - options::Options; - pop::Int, - out::Int, - iteration::Int, - verbosity, - cur_maxsize::Int, - running_search_statistics, - dominating=nothing, - idea_database=nothing, -) where {T,L,N} - record = RecordType() - @recorder record["out$(out)_pop$(pop)"] = RecordType( - "iteration$(iteration)" => record_population(in_pop, options) - ) - num_evals = 0.0 - normalize_frequencies!(running_search_statistics) - out_pop, best_seen, evals_from_cycle = s_r_cycle( - dataset, - in_pop, - options.ncycles_per_iteration, - cur_maxsize, - running_search_statistics; - verbosity=verbosity, - options=options, - record=record, - dominating=dominating, - idea_database=idea_database, - ) - num_evals += evals_from_cycle - out_pop, evals_from_optimize = optimize_and_simplify_population( - dataset, out_pop, options, cur_maxsize, record - ) - num_evals += evals_from_optimize - if options.batching - for i_member in 1:(options.maxsize + MAX_DEGREE) - score, result_loss = score_func(dataset, best_seen.members[i_member], options) - best_seen.members[i_member].score = score - best_seen.members[i_member].loss = result_loss - num_evals += 1 - end - end - return (out_pop, best_seen, record, num_evals) -end - include("MLJInterface.jl") using .MLJInterfaceModule: LaSRRegressor, MultitargetLaSRRegressor diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl deleted file mode 100644 index a84218879..000000000 --- a/src/LossFunctions.jl +++ /dev/null @@ -1,240 +0,0 @@ -module LossFunctionsModule - -using StatsBase: StatsBase -using DynamicExpressions: - AbstractExpression, AbstractExpressionNode, get_tree, eval_tree_array -using LossFunctions: LossFunctions -using LossFunctions: SupervisedLoss -using ..InterfaceDynamicExpressionsModule: expected_array_type -using ..CoreModule: Options, Dataset, create_expression, DATA_TYPE, LOSS_TYPE -using ..ComplexityModule: compute_complexity -using ..DimensionalAnalysisModule: violates_dimensional_constraints - -function _loss( - x::AbstractArray{T}, y::AbstractArray{T}, loss::LT -) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}} - if loss isa SupervisedLoss - return LossFunctions.mean(loss, x, y) - else - l(i) = loss(x[i], y[i]) - return LossFunctions.mean(l, eachindex(x)) - end -end - -function _weighted_loss( - x::AbstractArray{T}, y::AbstractArray{T}, w::AbstractArray{T}, loss::LT -) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}} - if loss isa SupervisedLoss - return sum(loss, x, y, w; normalize=true) - else - l(i) = loss(x[i], y[i], w[i]) - return sum(l, eachindex(x)) / sum(w) - end -end - -"""If any of the indices are `nothing`, just return.""" -@inline function maybe_getindex(v, i...) - if any(==(nothing), i) - return v - else - return getindex(v, i...) - end -end - -function eval_tree_dispatch( - tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, - dataset::Dataset{T}, - options::Options, - idx, -) where {T<:DATA_TYPE} - A = expected_array_type(dataset.X) - return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options)::Tuple{A,Bool} -end - -# Evaluate the loss of a particular expression on the input dataset. -function _eval_loss( - tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, - dataset::Dataset{T,L}, - options::Options, - regularization::Bool, - idx, -)::L where {T<:DATA_TYPE,L<:LOSS_TYPE} - (prediction, completion) = eval_tree_dispatch(tree, dataset, options, idx) - if !completion - return L(Inf) - end - - loss_val = if dataset.weighted - _weighted_loss( - prediction, - maybe_getindex(dataset.y, idx), - maybe_getindex(dataset.weights, idx), - options.elementwise_loss, - ) - else - _loss(prediction, maybe_getindex(dataset.y, idx), options.elementwise_loss) - end - - if regularization - loss_val += dimensional_regularization(tree, dataset, options) - end - - return loss_val -end - -# This evaluates function F: -function evaluator( - f::F, tree::AbstractExpressionNode{T}, dataset::Dataset{T,L}, options::Options, idx -)::L where {T<:DATA_TYPE,L<:LOSS_TYPE,F} - if hasmethod(f, typeof((tree, dataset, options, idx))) - # If user defines method that accepts batching indices: - return f(tree, dataset, options, idx) - elseif options.batching - error( - "User-defined loss function must accept batching indices if `options.batching == true`. " * - "For example, `f(tree, dataset, options, idx)`, where `idx` " * - "is `nothing` if full dataset is to be used, " * - "and a vector of indices otherwise.", - ) - else - return f(tree, dataset, options) - end -end - -# Evaluate the loss of a particular expression on the input dataset. -function eval_loss( - tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, - dataset::Dataset{T,L}, - options::Options; - regularization::Bool=true, - idx=nothing, -)::L where {T<:DATA_TYPE,L<:LOSS_TYPE} - loss_val = if options.loss_function === nothing - _eval_loss(tree, dataset, options, regularization, idx) - else - f = options.loss_function::Function - evaluator(f, get_tree(tree), dataset, options, idx) - end - - return loss_val -end - -function eval_loss_batched( - tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, - dataset::Dataset{T,L}, - options::Options; - regularization::Bool=true, - idx=nothing, -)::L where {T<:DATA_TYPE,L<:LOSS_TYPE} - _idx = idx === nothing ? batch_sample(dataset, options) : idx - return eval_loss(tree, dataset, options; regularization=regularization, idx=_idx) -end - -function batch_sample(dataset, options) - return StatsBase.sample(1:(dataset.n), options.batch_size; replace=true)::Vector{Int} -end - -# Just so we can pass either PopMember or Node here: -get_tree_from_member(t::Union{AbstractExpression,AbstractExpressionNode}) = t -get_tree_from_member(m) = m.tree -# Beware: this is a circular dependency situation... -# PopMember is using losses, but then we also want -# losses to use the PopMember's cached complexity for trees. -# TODO! - -# Compute a score which includes a complexity penalty in the loss -function loss_to_score( - loss::L, - use_baseline::Bool, - baseline::L, - member, - options::Options, - complexity::Union{Int,Nothing}=nothing, -)::L where {L<:LOSS_TYPE} - # TODO: Come up with a more general normalization scheme. - normalization = if baseline >= L(0.01) && use_baseline - baseline - else - L(0.01) - end - loss_val = loss / normalization - size = complexity === nothing ? compute_complexity(member, options) : complexity - parsimony_term = size * options.parsimony - loss_val += L(parsimony_term) - - return loss_val -end - -# Score an equation -function score_func( - dataset::Dataset{T,L}, member, options::Options; complexity::Union{Int,Nothing}=nothing -)::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE} - result_loss = eval_loss(get_tree_from_member(member), dataset, options) - score = loss_to_score( - result_loss, - dataset.use_baseline, - dataset.baseline_loss, - member, - options, - complexity, - ) - return score, result_loss -end - -# Score an equation with a small batch -function score_func_batched( - dataset::Dataset{T,L}, - member, - options::Options; - complexity::Union{Int,Nothing}=nothing, - idx=nothing, -)::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE} - result_loss = eval_loss_batched(get_tree_from_member(member), dataset, options; idx=idx) - score = loss_to_score( - result_loss, - dataset.use_baseline, - dataset.baseline_loss, - member, - options, - complexity, - ) - return score, result_loss -end - -""" - update_baseline_loss!(dataset::Dataset{T,L}, options::Options) where {T<:DATA_TYPE,L<:LOSS_TYPE} - -Update the baseline loss of the dataset using the loss function specified in `options`. -""" -function update_baseline_loss!( - dataset::Dataset{T,L}, options::Options -) where {T<:DATA_TYPE,L<:LOSS_TYPE} - example_tree = create_expression(zero(T), options, dataset) - # constructorof(options.node_type)(T; val=dataset.avg_y) - # TODO: It could be that the loss function is not defined for this example type? - baseline_loss = eval_loss(example_tree, dataset, options) - if isfinite(baseline_loss) - dataset.baseline_loss = baseline_loss - dataset.use_baseline = true - else - dataset.baseline_loss = one(L) - dataset.use_baseline = false - end - return nothing -end - -function dimensional_regularization( - tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, - dataset::Dataset{T,L}, - options::Options, -) where {T<:DATA_TYPE,L<:LOSS_TYPE} - if !violates_dimensional_constraints(tree, dataset, options) - return zero(L) - elseif options.dimensional_constraint_penalty === nothing - return L(1000) - else - return L(options.dimensional_constraint_penalty::Float32) - end -end - -end diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index c86c3da40..4a782183d 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -1,7 +1,7 @@ module MLJInterfaceModule - using Optim: Optim using LineSearches: LineSearches +using Logging: AbstractLogger using MLJModelInterface: MLJModelInterface as MMI using ADTypes: AbstractADType using DynamicExpressions: @@ -23,834 +23,26 @@ using DynamicQuantities: ustrip, dimension using LossFunctions: SupervisedLoss -using Compat: allequal, stack -using ..InterfaceDynamicQuantitiesModule: get_dimensions_type -using ..CoreModule: Options, Dataset, MutationWeights, LLMOptions, LOSS_TYPE -using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS -using ..ComplexityModule: compute_complexity -using ..HallOfFameModule: HallOfFame, format_hall_of_fame -using ..UtilsModule: subscriptify, @ignore - -import ..equation_search - -abstract type AbstractSRRegressor <: MMI.Deterministic end +using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_dimensions_type +using SymbolicRegression.CoreModule: + Options, Dataset, AbstractMutationWeights, MutationWeights, LOSS_TYPE, ComplexityMapping +using SymbolicRegression.CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS +using SymbolicRegression.ComplexityModule: compute_complexity +using SymbolicRegression.HallOfFameModule: HallOfFame, format_hall_of_fame +using SymbolicRegression.UtilsModule: subscriptify, @ignore +using SymbolicRegression.LoggingModule: AbstractSRLogger +using SymbolicRegression.MLJInterfaceModule: modelexpr, AbstractSRRegressor # For static analysis tools: @ignore mutable struct LaSRRegressor <: AbstractSRRegressor selection_method::Function end + @ignore mutable struct MultitargetLaSRRegressor <: AbstractSRRegressor selection_method::Function end -# TODO: To reduce code re-use, we could forward these defaults from -# `equation_search`, similar to what we do for `Options`. - -"""Generate an `LaSRRegressor` struct containing all the fields in `Options`.""" -function modelexpr(model_name::Symbol) - struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <: - AbstractSRRegressor - niterations::Int = 10 - parallelism::Symbol = :multithreading - numprocs::Union{Int,Nothing} = nothing - procs::Union{Vector{Int},Nothing} = nothing - addprocs_function::Union{Function,Nothing} = nothing - heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing - runtests::Bool = true - loss_type::L = Nothing - selection_method::Function = choose_best - dimensions_type::Type{D} = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE} - end) - # TODO: store `procs` from initial run if parallelism is `:multiprocessing` - fields = last(last(struct_def.args).args).args - - # Add everything from `Options` constructor directly to struct: - for (i, option) in enumerate(DEFAULT_OPTIONS) - insert!(fields, i, Expr(:(=), option.args...)) - end - - # We also need to create the `get_options` function, based on this: - constructor = :(Options(;)) - constructor_fields = last(constructor.args).args - for option in DEFAULT_OPTIONS - symb = getsymb(first(option.args)) - push!(constructor_fields, Expr(:kw, symb, Expr(:(.), :m, Core.QuoteNode(symb)))) - end - - return quote - $struct_def - function get_options(m::$(model_name)) - return $constructor - end - end -end -function getsymb(ex::Symbol) - return ex -end -function getsymb(ex::Expr) - for arg in ex.args - isa(arg, Symbol) && return arg - s = getsymb(arg) - isa(s, Symbol) && return s - end - return nothing -end - -"""Get an equivalent `Options()` object for a particular regressor.""" -function get_options(::AbstractSRRegressor) end - eval(modelexpr(:LaSRRegressor)) eval(modelexpr(:MultitargetLaSRRegressor)) -# Cleaning already taken care of by `Options` and `equation_search` -function full_report( - m::AbstractSRRegressor, fitresult; v_with_strings::Val{with_strings}=Val(true) -) where {with_strings} - _, hof = fitresult.state - # TODO: Adjust baseline loss - formatted = format_hall_of_fame(hof, fitresult.options) - equation_strings = if with_strings - get_equation_strings_for( - m, formatted.trees, fitresult.options, fitresult.variable_names - ) - else - nothing - end - best_idx = dispatch_selection_for( - m, formatted.trees, formatted.losses, formatted.scores, formatted.complexities - ) - return (; - best_idx=best_idx, - equations=formatted.trees, - equation_strings=equation_strings, - losses=formatted.losses, - complexities=formatted.complexities, - scores=formatted.scores, - ) -end - -MMI.clean!(::AbstractSRRegressor) = "" - -# TODO: Enable `verbosity` being passed to `equation_search` -function MMI.fit(m::AbstractSRRegressor, verbosity, X, y, w=nothing) - return MMI.update(m, verbosity, nothing, nothing, X, y, w) -end -function MMI.update( - m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing -) - options = old_fitresult === nothing ? get_options(m) : old_fitresult.options - return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing) -end -function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, classes) - if isnothing(classes) && MMI.istable(X) && haskey(X, :classes) - if !(X isa NamedTuple) - error("Classes can only be specified with named tuples.") - end - new_X = Base.structdiff(X, (; X.classes)) - new_classes = X.classes - return _update( - m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_classes - ) - end - if !isnothing(old_fitresult) - @assert( - old_fitresult.has_classes == !isnothing(classes), - "If the first fit used classes, the second fit must also use classes." - ) - end - # To speed up iterative fits, we cache the types: - types = if isnothing(old_fitresult) - (; - T=Any, - X_t=Any, - y_t=Any, - w_t=Any, - state=Any, - X_units=Any, - y_units=Any, - X_units_clean=Any, - y_units_clean=Any, - ) - else - old_fitresult.types - end - X_t::types.X_t, variable_names, X_units::types.X_units = get_matrix_and_info( - X, m.dimensions_type - ) - y_t::types.y_t, y_variable_names, y_units::types.y_units = format_input_for( - m, y, m.dimensions_type - ) - X_units_clean::types.X_units_clean = clean_units(X_units) - y_units_clean::types.y_units_clean = clean_units(y_units) - w_t::types.w_t = if w !== nothing && isa(m, MultitargetLaSRRegressor) - @assert(isa(w, AbstractVector) && ndims(w) == 1, "Unexpected input for `w`.") - repeat(w', size(y_t, 1)) - else - w - end - search_state::types.state = equation_search( - X_t, - y_t; - niterations=m.niterations, - weights=w_t, - variable_names=variable_names, - options=options, - parallelism=m.parallelism, - numprocs=m.numprocs, - procs=m.procs, - addprocs_function=m.addprocs_function, - heap_size_hint_in_bytes=m.heap_size_hint_in_bytes, - runtests=m.runtests, - saved_state=(old_fitresult === nothing ? nothing : old_fitresult.state), - return_state=true, - loss_type=m.loss_type, - X_units=X_units_clean, - y_units=y_units_clean, - verbosity=verbosity, - extra=isnothing(classes) ? (;) : (; classes), - # Help out with inference: - v_dim_out=isa(m, LaSRRegressor) ? Val(1) : Val(2), - ) - fitresult = (; - state=search_state, - num_targets=isa(m, LaSRRegressor) ? 1 : size(y_t, 1), - options=options, - variable_names=variable_names, - y_variable_names=y_variable_names, - y_is_table=MMI.istable(y), - has_classes=!isnothing(classes), - X_units=X_units_clean, - y_units=y_units_clean, - types=( - T=hof_eltype(search_state[2]), - X_t=typeof(X_t), - y_t=typeof(y_t), - w_t=typeof(w_t), - state=typeof(search_state), - X_units=typeof(X_units), - y_units=typeof(y_units), - X_units_clean=typeof(X_units_clean), - y_units_clean=typeof(y_units_clean), - ), - )::(old_fitresult === nothing ? Any : typeof(old_fitresult)) - return (fitresult, nothing, full_report(m, fitresult)) -end -hof_eltype(::Type{H}) where {T,H<:HallOfFame{T}} = T -hof_eltype(::Type{V}) where {V<:Vector} = hof_eltype(eltype(V)) -hof_eltype(h) = hof_eltype(typeof(h)) - -function clean_units(units) - !isa(units, AbstractDimensions) && error("Unexpected units.") - iszero(units) && return nothing - return units -end -function clean_units(units::Vector) - !all(Base.Fix2(isa, AbstractDimensions), units) && error("Unexpected units.") - all(iszero, units) && return nothing - return units -end - -function get_matrix_and_info(X, ::Type{D}) where {D} - sch = MMI.istable(X) ? MMI.schema(X) : nothing - Xm_t = MMI.matrix(X; transpose=true) - colnames = if sch === nothing - [map(i -> "x$(subscriptify(i))", axes(Xm_t, 1))...] - else - [string.(sch.names)...] - end - D_promoted = get_dimensions_type(Xm_t, D) - Xm_t_strip, X_units = unwrap_units_single(Xm_t, D_promoted) - return Xm_t_strip, colnames, X_units -end - -function format_input_for(::LaSRRegressor, y, ::Type{D}) where {D} - @assert( - !(MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1)), - "For multi-output regression, please use `MultitargetLaSRRegressor`." - ) - y_t = vec(y) - colnames = nothing - D_promoted = get_dimensions_type(y_t, D) - y_t_strip, y_units = unwrap_units_single(y_t, D_promoted) - return y_t_strip, colnames, y_units -end -function format_input_for(::MultitargetLaSRRegressor, y, ::Type{D}) where {D} - @assert( - MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1), - "For single-output regression, please use `LaSRRegressor`." - ) - return get_matrix_and_info(y, D) -end -function validate_variable_names(variable_names, fitresult) - @assert( - variable_names == fitresult.variable_names, - "Variable names do not match fitted regressor." - ) - return nothing -end -function validate_units(X_units, old_X_units) - @assert( - all(X_units .== old_X_units), - "Units of new data do not match units of fitted regressor." - ) - return nothing -end - -# TODO: Test whether this conversion poses any issues in data normalization... -function dimension_with_fallback(q::UnionAbstractQuantity{T}, ::Type{D}) where {T,D} - return dimension(convert(Quantity{T,D}, q))::D -end -function dimension_with_fallback(_, ::Type{D}) where {D} - return D() -end -function prediction_warn() - @warn "Evaluation failed either due to NaNs detected or due to unfinished search. Using 0s for prediction." -end - -wrap_units(v, ::Nothing, ::Integer) = v -wrap_units(v, ::Nothing, ::Nothing) = v -wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v) -wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v) - -function prediction_fallback(::Type{T}, ::LaSRRegressor, Xnew_t, fitresult, _) where {T} - prediction_warn() - out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)) - return wrap_units(out, fitresult.y_units, nothing) -end -function prediction_fallback( - ::Type{T}, ::MultitargetLaSRRegressor, Xnew_t, fitresult, prototype -) where {T} - prediction_warn() - out_cols = [ - wrap_units( - fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)), fitresult.y_units, i - ) for i in 1:(fitresult.num_targets) - ] - out_matrix = hcat(out_cols...) - if !fitresult.y_is_table - return out_matrix - else - return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype) - end -end - -compat_ustrip(A::QuantityArray) = ustrip(A) -compat_ustrip(A) = ustrip.(A) - -""" - unwrap_units_single(::AbstractArray, ::Type{<:AbstractDimensions}) - -Remove units from some features in a matrix, and return, as a tuple, -(1) the matrix with stripped units, and (2) the dimensions for those features. -""" -function unwrap_units_single(A::AbstractMatrix, ::Type{D}) where {D} - dims = D[dimension_with_fallback(first(row), D) for row in eachrow(A)] - @inbounds for (i, row) in enumerate(eachrow(A)) - all(xi -> dimension_with_fallback(xi, D) == dims[i], row) || - error("Inconsistent units in feature $i of matrix.") - end - return stack(compat_ustrip, eachrow(A); dims=1)::AbstractMatrix, dims -end -function unwrap_units_single(v::AbstractVector, ::Type{D}) where {D} - dims = dimension_with_fallback(first(v), D) - all(xi -> dimension_with_fallback(xi, D) == dims, v) || - error("Inconsistent units in vector.") - return compat_ustrip(v)::AbstractVector, dims -end - -function MMI.fitted_params(m::AbstractSRRegressor, fitresult) - report = full_report(m, fitresult) - return (; - best_idx=report.best_idx, - equations=report.equations, - equation_strings=report.equation_strings, - ) -end - -function eval_tree_mlj( - tree::AbstractExpression, - X_t, - classes, - m::AbstractSRRegressor, - ::Type{T}, - fitresult, - i, - prototype, -) where {T} - out, completed = if isnothing(classes) - eval_tree_array(tree, X_t, fitresult.options) - else - eval_tree_array(tree, X_t, classes, fitresult.options) - end - if completed - return wrap_units(out, fitresult.y_units, i) - else - return prediction_fallback(T, m, X_t, fitresult, prototype) - end -end - -function MMI.predict( - m::M, fitresult, Xnew; idx=nothing, classes=nothing -) where {M<:AbstractSRRegressor} - return _predict(m, fitresult, Xnew, idx, classes) -end -function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegressor} - if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data)) - @assert( - haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2, - "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`." - ) - return _predict(m, fitresult, Xnew.data, Xnew.idx, classes) - end - if isnothing(classes) && MMI.istable(Xnew) && haskey(Xnew, :classes) - if !(Xnew isa NamedTuple) - error("Classes can only be specified with named tuples.") - end - Xnew2 = Base.structdiff(Xnew, (; Xnew.classes)) - return _predict(m, fitresult, Xnew2, idx, Xnew.classes) - end - - if fitresult.has_classes - @assert( - !isnothing(classes), - "Classes must be specified if the model was fit with classes." - ) - end - - params = full_report(m, fitresult; v_with_strings=Val(false)) - prototype = MMI.istable(Xnew) ? Xnew : nothing - Xnew_t, variable_names, X_units = get_matrix_and_info(Xnew, m.dimensions_type) - T = promote_type(eltype(Xnew_t), fitresult.types.T) - - if isempty(params.equations) || any(isempty, params.equations) - @warn "Equations not found. Returning 0s for prediction." - return prediction_fallback(T, m, Xnew_t, fitresult, prototype) - end - - X_units_clean = clean_units(X_units) - validate_variable_names(variable_names, fitresult) - validate_units(X_units_clean, fitresult.X_units) - - idx = idx === nothing ? params.best_idx : idx - - if M <: LaSRRegressor - return eval_tree_mlj( - params.equations[idx], Xnew_t, classes, m, T, fitresult, nothing, prototype - ) - elseif M <: MultitargetLaSRRegressor - outs = [ - eval_tree_mlj( - params.equations[i][idx[i]], Xnew_t, classes, m, T, fitresult, i, prototype - ) for i in eachindex(idx, params.equations) - ] - out_matrix = reduce(hcat, outs) - if !fitresult.y_is_table - return out_matrix - else - return MMI.table(out_matrix; names=fitresult.y_variable_names, prototype) - end - end -end - -function get_equation_strings_for(::LaSRRegressor, trees, options, variable_names) - return (t -> string_tree(t, options; variable_names=variable_names)).(trees) -end -function get_equation_strings_for( - ::MultitargetLaSRRegressor, trees, options, variable_names -) - return [ - (t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees - ] -end - -function choose_best(; trees, losses::Vector{L}, scores, complexities) where {L<:LOSS_TYPE} - # Same as in PySR: - # https://github.com/MilesCranmer/PySR/blob/e74b8ad46b163c799908b3aa4d851cf8457c79ef/pysr/sr.py#L2318-L2332 - # threshold = 1.5 * minimum_loss - # Then, we get max score of those below the threshold. - threshold = 1.5 * minimum(losses) - return argmax([ - (losses[i] <= threshold) ? scores[i] : typemin(L) for i in eachindex(losses) - ]) -end - -function dispatch_selection_for(m::LaSRRegressor, trees, losses, scores, complexities)::Int - length(trees) == 0 && return 0 - return m.selection_method(; - trees=trees, losses=losses, scores=scores, complexities=complexities - ) -end -function dispatch_selection_for( - m::MultitargetLaSRRegressor, trees, losses, scores, complexities -) - any(t -> length(t) == 0, trees) && return fill(0, length(trees)) - return [ - m.selection_method(; - trees=trees[i], losses=losses[i], scores=scores[i], complexities=complexities[i] - ) for i in eachindex(trees) - ] -end - -MMI.metadata_pkg( - AbstractSRRegressor; - name="LibraryAugmentedSymbolicRegression", - uuid="8254be44-1295-4e6a-a16d-46603ac705cb", - url="https://github.com/MilesCranmer/LibraryAugmentedSymbolicRegression.jl", - julia=true, - license="Apache-2.0", - is_wrapper=false, -) - -const input_scitype = Union{ - MMI.Table(MMI.Continuous), - AbstractMatrix{<:MMI.Continuous}, - MMI.Table(MMI.Continuous, MMI.Count), -} - -# TODO: Allow for Count data, and coerce it into Continuous as needed. -MMI.metadata_model( - LaSRRegressor; - input_scitype, - target_scitype=AbstractVector{<:MMI.Continuous}, - supports_weights=true, - reports_feature_importances=false, - load_path="LibraryAugmentedSymbolicRegression.MLJInterfaceModule.LaSRRegressor", - human_name="Symbolic Regression via Evolutionary Search", -) -MMI.metadata_model( - MultitargetLaSRRegressor; - input_scitype, - target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, - supports_weights=true, - reports_feature_importances=false, - load_path="LibraryAugmentedSymbolicRegression.MLJInterfaceModule.MultitargetLaSRRegressor", - human_name="Multi-Target Symbolic Regression via Evolutionary Search", -) - -function tag_with_docstring(model_name::Symbol, description::String, bottom_matter::String) - docstring = """$(MMI.doc_header(eval(model_name))) - - $(description) - - # Hyper-parameters - """ - - # TODO: These ones are copied (or written) manually: - append_arguments = """- `niterations::Int=10`: The number of iterations to perform the search. - More iterations will improve the results. - - `parallelism=:multithreading`: What parallelism mode to use. - The options are `:multithreading`, `:multiprocessing`, and `:serial`. - By default, multithreading will be used. Multithreading uses less memory, - but multiprocessing can handle multi-node compute. If using `:multithreading` - mode, the number of threads available to julia are used. If using - `:multiprocessing`, `numprocs` processes will be created dynamically if - `procs` is unset. If you have already allocated processes, pass them - to the `procs` argument and they will be used. - You may also pass a string instead of a symbol, like `"multithreading"`. - - `numprocs::Union{Int, Nothing}=nothing`: The number of processes to use, - if you want `equation_search` to set this up automatically. By default - this will be `4`, but can be any number (you should pick a number <= - the number of cores available). - - `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up - a distributed run manually with `procs = addprocs()` and `@everywhere`, - pass the `procs` to this keyword argument. - - `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing - (`parallelism=:multithreading`), and are not passing `procs` manually, - then they will be allocated dynamically using `addprocs`. However, - you may also pass a custom function to use instead of `addprocs`. - This function should take a single positional argument, - which is the number of processes to use, as well as the `lazy` keyword argument. - For example, if set up on a slurm cluster, you could pass - `addprocs_function = addprocs_slurm`, which will set up slurm processes. - - `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint` - flag on Julia processes, recommending garbage collection once a process - is close to the recommended size. This is important for long-running distributed - jobs where each process has an independent memory, and can help avoid - out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`. - - `runtests::Bool=true`: Whether to run (quick) tests before starting the - search, to see if there will be any problems during the equation search - related to the host environment. - - `loss_type::Type=Nothing`: If you would like to use a different type - for the loss than for the data you passed, specify the type here. - Note that if you pass complex data `::Complex{L}`, then the loss - type will automatically be set to `L`. - - `selection_method::Function`: Function to selection expression from - the Pareto frontier for use in `predict`. - See `LibraryAugmentedSymbolicRegression.MLJInterfaceModule.choose_best` for an example. - This function should return a single integer specifying - the index of the expression to use. By default, this maximizes - the score (a pound-for-pound rating) of expressions reaching the threshold - of 1.5x the minimum loss. To override this at prediction time, you can pass - a named tuple with keys `data` and `idx` to `predict`. See the Operations - section for details. - - `dimensions_type::AbstractDimensions`: The type of dimensions to use when storing - the units of the data. By default this is `DynamicQuantities.SymbolicDimensions`. - """ - - bottom = """ - # Operations - - - `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which - should have same scitype as `X` above. The expression used for prediction is defined - by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`. - - `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features - `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys - `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`. - - $(bottom_matter) - """ - - # Remove common indentation: - docstring = replace(docstring, r"^ " => "") - extra_arguments = replace(append_arguments, r"^ " => "") - bottom = replace(bottom, r"^ " => "") - - # Add parameter descriptions: - docstring = docstring * OPTION_DESCRIPTIONS - docstring = docstring * extra_arguments - docstring = docstring * bottom - return quote - @doc $docstring $model_name - end -end - -#https://arxiv.org/abs/2305.01582 -eval( - tag_with_docstring( - :LaSRRegressor, - replace( - """ - Single-target Symbolic Regression regressor (`LaSRRegressor`) searches - for symbolic expressions that predict a single target variable from - a set of input variables. All data is assumed to be `Continuous`. - The search is performed using an evolutionary algorithm. - This algorithm is described in the paper - https://arxiv.org/abs/2305.01582. - - # Training data - - In MLJ or MLJBase, bind an instance `model` to data with - - mach = machine(model, X, y) - - OR - - mach = machine(model, X, y, w) - - Here: - - - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype - `Continuous`; check column scitypes with `schema(X)`. Variable names in discovered - expressions will be taken from the column names of `X`, if available. Units in columns - of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used. - - - `y` is the target, which can be any `AbstractVector` whose element scitype is - `Continuous`; check the scitype with `scitype(y)`. Units in `y` (use `DynamicQuantities` - for units) will trigger dimensional analysis to be used. - - - `w` is the observation weights which can either be `nothing` (default) or an - `AbstractVector` whose element scitype is `Count` or `Continuous`. - - Train the machine using `fit!(mach)`, inspect the discovered expressions with - `report(mach)`, and predict on new data with `predict(mach, Xnew)`. - Note that unlike other regressors, symbolic regression stores a list of - trained models. The model chosen from this list is defined by the function - `selection_method` keyword argument, which by default balances accuracy - and complexity. You can override this at prediction time by passing a named - tuple with keys `data` and `idx`. - - """, - r"^ " => "", - ), - replace( - """ - # Fitted parameters - - The fields of `fitted_params(mach)` are: - - - `best_idx::Int`: The index of the best expression in the Pareto frontier, - as determined by the `selection_method` function. Override in `predict` by passing - a named tuple with keys `data` and `idx`. - - `equations::Vector{Node{T}}`: The expressions discovered by the search, represented - in a dominating Pareto frontier (i.e., the best expressions found for - each complexity). `T` is equal to the element type - of the passed data. - - `equation_strings::Vector{String}`: The expressions discovered by the search, - represented as strings for easy inspection. - - # Report - - The fields of `report(mach)` are: - - - `best_idx::Int`: The index of the best expression in the Pareto frontier, - as determined by the `selection_method` function. Override in `predict` by passing - a named tuple with keys `data` and `idx`. - - `equations::Vector{Node{T}}`: The expressions discovered by the search, represented - in a dominating Pareto frontier (i.e., the best expressions found for - each complexity). - - `equation_strings::Vector{String}`: The expressions discovered by the search, - represented as strings for easy inspection. - - `complexities::Vector{Int}`: The complexity of each expression in the Pareto frontier. - - `losses::Vector{L}`: The loss of each expression in the Pareto frontier, according - to the loss function specified in the model. The type `L` is the loss type, which - is usually the same as the element type of data passed (i.e., `T`), but can differ - if complex data types are passed. - - `scores::Vector{L}`: A metric which considers both the complexity and loss of an expression, - equal to the change in the log-loss divided by the change in complexity, relative to - the previous expression along the Pareto frontier. A larger score aims to indicate - an expression is more likely to be the true expression generating the data, but - this is very problem-dependent and generally several other factors should be considered. - - # Examples - - ```julia - using MLJ - LaSRRegressor = @load LaSRRegressor pkg=LibraryAugmentedSymbolicRegression - X, y = @load_boston - model = LaSRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100) - mach = machine(model, X, y) - fit!(mach) - y_hat = predict(mach, X) - # View the equation used: - r = report(mach) - println("Equation used:", r.equation_strings[r.best_idx]) - ``` - - With units and variable names: - - ```julia - using MLJ - using DynamicQuantities - SRegressor = @load LaSRRegressor pkg=LibraryAugmentedSymbolicRegression - - X = (; x1=rand(32) .* us"km/h", x2=rand(32) .* us"km") - y = @. X.x2 / X.x1 + 0.5us"h" - model = LaSRRegressor(binary_operators=[+, -, *, /]) - mach = machine(model, X, y) - fit!(mach) - y_hat = predict(mach, X) - # View the equation used: - r = report(mach) - println("Equation used:", r.equation_strings[r.best_idx]) - ``` - - See also [`MultitargetLaSRRegressor`](@ref). - """, - r"^ " => "", - ), - ), -) -eval( - tag_with_docstring( - :MultitargetLaSRRegressor, - replace( - """ - Multi-target Symbolic Regression regressor (`MultitargetLaSRRegressor`) - conducts several searches for expressions that predict each target variable - from a set of input variables. All data is assumed to be `Continuous`. - The search is performed using an evolutionary algorithm. - This algorithm is described in the paper - https://arxiv.org/abs/2305.01582. - - # Training data - In MLJ or MLJBase, bind an instance `model` to data with - - mach = machine(model, X, y) - - OR - - mach = machine(model, X, y, w) - - Here: - - - `X` is any table of input features (eg, a `DataFrame`) whose columns are of scitype - `Continuous`; check column scitypes with `schema(X)`. Variable names in discovered - expressions will be taken from the column names of `X`, if available. Units in columns - of `X` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used. - - - `y` is the target, which can be any table of target variables whose element - scitype is `Continuous`; check the scitype with `schema(y)`. Units in columns of - `y` (use `DynamicQuantities` for units) will trigger dimensional analysis to be used. - - - `w` is the observation weights which can either be `nothing` (default) or an - `AbstractVector` whose element scitype is `Count` or `Continuous`. The same - weights are used for all targets. - - Train the machine using `fit!(mach)`, inspect the discovered expressions with - `report(mach)`, and predict on new data with `predict(mach, Xnew)`. - Note that unlike other regressors, symbolic regression stores a list of lists of - trained models. The models chosen from each of these lists is defined by the function - `selection_method` keyword argument, which by default balances accuracy - and complexity. You can override this at prediction time by passing a named - tuple with keys `data` and `idx`. - - """, - r"^ " => "", - ), - replace( - """ - # Fitted parameters - - The fields of `fitted_params(mach)` are: - - - `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier, - as determined by the `selection_method` function. Override in `predict` by passing - a named tuple with keys `data` and `idx`. - - `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented - in a dominating Pareto frontier (i.e., the best expressions found for - each complexity). The outer vector is indexed by target variable, and the inner - vector is ordered by increasing complexity. `T` is equal to the element type - of the passed data. - - `equation_strings::Vector{Vector{String}}`: The expressions discovered by the search, - represented as strings for easy inspection. - - # Report - - The fields of `report(mach)` are: - - - `best_idx::Vector{Int}`: The index of the best expression in each Pareto frontier, - as determined by the `selection_method` function. Override in `predict` by passing - a named tuple with keys `data` and `idx`. - - `equations::Vector{Vector{Node{T}}}`: The expressions discovered by the search, represented - in a dominating Pareto frontier (i.e., the best expressions found for - each complexity). The outer vector is indexed by target variable, and the inner - vector is ordered by increasing complexity. - - `equation_strings::Vector{Vector{String}}`: The expressions discovered by the search, - represented as strings for easy inspection. - - `complexities::Vector{Vector{Int}}`: The complexity of each expression in each Pareto frontier. - - `losses::Vector{Vector{L}}`: The loss of each expression in each Pareto frontier, according - to the loss function specified in the model. The type `L` is the loss type, which - is usually the same as the element type of data passed (i.e., `T`), but can differ - if complex data types are passed. - - `scores::Vector{Vector{L}}`: A metric which considers both the complexity and loss of an expression, - equal to the change in the log-loss divided by the change in complexity, relative to - the previous expression along the Pareto frontier. A larger score aims to indicate - an expression is more likely to be the true expression generating the data, but - this is very problem-dependent and generally several other factors should be considered. - - # Examples - - ```julia - using MLJ - MultitargetLaSRRegressor = @load MultitargetLaSRRegressor pkg=LibraryAugmentedSymbolicRegression - X = (a=rand(100), b=rand(100), c=rand(100)) - Y = (y1=(@. cos(X.c) * 2.1 - 0.9), y2=(@. X.a * X.b + X.c)) - model = MultitargetLaSRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100) - mach = machine(model, X, Y) - fit!(mach) - y_hat = predict(mach, X) - # View the equations used: - r = report(mach) - for (output_index, (eq, i)) in enumerate(zip(r.equation_strings, r.best_idx)) - println("Equation used for ", output_index, ": ", eq[i]) - end - ``` - - See also [`LaSRRegressor`](@ref). - """, - r"^ " => "", - ), - ), -) - end diff --git a/src/Migration.jl b/src/Migration.jl deleted file mode 100644 index daab9255f..000000000 --- a/src/Migration.jl +++ /dev/null @@ -1,40 +0,0 @@ -module MigrationModule - -using StatsBase: StatsBase -using ..CoreModule: Options -using ..PopulationModule: Population -using ..PopMemberModule: PopMember, reset_birth! -using ..UtilsModule: poisson_sample - -""" - migrate!(migration::Pair{Population{T,L},Population{T,L}}, options::Options; frac::AbstractFloat) - -Migrate a fraction of the population from one population to the other, creating copies -to do so. The original migrant population is not modified. Pass with, e.g., -`migrate!(migration_candidates => destination, options; frac=0.1)` -""" -function migrate!( - migration::Pair{Vector{PM},P}, options::Options; frac::AbstractFloat -) where {T,L,N,PM<:PopMember{T,L,N},P<:Population{T,L,N}} - base_pop = migration.second - population_size = length(base_pop.members) - mean_number_replaced = population_size * frac - num_replace = poisson_sample(mean_number_replaced) - - migrant_candidates = migration.first - - # Ensure `replace=true` is a valid setting: - num_replace = min(num_replace, length(migrant_candidates)) - num_replace = min(num_replace, population_size) - - locations = StatsBase.sample(1:population_size, num_replace; replace=true) - migrants = StatsBase.sample(migrant_candidates, num_replace; replace=true) - - for (i, migrant) in zip(locations, migrants) - base_pop.members[i] = copy(migrant) - reset_birth!(base_pop.members[i]; options.deterministic) - end - return nothing -end - -end diff --git a/src/Mutate.jl b/src/Mutate.jl index 496b8a337..e08b7ffc6 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -1,475 +1,74 @@ module MutateModule -using DispatchDoctor: @unstable -using DynamicExpressions: - AbstractExpressionNode, - AbstractExpression, - ParametricExpression, - with_contents, - get_tree, - preserve_sharing, - copy_node, - count_scalar_constants, - simplify_tree!, - combine_operators -using ..CoreModule: Options, MutationWeights, Dataset, RecordType, sample_mutation -using ..ComplexityModule: compute_complexity -using ..LossFunctionsModule: score_func, score_func_batched -using ..CheckConstraintsModule: check_constraints -using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..PopMemberModule: PopMember -using ..MutationFunctionsModule: - gen_random_tree_fixed_size, - mutate_constant, - mutate_operator, - swap_operands, - append_random_op, - prepend_random_op, - insert_random_op, - delete_random_op!, - crossover_trees, - form_random_connection!, - break_random_connection! +using SymbolicRegression +using .SymbolicRegression: @recorder +using ..LLMOptionsModule: LaSROptions using ..LLMFunctionsModule: - llm_mutate_op, llm_crossover_trees, tree_to_expr, gen_llm_random_tree, llm_recorder - -using ..ConstantOptimizationModule: optimize_constants -using ..RecorderModule: @recorder - -function check_constant(tree::AbstractExpressionNode)::Bool - return (tree.degree == 0) && tree.constant -end - -function check_constant(tree::AbstractExpression)::Bool - return check_constant(get_tree(tree)) + llm_mutate_tree, llm_recorder, llm_crossover_trees, llm_randomize_tree + +function mutate!( + tree::N, + ::P, + ::Val{:llm_mutate}, + ::SymbolicRegression.AbstractMutationWeights, + options::LaSROptions; + recorder::SymbolicRegression.RecordType, + curmaxsize, + nfeatures, + kws..., +) where {T,N<:SymbolicRegression.AbstractExpression{T},P<:SymbolicRegression.PopMember} + tree = llm_mutate_tree(tree, options) + @recorder recorder["type"] = "llm_mutate" + return MutationResult{N,P}(; tree=tree) end -function condition_mutation_weights!( - weights::MutationWeights, member::PopMember, options::Options, curmaxsize::Int -) - tree = get_tree(member.tree) - if !preserve_sharing(typeof(member.tree)) - weights.form_connection = 0.0 - weights.break_connection = 0.0 - end - if tree.degree == 0 - # If equation is too small, don't delete operators - # or simplify - weights.mutate_operator = 0.0 - weights.swap_operands = 0.0 - weights.delete_node = 0.0 - weights.simplify = 0.0 - if !tree.constant - weights.optimize = 0.0 - weights.mutate_constant = 0.0 - end - return nothing - end - - if !any(node -> node.degree == 2, tree) - # swap is implemented only for binary ops - weights.swap_operands = 0.0 - end - - if !(member.tree isa ParametricExpression) # TODO: HACK - #More constants => more likely to do constant mutation - let n_constants = count_scalar_constants(member.tree) - weights.mutate_constant *= min(8, n_constants) / 8.0 - end - end - complexity = compute_complexity(member, options) - - if complexity >= curmaxsize - # If equation is too big, don't add new operators - weights.add_node = 0.0 - weights.insert_node = 0.0 - end - - if !options.should_simplify - weights.simplify = 0.0 - end - - return nothing +function mutate!( + tree::N, + ::P, + ::Val{:llm_randomize}, + ::SymbolicRegression.AbstractMutationWeights, + options::LaSROptions; + recorder::SymbolicRegression.RecordType, + curmaxsize, + nfeatures, + kws..., +) where {T,N<:SymbolicRegression.AbstractExpression{T},P<:SymbolicRegression.PopMember} + tree = llm_randomize_tree(tree, curmaxsize, options, nfeatures) + @recorder recorder["type"] = "llm_randomize" + return MutationResult{N,P}(; tree=tree) end -# Go through one simulated options.annealing mutation cycle -# exp(-delta/T) defines probability of accepting a change -@unstable function next_generation( - dataset::D, - member::P, - temperature, - curmaxsize::Int, - running_search_statistics::RunningSearchStatistics, - options::Options; - tmp_recorder::RecordType, - dominating=nothing, - idea_database=nothing, -)::Tuple{ - P,Bool,Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} - parent_ref = member.ref - mutation_accepted = false - num_evals = 0.0 - - #TODO - reconsider this - beforeScore, beforeLoss = if options.batching - num_evals += (options.batch_size / dataset.n) - score_func_batched(dataset, member, options) - else - member.score, member.loss - end - - nfeatures = dataset.nfeatures - - weights = copy(options.mutation_weights) - - condition_mutation_weights!(weights, member, options, curmaxsize) - - mutation_choice = sample_mutation(weights) - - successful_mutation = false - #TODO: Currently we dont take this \/ into account - is_success_always_possible = true - attempts = 0 - max_attempts = 10 - - ############################################# - # Mutations - ############################################# - local tree - if options.llm_options.active && (rand() < options.llm_options.weights.llm_mutate) - tree = copy_node(member.tree) - if check_constant(tree) - tree = with_contents( - tree, gen_random_tree_fixed_size(rand(1:curmaxsize), options, nfeatures, T) +"""Generate a generation via crossover of two members.""" +function crossover_generation( + member1::P, member2::P, dataset::D, curmaxsize::Int, options::LaSROptions +)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} + llm_skip = false + if options.use_llm && (rand() < options.llm_operation_weights.llm_crossover) + tree1 = member1.tree + tree2 = member2.tree + + # add simplification for crossover + tree1 = simplify_tree!(tree1, options.operators) + tree1 = combine_operators(tree1, options.operators) + tree2 = simplify_tree!(tree2, options.operators) + tree2 = combine_operators(tree2, options.operators) + + crossover_accepted = false + nfeatures = dataset.nfeatures + + if check_constant(tree1) + tree1 = with_contents( + tree1, gen_random_tree_fixed_size(rand(1:curmaxsize), options, nfeatures, T) ) end - tree = llm_mutate_op(tree, options, dominating, idea_database) - tree = simplify_tree!(tree, options.operators) - tree = combine_operators(tree, options.operators) - @recorder tmp_recorder["type"] = "llm_mutate" - - successful_mutation = - (!check_constant(tree)) && check_constraints(tree, options, curmaxsize) - - if successful_mutation - llm_recorder(options.llm_options, tree_to_expr(tree, options), "mutate") - else - llm_recorder(options.llm_options, tree_to_expr(tree, options), "mutate|failed") - end - end - - while (!successful_mutation) && attempts < max_attempts - tree = copy_node(member.tree) - successful_mutation = true - if mutation_choice == :mutate_constant - tree = mutate_constant(tree, temperature, options) - @recorder tmp_recorder["type"] = "constant" - is_success_always_possible = true - # Mutating a constant shouldn't invalidate an already-valid function - elseif mutation_choice == :mutate_operator - tree = mutate_operator(tree, options) - @recorder tmp_recorder["type"] = "operator" - is_success_always_possible = true - # Can always mutate to the same operator - - elseif mutation_choice == :swap_operands - tree = swap_operands(tree) - @recorder tmp_recorder["type"] = "swap_operands" - is_success_always_possible = true - - elseif mutation_choice == :add_node - if rand() < 0.5 - tree = append_random_op(tree, options, nfeatures) - @recorder tmp_recorder["type"] = "append_op" - else - tree = prepend_random_op(tree, options, nfeatures) - @recorder tmp_recorder["type"] = "prepend_op" - end - is_success_always_possible = false - # Can potentially have a situation without success - elseif mutation_choice == :insert_node - tree = insert_random_op(tree, options, nfeatures) - @recorder tmp_recorder["type"] = "insert_op" - is_success_always_possible = false - elseif mutation_choice == :delete_node - tree = delete_random_op!(tree, options, nfeatures) - @recorder tmp_recorder["type"] = "delete_op" - is_success_always_possible = true - elseif mutation_choice == :simplify - @assert options.should_simplify - simplify_tree!(tree, options.operators) - tree = combine_operators(tree, options.operators) - @recorder tmp_recorder["type"] = "partial_simplify" - mutation_accepted = true - is_success_always_possible = true - return ( - PopMember( - tree, - beforeScore, - beforeLoss, - options; - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, + if check_constant(tree2) + tree2 = with_contents( + tree2, gen_random_tree_fixed_size(rand(1:curmaxsize), options, nfeatures, T) ) - # Simplification shouldn't hurt complexity; unless some non-symmetric constraint - # to commutative operator... - elseif mutation_choice == :randomize - # We select a random size, though the generated tree - # may have fewer nodes than we request. - tree_size_to_generate = rand(1:curmaxsize) - if options.llm_options.active && - (rand() < options.llm_options.weights.llm_gen_random) - tree = with_contents( - tree, - combine_operators( - simplify_tree!( - gen_llm_random_tree( - tree_size_to_generate, options, nfeatures, T, idea_database - ), - options.operators, - ), - options.operators, - ), - ) - @recorder tmp_recorder["type"] = "regenerate_llm" - - is_success_always_possible = false - - if check_constant(tree) # don't allow constant outputs - tree = with_contents( - tree, - gen_random_tree_fixed_size( - tree_size_to_generate, options, nfeatures, T - ), - ) - is_success_always_possible = true - end - else - tree = with_contents( - tree, - gen_random_tree_fixed_size( - tree_size_to_generate, options, nfeatures, T - ), - ) - @recorder tmp_recorder["type"] = "regenerate" - - is_success_always_possible = true - end - elseif mutation_choice == :optimize - cur_member = PopMember( - tree, - beforeScore, - beforeLoss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, - ) - cur_member, new_num_evals = optimize_constants(dataset, cur_member, options) - num_evals += new_num_evals - @recorder tmp_recorder["type"] = "optimize" - mutation_accepted = true - is_success_always_possible = true - return (cur_member, mutation_accepted, num_evals) - elseif mutation_choice == :do_nothing - @recorder begin - tmp_recorder["type"] = "identity" - tmp_recorder["result"] = "accept" - tmp_recorder["reason"] = "identity" - end - mutation_accepted = true - is_success_always_possible = true - return ( - PopMember( - tree, - beforeScore, - beforeLoss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, - ) - elseif mutation_choice == :form_connection - tree = form_random_connection!(tree) - @recorder tmp_recorder["type"] = "form_connection" - is_success_always_possible = true - elseif mutation_choice == :break_connection - tree = break_random_connection!(tree) - @recorder tmp_recorder["type"] = "break_connection" - is_success_always_possible = true - else - error("Unknown mutation choice: $mutation_choice") end - successful_mutation = - successful_mutation && check_constraints(tree, options, curmaxsize) - - attempts += 1 - end - ############################################# - tree::AbstractExpression - - if !successful_mutation - @recorder begin - tmp_recorder["result"] = "reject" - tmp_recorder["reason"] = "failed_constraint_check" - end - mutation_accepted = false - return ( - PopMember( - copy_node(member.tree), - beforeScore, - beforeLoss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, - ) - end - - if options.batching - afterScore, afterLoss = score_func_batched(dataset, tree, options) - num_evals += (options.batch_size / dataset.n) - else - afterScore, afterLoss = score_func(dataset, tree, options) - num_evals += 1 - end - - if isnan(afterScore) - @recorder begin - tmp_recorder["result"] = "reject" - tmp_recorder["reason"] = "nan_loss" - end - mutation_accepted = false - return ( - PopMember( - copy_node(member.tree), - beforeScore, - beforeLoss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, - ) - end - - probChange = 1.0 - if options.annealing - delta = afterScore - beforeScore - probChange *= exp(-delta / (temperature * options.alpha)) - end - newSize = -1 - if options.use_frequency - oldSize = compute_complexity(member, options) - newSize = compute_complexity(tree, options) - old_frequency = if (0 < oldSize <= options.maxsize) - running_search_statistics.normalized_frequencies[oldSize] - else - 1e-6 - end - new_frequency = if (0 < newSize <= options.maxsize) - running_search_statistics.normalized_frequencies[newSize] - else - 1e-6 - end - probChange *= old_frequency / new_frequency - end - - if probChange < rand() - @recorder begin - tmp_recorder["result"] = "reject" - tmp_recorder["reason"] = "annealing_or_frequency" - end - mutation_accepted = false - return ( - PopMember( - copy_node(member.tree), - beforeScore, - beforeLoss, - options, - compute_complexity(member, options); - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, - ) - else - @recorder begin - tmp_recorder["result"] = "accept" - tmp_recorder["reason"] = "pass" - end - mutation_accepted = true - return ( - PopMember( - tree, - afterScore, - afterLoss, - options, - newSize; - parent=parent_ref, - deterministic=options.deterministic, - ), - mutation_accepted, - num_evals, - ) - end -end - -"""Generate a generation via crossover of two members.""" -@unstable function crossover_generation( - member1::P, - member2::P, - dataset::D, - curmaxsize::Int, - options::Options; - dominating=nothing, - idea_database=nothing, -)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} - tree1 = member1.tree - tree2 = member2.tree - - # add simplification for crossover - tree1 = simplify_tree!(tree1, options.operators) - tree1 = combine_operators(tree1, options.operators) - tree2 = simplify_tree!(tree2, options.operators) - tree2 = combine_operators(tree2, options.operators) - - crossover_accepted = false - nfeatures = dataset.nfeatures - - if check_constant(tree1) - tree1 = with_contents( - tree1, gen_random_tree_fixed_size(rand(1:curmaxsize), options, nfeatures, T) - ) - end - if check_constant(tree2) - tree2 = with_contents( - tree2, gen_random_tree_fixed_size(rand(1:curmaxsize), options, nfeatures, T) - ) - end - - child_tree1 = nothing - child_tree2 = nothing - llm_skip = false - if options.llm_options.active && (rand() < options.llm_options.weights.llm_crossover) - child_tree1, child_tree2 = llm_crossover_trees( - tree1, tree2, options, dominating, idea_database - ) + child_tree1, child_tree2 = llm_crossover_trees(tree1, tree2, options) child_tree1 = simplify_tree!(child_tree1, options.operators) child_tree1 = combine_operators(child_tree1, options.operators) @@ -487,83 +86,25 @@ end if successful_crossover recorder_str = - tree_to_expr(child_tree1, options) * + render_expr(child_tree1, options) * " && " * - tree_to_expr(child_tree2, options) + render_expr(child_tree2, options) llm_recorder(options.llm_options, recorder_str, "crossover") llm_skip = true else recorder_str = - tree_to_expr(child_tree1, options) * + render_expr(child_tree1, options) * " && " * - tree_to_expr(child_tree2, options) + render_expr(child_tree2, options) llm_recorder(options.llm_options, recorder_str, "crossover|failed") child_tree1, child_tree2 = crossover_trees(tree1, tree2) end - else - child_tree1, child_tree2 = crossover_trees(tree1, tree2) end - - # We breed these until constraints are no longer violated: - num_tries = 1 - max_tries = 10 - num_evals = 0.0 - afterSize1 = -1 - afterSize2 = -1 - while !llm_skip - afterSize1 = compute_complexity(child_tree1, options) - afterSize2 = compute_complexity(child_tree2, options) - # Both trees satisfy constraints - if check_constraints(child_tree1, options, curmaxsize, afterSize1) && - check_constraints(child_tree2, options, curmaxsize, afterSize2) - break - end - if num_tries > max_tries - crossover_accepted = false - return member1, member2, crossover_accepted, num_evals # Fail. - end - child_tree1, child_tree2 = crossover_trees(tree1, tree2) - num_tries += 1 - end - if options.batching - afterScore1, afterLoss1 = score_func_batched( - dataset, child_tree1, options; complexity=afterSize1 + if !llm_skip + return crossover_generation( + member1, member2, dataset, curmaxsize, options.sr_options ) - afterScore2, afterLoss2 = score_func_batched( - dataset, child_tree2, options; complexity=afterSize2 - ) - num_evals += 2 * (options.batch_size / dataset.n) - else - afterScore1, afterLoss1 = score_func( - dataset, child_tree1, options; complexity=afterSize1 - ) - afterScore2, afterLoss2 = score_func( - dataset, child_tree2, options; complexity=afterSize2 - ) - num_evals += options.batch_size / dataset.n end - - baby1 = PopMember( - child_tree1, - afterScore1, - afterLoss1, - options, - afterSize1; - parent=member1.ref, - deterministic=options.deterministic, - )::P - baby2 = PopMember( - child_tree2, - afterScore2, - afterLoss2, - options, - afterSize2; - parent=member2.ref, - deterministic=options.deterministic, - )::P - - crossover_accepted = true - return baby1, baby2, crossover_accepted, num_evals end end diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl deleted file mode 100644 index 31534054f..000000000 --- a/src/MutationFunctions.jl +++ /dev/null @@ -1,441 +0,0 @@ -module MutationFunctionsModule - -using Random: default_rng, AbstractRNG -using DynamicExpressions: - AbstractExpressionNode, - AbstractExpression, - AbstractNode, - NodeSampler, - get_contents, - with_contents, - constructorof, - copy_node, - set_node!, - count_nodes, - has_constants, - has_operators -using Compat: Returns, @inline -using ..CoreModule: Options, DATA_TYPE - -""" - random_node(tree::AbstractNode; filter::F=Returns(true)) - -Return a random node from the tree. You may optionally -filter the nodes matching some condition before sampling. -""" -function random_node( - tree::AbstractNode, rng::AbstractRNG=default_rng(); filter::F=Returns(true) -) where {F<:Function} - Base.depwarn( - "Instead of `random_node(tree, filter)`, use `rand(NodeSampler(; tree, filter))`", - :random_node, - ) - return rand(rng, NodeSampler(; tree, filter)) -end - -"""Swap operands in binary operator for ops like pow and divide""" -function swap_operands(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - ex = with_contents(ex, swap_operands(tree, rng)) - return ex -end -function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng()) - if !any(node -> node.degree == 2, tree) - return tree - end - node = rand(rng, NodeSampler(; tree, filter=t -> t.degree == 2)) - node.l, node.r = node.r, node.l - return tree -end - -"""Randomly convert an operator into another one (binary->binary; unary->unary)""" -function mutate_operator( - ex::AbstractExpression{T}, options::Options, rng::AbstractRNG=default_rng() -) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, mutate_operator(tree, options, rng)) - return ex -end -function mutate_operator( - tree::AbstractExpressionNode{T}, options::Options, rng::AbstractRNG=default_rng() -) where {T} - if !(has_operators(tree)) - return tree - end - node = rand(rng, NodeSampler(; tree, filter=t -> t.degree != 0)) - if node.degree == 1 - node.op = rand(rng, 1:(options.nuna)) - else - node.op = rand(rng, 1:(options.nbin)) - end - return tree -end - -"""Randomly perturb a constant""" -function mutate_constant( - ex::AbstractExpression{T}, temperature, options::Options, rng::AbstractRNG=default_rng() -) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, mutate_constant(tree, temperature, options, rng)) - return ex -end -function mutate_constant( - tree::AbstractExpressionNode{T}, - temperature, - options::Options, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - # T is between 0 and 1. - - if !(has_constants(tree)) - return tree - end - node = rand(rng, NodeSampler(; tree, filter=t -> (t.degree == 0 && t.constant))) - - node.val *= mutate_factor(T, temperature, options, rng) - - return tree -end - -function mutate_factor(::Type{T}, temperature, options, rng) where {T<:DATA_TYPE} - bottom = 1//10 - maxChange = options.perturbation_factor * temperature + 1 + bottom - factor = T(maxChange^rand(rng, T)) - makeConstBigger = rand(rng, Bool) - - factor = makeConstBigger ? factor : 1 / factor - - if rand(rng) > options.probability_negate_constant - factor *= -1 - end - return factor -end - -# TODO: Shouldn't we add a mutate_feature here? - -"""Add a random unary/binary operation to the end of a tree""" -function append_random_op( - ex::AbstractExpression{T}, - options::Options, - nfeatures::Int, - rng::AbstractRNG=default_rng(); - makeNewBinOp::Union{Bool,Nothing}=nothing, -) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, append_random_op(tree, options, nfeatures, rng; makeNewBinOp)) - return ex -end -function append_random_op( - tree::AbstractExpressionNode{T}, - options::Options, - nfeatures::Int, - rng::AbstractRNG=default_rng(); - makeNewBinOp::Union{Bool,Nothing}=nothing, -) where {T<:DATA_TYPE} - node = rand(rng, NodeSampler(; tree, filter=t -> t.degree == 0)) - - if makeNewBinOp === nothing - choice = rand(rng) - makeNewBinOp = choice < options.nbin / (options.nuna + options.nbin) - end - - if makeNewBinOp - newnode = constructorof(typeof(tree))(; - op=rand(rng, 1:(options.nbin)), - l=make_random_leaf(nfeatures, T, typeof(tree), rng, options), - r=make_random_leaf(nfeatures, T, typeof(tree), rng, options), - ) - else - newnode = constructorof(typeof(tree))(; - op=rand(rng, 1:(options.nuna)), - l=make_random_leaf(nfeatures, T, typeof(tree), rng, options), - ) - end - - set_node!(node, newnode) - - return tree -end - -"""Insert random node""" -function insert_random_op( - ex::AbstractExpression{T}, - options::Options, - nfeatures::Int, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, insert_random_op(tree, options, nfeatures, rng)) - return ex -end -function insert_random_op( - tree::AbstractExpressionNode{T}, - options::Options, - nfeatures::Int, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - node = rand(rng, NodeSampler(; tree)) - choice = rand(rng) - makeNewBinOp = choice < options.nbin / (options.nuna + options.nbin) - left = copy_node(node) - - if makeNewBinOp - right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) - newnode = constructorof(typeof(tree))(; - op=rand(rng, 1:(options.nbin)), l=left, r=right - ) - else - newnode = constructorof(typeof(tree))(; op=rand(rng, 1:(options.nuna)), l=left) - end - set_node!(node, newnode) - return tree -end - -"""Add random node to the top of a tree""" -function prepend_random_op( - ex::AbstractExpression{T}, - options::Options, - nfeatures::Int, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, prepend_random_op(tree, options, nfeatures, rng)) - return ex -end -function prepend_random_op( - tree::AbstractExpressionNode{T}, - options::Options, - nfeatures::Int, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - node = tree - choice = rand(rng) - makeNewBinOp = choice < options.nbin / (options.nuna + options.nbin) - left = copy_node(tree) - - if makeNewBinOp - right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) - newnode = constructorof(typeof(tree))(; - op=rand(rng, 1:(options.nbin)), l=left, r=right - ) - else - newnode = constructorof(typeof(tree))(; op=rand(rng, 1:(options.nuna)), l=left) - end - set_node!(node, newnode) - return node -end - -function make_random_leaf( - nfeatures::Int, - ::Type{T}, - ::Type{N}, - rng::AbstractRNG=default_rng(), - ::Union{Options,Nothing}=nothing, -) where {T<:DATA_TYPE,N<:AbstractExpressionNode} - if rand(rng, Bool) - return constructorof(N)(; val=randn(rng, T)) - else - return constructorof(N)(T; feature=rand(rng, 1:nfeatures)) - end -end - -"""Return a random node from the tree with parent, and side ('n' for no parent)""" -function random_node_and_parent(tree::AbstractNode, rng::AbstractRNG=default_rng()) - if tree.degree == 0 - return tree, tree, 'n' - end - parent = rand(rng, NodeSampler(; tree, filter=t -> t.degree != 0)) - if parent.degree == 1 || rand(rng, Bool) - return (parent.l, parent, 'l') - else - return (parent.r, parent, 'r') - end -end - -"""Select a random node, and splice it out of the tree.""" -function delete_random_op!( - ex::AbstractExpression{T}, - options::Options, - nfeatures::Int, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - tree = get_contents(ex) - ex = with_contents(ex, delete_random_op!(tree, options, nfeatures, rng)) - return ex -end -function delete_random_op!( - tree::AbstractExpressionNode{T}, - options::Options, - nfeatures::Int, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - node, parent, side = random_node_and_parent(tree, rng) - isroot = side == 'n' - - if node.degree == 0 - # Replace with new constant - newnode = make_random_leaf(nfeatures, T, typeof(tree), rng, options) - set_node!(node, newnode) - elseif node.degree == 1 - # Join one of the children with the parent - if isroot - return node.l - elseif parent.l == node - parent.l = node.l - else - parent.r = node.l - end - else - # Join one of the children with the parent - if rand(rng, Bool) - if isroot - return node.l - elseif parent.l == node - parent.l = node.l - else - parent.r = node.l - end - else - if isroot - return node.r - elseif parent.l == node - parent.l = node.r - else - parent.r = node.r - end - end - end - return tree -end - -"""Create a random equation by appending random operators""" -function gen_random_tree( - length::Int, options::Options, nfeatures::Int, ::Type{T}, rng::AbstractRNG=default_rng() -) where {T<:DATA_TYPE} - # Note that this base tree is just a placeholder; it will be replaced. - tree = constructorof(options.node_type)(T; val=convert(T, 1)) - for i in 1:length - # TODO: This can be larger number of nodes than length. - tree = append_random_op(tree, options, nfeatures, rng) - end - return tree -end - -function gen_random_tree_fixed_size( - node_count::Int, - options::Options, - nfeatures::Int, - ::Type{T}, - rng::AbstractRNG=default_rng(), -) where {T<:DATA_TYPE} - tree = make_random_leaf(nfeatures, T, options.node_type, rng, options) - cur_size = count_nodes(tree) - while cur_size < node_count - if cur_size == node_count - 1 # only unary operator allowed. - options.nuna == 0 && break # We will go over the requested amount, so we must break. - tree = append_random_op(tree, options, nfeatures, rng; makeNewBinOp=false) - else - tree = append_random_op(tree, options, nfeatures, rng) - end - cur_size = count_nodes(tree) - end - return tree -end - -function crossover_trees( - ex1::E, ex2::E, rng::AbstractRNG=default_rng() -) where {T,E<:AbstractExpression{T}} - tree1 = get_contents(ex1) - tree2 = get_contents(ex2) - out1, out2 = crossover_trees(tree1, tree2, rng) - ex1 = with_contents(ex1, out1) - ex2 = with_contents(ex2, out2) - return ex1, ex2 -end - -"""Crossover between two expressions""" -function crossover_trees( - tree1::N, tree2::N, rng::AbstractRNG=default_rng() -) where {T,N<:AbstractExpressionNode{T}} - tree1 = copy_node(tree1) - tree2 = copy_node(tree2) - - node1, parent1, side1 = random_node_and_parent(tree1, rng) - node2, parent2, side2 = random_node_and_parent(tree2, rng) - - node1 = copy_node(node1) - - if side1 == 'l' - parent1.l = copy_node(node2) - # tree1 now contains this. - elseif side1 == 'r' - parent1.r = copy_node(node2) - # tree1 now contains this. - else # 'n' - # This means that there is no parent2. - tree1 = copy_node(node2) - end - - if side2 == 'l' - parent2.l = node1 - elseif side2 == 'r' - parent2.r = node1 - else # 'n' - tree2 = node1 - end - return tree1, tree2 -end - -function get_two_nodes_without_loop(tree::AbstractNode, rng::AbstractRNG; max_attempts=10) - for _ in 1:max_attempts - parent = rand(rng, NodeSampler(; tree, filter=t -> t.degree != 0)) - new_child = rand(rng, NodeSampler(; tree, filter=t -> t !== tree)) - - would_form_loop = any(t -> t === parent, new_child) - if !would_form_loop - return (parent, new_child, false) - end - end - return (tree, tree, true) -end - -function form_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - return with_contents(ex, form_random_connection!(tree, rng)) -end -function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) - if length(tree) < 5 - return tree - end - - parent, new_child, would_form_loop = get_two_nodes_without_loop(tree, rng) - - if would_form_loop - return tree - end - - # Set one of the children to be this new child: - if parent.degree == 1 || rand(rng, Bool) - parent.l = new_child - else - parent.r = new_child - end - return tree -end - -function break_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) - tree = get_contents(ex) - return with_contents(ex, break_random_connection!(tree, rng)) -end -function break_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) - tree.degree == 0 && return tree - parent = rand(rng, NodeSampler(; tree, filter=t -> t.degree != 0)) - if parent.degree == 1 || rand(rng, Bool) - parent.l = copy(parent.l) - else - parent.r = copy(parent.r) - end - return tree -end - -end diff --git a/src/MutationWeights.jl b/src/MutationWeights.jl index 1f3f7369f..3edc95d77 100644 --- a/src/MutationWeights.jl +++ b/src/MutationWeights.jl @@ -1,66 +1,73 @@ -module MutationWeightsModule +module LaSRMutationWeightsModule -using StatsBase: StatsBase +using DispatchDoctor: @unstable +using Base +using SymbolicRegression """ - MutationWeights(;kws...) + LLMMutationProbabilities(;kws...) + +Defines the probability of different LLM-based mutation operations. +NOTE: + - These must sum up to 1.0. + - The LLM operations can be significantly slower than their symbolic counterparts, + so higher probabilities will result in slower operations. By default, we set all probs to 0.0. -This defines how often different mutations occur. These weightings -will be normalized to sum to 1.0 after initialization. # Arguments -- `mutate_constant::Float64`: How often to mutate a constant. -- `mutate_operator::Float64`: How often to mutate an operator. -- `swap_operands::Float64`: How often to swap the operands of a binary operator. -- `add_node::Float64`: How often to append a node to the tree. -- `insert_node::Float64`: How often to insert a node into the tree. -- `delete_node::Float64`: How often to delete a node from the tree. -- `simplify::Float64`: How often to simplify the tree. -- `randomize::Float64`: How often to create a random tree. -- `do_nothing::Float64`: How often to do nothing. -- `optimize::Float64`: How often to optimize the constants in the tree, as a mutation. - Note that this is different from `optimizer_probability`, which is - performed at the end of an iteration for all individuals. -- `form_connection::Float64`: **Only used for `GraphNode`, not regular `Node`**. - Otherwise, this will automatically be set to 0.0. How often to form a - connection between two nodes. -- `break_connection::Float64`: **Only used for `GraphNode`, not regular `Node`**. - Otherwise, this will automatically be set to 0.0. How often to break a - connection between two nodes. +- `llm_mutate::Float64`: Probability of calling LLM version of mutation. +- `llm_gen_random::Float64`: Probability of calling LLM version of gen_random. + +TODO: Implement more prompts so we can make specialized mutation operators like +llm_mutate_const, llm_mutate_operation. """ -Base.@kwdef mutable struct MutationWeights - mutate_constant::Float64 = 0.048 - mutate_operator::Float64 = 0.47 - swap_operands::Float64 = 0.1 - add_node::Float64 = 0.79 - insert_node::Float64 = 5.1 - delete_node::Float64 = 1.7 - simplify::Float64 = 0.0020 - randomize::Float64 = 0.00023 - do_nothing::Float64 = 0.21 - optimize::Float64 = 0.0 - form_connection::Float64 = 0.5 - break_connection::Float64 = 0.1 +Base.@kwdef mutable struct LLMMutationProbabilities + llm_mutate::Float64 = 0.0 + llm_randomize::Float64 = 0.0 end -const mutations = fieldnames(MutationWeights) -const v_mutations = Symbol[mutations...] +""" + LaSRMutationWeights{W<:SymbolicRegression.MutationWeights}(mutation_weights::W, llm_weights::LLMMutationProbabilities) + +Defines the composite weights for all the mutation operations in the LaSR module. +""" +mutable struct LaSRMutationWeights{W<:SymbolicRegression.MutationWeights} <: + SymbolicRegression.AbstractMutationWeights + sr_weights::W + llm_weights::LLMMutationProbabilities +end +const LLM_MUTATION_WEIGHTS_KEYS = fieldnames(LLMMutationProbabilities) + +@unstable function LaSRMutationWeights(; kws...) + sr_weights_keys = filter(k -> !(k in LLM_MUTATION_WEIGHTS_KEYS), keys(kws)) + sr_weights = SymbolicRegression.MutationWeights(; + NamedTuple(sr_weights_keys .=> Tuple(kws[k] for k in sr_weights_keys))... + ) + sr_weights_vec = [getfield(sr_weights, f) for f in fieldnames(typeof(sr_weights))] + + llm_weights_keys = filter(k -> k in LLM_MUTATION_WEIGHTS_KEYS, keys(kws)) + llm_weights = LLMMutationProbabilities(; + NamedTuple(llm_weights_keys .=> Tuple(kws[k] for k in llm_weights_keys))... + ) + llm_weights_vec = [getfield(llm_weights, f) for f in fieldnames(typeof(llm_weights))] + + norm_sr_weights = SymbolicRegression.MutationWeights( + sr_weights_vec * (1 - sum(llm_weights_vec))... + ) + norm_llm_weights = LLMMutationProbabilities(llm_weights_vec * sum(sr_weights_vec)...) + + return LaSRMutationWeights(norm_sr_weights, norm_llm_weights) +end -# For some reason it's much faster to write out the fields explicitly: -let contents = [Expr(:., :w, QuoteNode(field)) for field in mutations] - @eval begin - function Base.convert(::Type{Vector}, w::MutationWeights)::Vector{Float64} - return $(Expr(:vect, contents...)) - end - function Base.copy(w::MutationWeights) - return $(Expr(:call, :MutationWeights, contents...)) - end +function Base.getproperty(weights::LaSRMutationWeights, k::Symbol) + if k in LLM_MUTATION_WEIGHTS_KEYS + return getproperty(getfield(weights, :llm_weights), k) + else + return getproperty(getfield(weights, :sr_weights), k) end end -"""Sample a mutation, given the weightings.""" -function sample_mutation(w::MutationWeights) - weights = convert(Vector, w) - return StatsBase.sample(v_mutations, StatsBase.Weights(weights)) +function Base.propertynames(weights::LaSRMutationWeights) + return (LLM_MUTATION_WEIGHTS_KEYS..., SymbolicRegression.MUTATION_WEIGHTS_KEYS...) end end diff --git a/src/Operators.jl b/src/Operators.jl deleted file mode 100644 index e7b99ea10..000000000 --- a/src/Operators.jl +++ /dev/null @@ -1,117 +0,0 @@ -module OperatorsModule - -using DynamicExpressions: DynamicExpressions as DE -using SpecialFunctions: SpecialFunctions -using DynamicQuantities: UnionAbstractQuantity -using SpecialFunctions: erf, erfc -using Base: @deprecate -using ..ProgramConstantsModule: DATA_TYPE -using ...UtilsModule: @ignore -#TODO - actually add these operators to the module! - -# TODO: Should this be limited to AbstractFloat instead? -function gamma(x::T)::T where {T<:DATA_TYPE} - out = SpecialFunctions.gamma(x) - return isinf(out) ? T(NaN) : out -end -gamma(x) = SpecialFunctions.gamma(x) - -atanh_clip(x) = atanh(mod(x + oneunit(x), oneunit(x) + oneunit(x)) - oneunit(x)) * one(x) -# == atanh((x + 1) % 2 - 1) - -# Implicitly defined: -#binary: mod -#unary: exp, abs, log1p, sin, cos, tan, sinh, cosh, tanh, asin, acos, atan, asinh, acosh, atanh, erf, erfc, gamma, relu, round, floor, ceil, round, sign. - -# Use some fast operators from https://github.com/JuliaLang/julia/blob/81597635c4ad1e8c2e1c5753fda4ec0e7397543f/base/fastmath.jl -# Define allowed operators. Any julia operator can also be used. -# TODO: Add all of these operators to the precompilation. -# TODO: Since simplification is done in DynamicExpressions.jl, are these names correct anymore? -function safe_pow(x::T, y::T)::T where {T<:Union{AbstractFloat,UnionAbstractQuantity}} - if isinteger(y) - y < zero(y) && iszero(x) && return T(NaN) - else - y > zero(y) && x < zero(x) && return T(NaN) - y < zero(y) && x <= zero(x) && return T(NaN) - end - return x^y -end -function safe_log(x::T)::T where {T<:AbstractFloat} - x <= zero(x) && return T(NaN) - return log(x) -end -function safe_log2(x::T)::T where {T<:AbstractFloat} - x <= zero(x) && return T(NaN) - return log2(x) -end -function safe_log10(x::T)::T where {T<:AbstractFloat} - x <= zero(x) && return T(NaN) - return log10(x) -end -function safe_log1p(x::T)::T where {T<:AbstractFloat} - x <= -oneunit(x) && return T(NaN) - return log1p(x) -end -function safe_acosh(x::T)::T where {T<:AbstractFloat} - x < oneunit(x) && return T(NaN) - return acosh(x) -end -function safe_sqrt(x::T)::T where {T<:AbstractFloat} - x < zero(x) && return T(NaN) - return sqrt(x) -end -# TODO: Should the above be made more generic, for, e.g., compatibility with units? - -# Do not change the names of these operators, as -# they have special use in simplifications and printing. -square(x) = x * x -cube(x) = x * x * x -plus(x, y) = x + y -sub(x, y) = x - y -mult(x, y) = x * y -# Generics (for SIMD) -safe_pow(x, y) = x^y -safe_log(x) = log(x) -safe_log2(x) = log2(x) -safe_log10(x) = log10(x) -safe_log1p(x) = log1p(x) -safe_acosh(x) = acosh(x) -safe_sqrt(x) = sqrt(x) - -function neg(x) - return -x -end -function greater(x, y) - return (x > y) * one(x) -end -function cond(x, y) - return (x > zero(x)) * y -end -function relu(x) - return (x > zero(x)) * x -end -function logical_or(x, y) - return ((x > zero(x)) | (y > zero(y))) * one(x) -end -function logical_and(x, y) - return ((x > zero(x)) & (y > zero(y))) * one(x) -end - -# Strings -DE.get_op_name(::typeof(safe_pow)) = "^" -DE.get_op_name(::typeof(safe_log)) = "log" -DE.get_op_name(::typeof(safe_log2)) = "log2" -DE.get_op_name(::typeof(safe_log10)) = "log10" -DE.get_op_name(::typeof(safe_log1p)) = "log1p" -DE.get_op_name(::typeof(safe_acosh)) = "acosh" -DE.get_op_name(::typeof(safe_sqrt)) = "sqrt" - -# Deprecated operations: -@deprecate pow(x, y) safe_pow(x, y) -@deprecate pow_abs(x, y) safe_pow(x, y) - -# For static analysis tools: -@ignore pow(x, y) = safe_pow(x, y) -@ignore pow_abs(x, y) = safe_pow(x, y) - -end diff --git a/src/Options.jl b/src/Options.jl deleted file mode 100644 index 5268b5955..000000000 --- a/src/Options.jl +++ /dev/null @@ -1,829 +0,0 @@ -module OptionsModule - -using DispatchDoctor: @unstable -using Optim: Optim -using Dates: Dates -using StatsBase: StatsBase -using DynamicExpressions: OperatorEnum, Node, Expression, default_node_type -using ADTypes: AbstractADType, ADTypes -using LossFunctions: L2DistLoss, SupervisedLoss -using Optim: Optim -using LineSearches: LineSearches -#TODO - eventually move some of these -# into the SR call itself, rather than -# passing huge options at once. -using ..OperatorsModule: - plus, - pow, - safe_pow, - mult, - sub, - safe_log, - safe_log10, - safe_log2, - safe_log1p, - safe_sqrt, - safe_acosh, - atanh_clip -using ..MutationWeightsModule: MutationWeights, mutations -using ..LLMOptionsModule: LLMOptions, validate_llm_options -import ..OptionsStructModule: Options -using ..OptionsStructModule: ComplexityMapping, operator_specialization -using ..UtilsModule: max_ops, @save_kwargs, @ignore - -"""Build constraints on operator-level complexity from a user-passed dict.""" -@unstable function build_constraints(; - una_constraints, - bin_constraints, - @nospecialize(unary_operators), - @nospecialize(binary_operators) -)::Tuple{Vector{Int},Vector{Tuple{Int,Int}}} - # Expect format ((*)=>(-1, 3)), etc. - # TODO: Need to disable simplification if (*, -, +, /) are constrained? - # Or, just quit simplification is constraints violated. - - is_una_constraints_already_done = una_constraints isa Vector{Int} - _una_constraints1 = if una_constraints isa Array && !is_una_constraints_already_done - Dict(una_constraints) - else - una_constraints - end - _una_constraints2 = if _una_constraints1 === nothing - fill(-1, length(unary_operators)) - elseif !is_una_constraints_already_done - [ - haskey(_una_constraints1, op) ? _una_constraints1[op]::Int : -1 for - op in unary_operators - ] - else - _una_constraints1 - end - - is_bin_constraints_already_done = bin_constraints isa Vector{Tuple{Int,Int}} - _bin_constraints1 = if bin_constraints isa Array && !is_bin_constraints_already_done - Dict(bin_constraints) - else - bin_constraints - end - _bin_constraints2 = if _bin_constraints1 === nothing - fill((-1, -1), length(binary_operators)) - elseif !is_bin_constraints_already_done - [ - if haskey(_bin_constraints1, op) - _bin_constraints1[op]::Tuple{Int,Int} - else - (-1, -1) - end for op in binary_operators - ] - else - _bin_constraints1 - end - - return _una_constraints2, _bin_constraints2 -end - -@unstable function build_nested_constraints(; - @nospecialize(binary_operators), @nospecialize(unary_operators), nested_constraints -) - nested_constraints === nothing && return nested_constraints - # Check that intersection of binary operators and unary operators is empty: - for op in binary_operators - if op ∈ unary_operators - error( - "Operator $(op) is both a binary and unary operator. " * - "You can't use nested constraints.", - ) - end - end - - # Convert to dict: - _nested_constraints = if nested_constraints isa Dict - nested_constraints - else - # Convert to dict: - nested_constraints = Dict( - [cons[1] => Dict(cons[2]...) for cons in nested_constraints]... - ) - end - for (op, nested_constraint) in _nested_constraints - if !(op ∈ binary_operators || op ∈ unary_operators) - error("Operator $(op) is not in the operator set.") - end - for (nested_op, max_nesting) in nested_constraint - if !(nested_op ∈ binary_operators || nested_op ∈ unary_operators) - error("Operator $(nested_op) is not in the operator set.") - end - @assert nested_op ∈ binary_operators || nested_op ∈ unary_operators - @assert max_nesting >= -1 && typeof(max_nesting) <: Int - end - end - - # Lastly, we clean it up into a dict of (degree,op_idx) => max_nesting. - return [ - let (degree, idx) = if op ∈ binary_operators - 2, findfirst(isequal(op), binary_operators)::Int - else - 1, findfirst(isequal(op), unary_operators)::Int - end, - new_max_nesting_dict = [ - let (nested_degree, nested_idx) = if nested_op ∈ binary_operators - 2, findfirst(isequal(nested_op), binary_operators)::Int - else - 1, findfirst(isequal(nested_op), unary_operators)::Int - end - (nested_degree, nested_idx, max_nesting) - end for (nested_op, max_nesting) in nested_constraint - ] - - (degree, idx, new_max_nesting_dict) - end for (op, nested_constraint) in _nested_constraints - ] -end - -function binopmap(op::F) where {F} - if op == plus - return + - elseif op == mult - return * - elseif op == sub - return - - elseif op == div - return / - elseif op == ^ - return safe_pow - elseif op == pow - return safe_pow - end - return op -end -function inverse_binopmap(op::F) where {F} - if op == safe_pow - return ^ - end - return op -end - -function unaopmap(op::F) where {F} - if op == log - return safe_log - elseif op == log10 - return safe_log10 - elseif op == log2 - return safe_log2 - elseif op == log1p - return safe_log1p - elseif op == sqrt - return safe_sqrt - elseif op == acosh - return safe_acosh - elseif op == atanh - return atanh_clip - end - return op -end -function inverse_unaopmap(op::F) where {F} - if op == safe_log - return log - elseif op == safe_log10 - return log10 - elseif op == safe_log2 - return log2 - elseif op == safe_log1p - return log1p - elseif op == safe_sqrt - return sqrt - elseif op == safe_acosh - return acosh - elseif op == atanh_clip - return atanh - end - return op -end - -create_mutation_weights(w::MutationWeights) = w -create_mutation_weights(w::NamedTuple) = MutationWeights(; w...) - -create_llm_options(w::LLMOptions) = w -create_llm_options(w::NamedTuple) = LLMOptions(; w...) - -const deprecated_options_mapping = Base.ImmutableDict( - :mutationWeights => :mutation_weights, - :hofMigration => :hof_migration, - :shouldOptimizeConstants => :should_optimize_constants, - :hofFile => :output_file, - :perturbationFactor => :perturbation_factor, - :batchSize => :batch_size, - :crossoverProbability => :crossover_probability, - :warmupMaxsizeBy => :warmup_maxsize_by, - :useFrequency => :use_frequency, - :useFrequencyInTournament => :use_frequency_in_tournament, - :ncyclesperiteration => :ncycles_per_iteration, - :fractionReplaced => :fraction_replaced, - :fractionReplacedHof => :fraction_replaced_hof, - :probNegate => :probability_negate_constant, - :optimize_probability => :optimizer_probability, - :probPickFirst => :tournament_selection_p, - :earlyStopCondition => :early_stop_condition, - :stateReturn => :deprecated_return_state, - :return_state => :deprecated_return_state, - :enable_autodiff => :deprecated_enable_autodiff, - :ns => :tournament_selection_n, - :loss => :elementwise_loss, -) - -# For static analysis tools: -@ignore const DEFAULT_OPTIONS = () - -const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators (functions) to use. - Each operator should be defined for two input scalars, - and one output scalar. All operators - need to be defined over the entire real line (excluding infinity - these - are stopped before they are input), or return `NaN` where not defined. - For speed, define it so it takes two reals - of the same type as input, and outputs the same type. For the SymbolicUtils - simplification backend, you will need to define a generic method of the - operator so it takes arbitrary types. -- `unary_operators`: Same, but for - unary operators (one input scalar, gives an output scalar). -- `constraints`: Array of pairs specifying size constraints - for each operator. The constraints for a binary operator should be a 2-tuple - (e.g., `(-1, -1)`) and the constraints for a unary operator should be an `Int`. - A size constraint is a limit to the size of the subtree - in each argument of an operator. e.g., `[(^)=>(-1, 3)]` means that the - `^` operator can have arbitrary size (`-1`) in its left argument, - but a maximum size of `3` in its right argument. Default is - no constraints. -- `batching`: Whether to evolve based on small mini-batches of data, - rather than the entire dataset. -- `batch_size`: What batch size to use if using batching. -- `elementwise_loss`: What elementwise loss function to use. Can be one of - the following losses, or any other loss of type - `SupervisedLoss`. You can also pass a function that takes - a scalar target (left argument), and scalar predicted (right - argument), and returns a scalar. This will be averaged - over the predicted data. If weights are supplied, your - function should take a third argument for the weight scalar. - Included losses: - Regression: - - `LPDistLoss{P}()`, - - `L1DistLoss()`, - - `L2DistLoss()` (mean square), - - `LogitDistLoss()`, - - `HuberLoss(d)`, - - `L1EpsilonInsLoss(ϵ)`, - - `L2EpsilonInsLoss(ϵ)`, - - `PeriodicLoss(c)`, - - `QuantileLoss(τ)`, - Classification: - - `ZeroOneLoss()`, - - `PerceptronLoss()`, - - `L1HingeLoss()`, - - `SmoothedL1HingeLoss(γ)`, - - `ModifiedHuberLoss()`, - - `L2MarginLoss()`, - - `ExpLoss()`, - - `SigmoidLoss()`, - - `DWDMarginLoss(q)`. -- `loss_function`: Alternatively, you may redefine the loss used - as any function of `tree::AbstractExpressionNode{T}`, `dataset::Dataset{T}`, - and `options::Options`, so long as you output a non-negative - scalar of type `T`. This is useful if you want to use a loss - that takes into account derivatives, or correlations across - the dataset. This also means you could use a custom evaluation - for a particular expression. If you are using - `batching=true`, then your function should - accept a fourth argument `idx`, which is either `nothing` - (indicating that the full dataset should be used), or a vector - of indices to use for the batch. - For example, - - function my_loss(tree, dataset::Dataset{T,L}, options)::L where {T,L} - prediction, flag = eval_tree_array(tree, dataset.X, options) - if !flag - return L(Inf) - end - return sum((prediction .- dataset.y) .^ 2) / dataset.n - end - -- `expression_type::Type{E}=Expression`: The type of expression to use. - For example, `Expression`. -- `node_type::Type{N}=default_node_type(Expression)`: The type of node to use for the search. - For example, `Node` or `GraphNode`. The default is computed by `default_node_type(expression_type)`. -- `populations`: How many populations of equations to use. -- `population_size`: How many equations in each population. -- `ncycles_per_iteration`: How many generations to consider per iteration. -- `tournament_selection_n`: Number of expressions considered in each tournament. -- `tournament_selection_p`: The fittest expression in a tournament is to be - selected with probability `p`, the next fittest with probability `p*(1-p)`, - and so forth. -- `topn`: Number of equations to return to the host process, and to - consider for the hall of fame. -- `complexity_of_operators`: What complexity should be assigned to each operator, - and the occurrence of a constant or variable. By default, this is 1 - for all operators. Can be a real number as well, in which case - the complexity of an expression will be rounded to the nearest integer. - Input this in the form of, e.g., [(^) => 3, sin => 2]. -- `complexity_of_constants`: What complexity should be assigned to use of a constant. - By default, this is 1. -- `complexity_of_variables`: What complexity should be assigned to use of a variable, - which can also be a vector indicating different per-variable complexity. - By default, this is 1. -- `alpha`: The probability of accepting an equation mutation - during regularized evolution is given by exp(-delta_loss/(alpha * T)), - where T goes from 1 to 0. Thus, alpha=infinite is the same as no annealing. -- `maxsize`: Maximum size of equations during the search. -- `maxdepth`: Maximum depth of equations during the search, by default - this is set equal to the maxsize. -- `parsimony`: A multiplicative factor for how much complexity is - punished. -- `dimensional_constraint_penalty`: An additive factor if the dimensional - constraint is violated. -- `dimensionless_constants_only`: Whether to only allow dimensionless - constants. -- `use_frequency`: Whether to use a parsimony that adapts to the - relative proportion of equations at each complexity; this will - ensure that there are a balanced number of equations considered - for every complexity. -- `use_frequency_in_tournament`: Whether to use the adaptive parsimony described - above inside the score, rather than just at the mutation accept/reject stage. -- `adaptive_parsimony_scaling`: How much to scale the adaptive parsimony term - in the loss. Increase this if the search is spending too much time - optimizing the most complex equations. -- `turbo`: Whether to use `LoopVectorization.@turbo` to evaluate expressions. - This can be significantly faster, but is only compatible with certain - operators. *Experimental!* -- `bumper`: Whether to use Bumper.jl for faster evaluation. *Experimental!* -- `migration`: Whether to migrate equations between processes. -- `hof_migration`: Whether to migrate equations from the hall of fame - to processes. -- `fraction_replaced`: What fraction of each population to replace with - migrated equations at the end of each cycle. -- `fraction_replaced_hof`: What fraction to replace with hall of fame - equations at the end of each cycle. -- `should_simplify`: Whether to simplify equations. If you - pass a custom objective, this will be set to `false`. -- `should_optimize_constants`: Whether to use an optimization algorithm - to periodically optimize constants in equations. -- `optimizer_algorithm`: Select algorithm to use for optimizing constants. Default - is `Optim.BFGS(linesearch=LineSearches.BackTracking())`. -- `optimizer_nrestarts`: How many different random starting positions to consider - for optimization of constants. -- `optimizer_probability`: Probability of performing optimization of constants at - the end of a given iteration. -- `optimizer_iterations`: How many optimization iterations to perform. This gets - passed to `Optim.Options` as `iterations`. The default is 8. -- `optimizer_f_calls_limit`: How many function calls to allow during optimization. - This gets passed to `Optim.Options` as `f_calls_limit`. The default is - `10_000`. -- `optimizer_options`: General options for the constant optimization. For details - we refer to the documentation on `Optim.Options` from the `Optim.jl` package. - Options can be provided here as `NamedTuple`, e.g. `(iterations=16,)`, as a - `Dict`, e.g. Dict(:x_tol => 1.0e-32,), or as an `Optim.Options` instance. -- `autodiff_backend`: The backend to use for differentiation, which should be - an instance of `AbstractADType` (see `DifferentiationInterface.jl`). - Default is `nothing`, which means `Optim.jl` will estimate gradients (likely - with finite differences). You can also pass a symbolic version of the backend - type, such as `:Zygote` for Zygote, `:Enzyme`, etc. Most backends will not - work, and many will never work due to incompatibilities, though support for some - is gradually being added. -- `output_file`: What file to store equations to, as a backup. -- `perturbation_factor`: When mutating a constant, either - multiply or divide by (1+perturbation_factor)^(rand()+1). -- `probability_negate_constant`: Probability of negating a constant in the equation - when mutating it. -- `mutation_weights`: Relative probabilities of the mutations. The struct - `MutationWeights` should be passed to these options. - See its documentation on `MutationWeights` for the different weights. -- `llm_options`: Options for LLM inference. Managed through struct - `LLMOptions`. See its documentation for more details. -- `crossover_probability`: Probability of performing crossover. -- `annealing`: Whether to use simulated annealing. -- `warmup_maxsize_by`: Whether to slowly increase the max size from 5 up to - `maxsize`. If nonzero, specifies the fraction through the search - at which the maxsize should be reached. -- `verbosity`: Whether to print debugging statements or - not. -- `print_precision`: How many digits to print when printing - equations. By default, this is 5. -- `save_to_file`: Whether to save equations to a file during the search. -- `bin_constraints`: See `constraints`. This is the same, but specified for binary - operators only (for example, if you have an operator that is both a binary - and unary operator). -- `una_constraints`: Likewise, for unary operators. -- `seed`: What random seed to use. `nothing` uses no seed. -- `progress`: Whether to use a progress bar output (`verbosity` will - have no effect). -- `early_stop_condition`: Float - whether to stop early if the mean loss gets below this value. - Function - a function taking (loss, complexity) as arguments and returning true or false. -- `timeout_in_seconds`: Float64 - the time in seconds after which to exit (as an alternative to the number of iterations). -- `max_evals`: Int (or Nothing) - the maximum number of evaluations of expressions to perform. -- `skip_mutation_failures`: Whether to simply skip over mutations that fail or are rejected, rather than to replace the mutated - expression with the original expression and proceed normally. -- `nested_constraints`: Specifies how many times a combination of operators can be nested. For example, - `[sin => [cos => 0], cos => [cos => 2]]` specifies that `cos` may never appear within a `sin`, - but `sin` can be nested with itself an unlimited number of times. The second term specifies that `cos` - can be nested up to 2 times within a `cos`, so that `cos(cos(cos(x)))` is allowed (as well as any combination - of `+` or `-` within it), but `cos(cos(cos(cos(x))))` is not allowed. When an operator is not specified, - it is assumed that it can be nested an unlimited number of times. This requires that there is no operator - which is used both in the unary operators and the binary operators (e.g., `-` could be both subtract, and negation). - For binary operators, both arguments are treated the same way, and the max of each argument is constrained. -- `deterministic`: Use a global counter for the birth time, rather than calls to `time()`. This gives - perfect resolution, and is therefore deterministic. However, it is not thread safe, and must be used - in serial mode. -- `define_helper_functions`: Whether to define helper functions - for constructing and evaluating trees. -""" - -""" - Options(;kws...) - -Construct options for `equation_search` and other functions. -The current arguments have been tuned using the median values from -https://github.com/MilesCranmer/PySR/discussions/115. - -# Arguments -$(OPTION_DESCRIPTIONS) -""" -@unstable @save_kwargs DEFAULT_OPTIONS function Options(; - binary_operators=Function[+, -, /, *], - unary_operators=Function[], - constraints=nothing, - elementwise_loss::Union{Function,SupervisedLoss,Nothing}=nothing, - loss_function::Union{Function,Nothing}=nothing, - tournament_selection_n::Integer=12, #1 sampled from every tournament_selection_n per mutation - tournament_selection_p::Real=0.86, - topn::Integer=12, #samples to return per population - complexity_of_operators=nothing, - complexity_of_constants::Union{Nothing,Real}=nothing, - complexity_of_variables::Union{Nothing,Real,AbstractVector}=nothing, - parsimony::Real=0.0032, - dimensional_constraint_penalty::Union{Nothing,Real}=nothing, - dimensionless_constants_only::Bool=false, - alpha::Real=0.100000, - maxsize::Integer=20, - maxdepth::Union{Nothing,Integer}=nothing, - turbo::Bool=false, - bumper::Bool=false, - migration::Bool=true, - hof_migration::Bool=true, - should_simplify::Union{Nothing,Bool}=nothing, - should_optimize_constants::Bool=true, - output_file::Union{Nothing,AbstractString}=nothing, - expression_type::Type=Expression, - node_type::Type=default_node_type(expression_type), - expression_options::NamedTuple=NamedTuple(), - populations::Integer=15, - perturbation_factor::Real=0.076, - annealing::Bool=false, - batching::Bool=false, - batch_size::Integer=50, - mutation_weights::Union{MutationWeights,AbstractVector,NamedTuple}=MutationWeights(), - llm_options::LLMOptions=LLMOptions(), - crossover_probability::Real=0.066, - warmup_maxsize_by::Real=0.0, - use_frequency::Bool=true, - use_frequency_in_tournament::Bool=true, - adaptive_parsimony_scaling::Real=20.0, - population_size::Integer=33, - ncycles_per_iteration::Integer=550, - fraction_replaced::Real=0.00036, - fraction_replaced_hof::Real=0.035, - verbosity::Union{Integer,Nothing}=nothing, - print_precision::Integer=5, - save_to_file::Bool=true, - probability_negate_constant::Real=0.01, - seed=nothing, - bin_constraints=nothing, - una_constraints=nothing, - progress::Union{Bool,Nothing}=nothing, - terminal_width::Union{Nothing,Integer}=nothing, - optimizer_algorithm::Union{AbstractString,Optim.AbstractOptimizer}=Optim.BFGS(; - linesearch=LineSearches.BackTracking() - ), - optimizer_nrestarts::Integer=2, - optimizer_probability::Real=0.14, - optimizer_iterations::Union{Nothing,Integer}=nothing, - optimizer_f_calls_limit::Union{Nothing,Integer}=nothing, - optimizer_options::Union{Dict,NamedTuple,Optim.Options,Nothing}=nothing, - autodiff_backend::Union{AbstractADType,Symbol,Nothing}=nothing, - use_recorder::Bool=false, - recorder_file::AbstractString="pysr_recorder.json", - early_stop_condition::Union{Function,Real,Nothing}=nothing, - timeout_in_seconds::Union{Nothing,Real}=nothing, - max_evals::Union{Nothing,Integer}=nothing, - skip_mutation_failures::Bool=true, - nested_constraints=nothing, - deterministic::Bool=false, - # Not search options; just construction options: - define_helper_functions::Bool=true, - deprecated_return_state=nothing, - # Deprecated args: - fast_cycle::Bool=false, - npopulations::Union{Nothing,Integer}=nothing, - npop::Union{Nothing,Integer}=nothing, - kws..., -) - for k in keys(kws) - !haskey(deprecated_options_mapping, k) && error("Unknown keyword argument: $k") - new_key = deprecated_options_mapping[k] - if startswith(string(new_key), "deprecated_") - Base.depwarn("The keyword argument `$(k)` is deprecated.", :Options) - if string(new_key) != "deprecated_return_state" - # This one we actually want to use - continue - end - else - Base.depwarn( - "The keyword argument `$(k)` is deprecated. Use `$(new_key)` instead.", - :Options, - ) - end - # Now, set the new key to the old value: - #! format: off - k == :hofMigration && (hof_migration = kws[k]; true) && continue - k == :shouldOptimizeConstants && (should_optimize_constants = kws[k]; true) && continue - k == :hofFile && (output_file = kws[k]; true) && continue - k == :perturbationFactor && (perturbation_factor = kws[k]; true) && continue - k == :batchSize && (batch_size = kws[k]; true) && continue - k == :crossoverProbability && (crossover_probability = kws[k]; true) && continue - k == :warmupMaxsizeBy && (warmup_maxsize_by = kws[k]; true) && continue - k == :useFrequency && (use_frequency = kws[k]; true) && continue - k == :useFrequencyInTournament && (use_frequency_in_tournament = kws[k]; true) && continue - k == :ncyclesperiteration && (ncycles_per_iteration = kws[k]; true) && continue - k == :fractionReplaced && (fraction_replaced = kws[k]; true) && continue - k == :fractionReplacedHof && (fraction_replaced_hof = kws[k]; true) && continue - k == :probNegate && (probability_negate_constant = kws[k]; true) && continue - k == :optimize_probability && (optimizer_probability = kws[k]; true) && continue - k == :probPickFirst && (tournament_selection_p = kws[k]; true) && continue - k == :earlyStopCondition && (early_stop_condition = kws[k]; true) && continue - k == :return_state && (deprecated_return_state = kws[k]; true) && continue - k == :stateReturn && (deprecated_return_state = kws[k]; true) && continue - k == :enable_autodiff && continue - k == :ns && (tournament_selection_n = kws[k]; true) && continue - k == :loss && (elementwise_loss = kws[k]; true) && continue - if k == :llm_options - llm_options = kws[k] - continue - end - if k == :mutationWeights - if typeof(kws[k]) <: AbstractVector - _mutation_weights = kws[k] - if length(_mutation_weights) < length(mutations) - # Pad with zeros: - _mutation_weights = vcat( - _mutation_weights, - zeros(length(mutations) - length(_mutation_weights)) - ) - end - mutation_weights = MutationWeights(_mutation_weights...) - else - mutation_weights = kws[k] - end - continue - end - #! format: on - error( - "Unknown deprecated keyword argument: $k. Please update `Options(;)` to transfer this key.", - ) - end - fast_cycle && Base.depwarn("`fast_cycle` is deprecated and has no effect.", :Options) - if npop !== nothing - Base.depwarn("`npop` is deprecated. Use `population_size` instead.", :Options) - population_size = npop - end - if npopulations !== nothing - Base.depwarn("`npopulations` is deprecated. Use `populations` instead.", :Options) - populations = npopulations - end - if optimizer_algorithm isa AbstractString - Base.depwarn( - "The `optimizer_algorithm` argument should be an `AbstractOptimizer`, not a string.", - :Options, - ) - optimizer_algorithm = if optimizer_algorithm == "NelderMead" - Optim.NelderMead(; linesearch=LineSearches.BackTracking()) - else - Optim.BFGS(; linesearch=LineSearches.BackTracking()) - end - end - - if elementwise_loss === nothing - elementwise_loss = L2DistLoss() - else - if loss_function !== nothing - error("You cannot specify both `elementwise_loss` and `loss_function`.") - end - end - - if should_simplify === nothing - should_simplify = ( - loss_function === nothing && - nested_constraints === nothing && - constraints === nothing && - bin_constraints === nothing && - una_constraints === nothing - ) - end - - is_testing = parse(Bool, get(ENV, "SYMBOLIC_REGRESSION_IS_TESTING", "false")) - - if output_file === nothing - # "%Y-%m-%d_%H%M%S.%f" - date_time_str = Dates.format(Dates.now(), "yyyy-mm-dd_HHMMSS.sss") - output_file = "hall_of_fame_" * date_time_str * ".csv" - if is_testing - tmpdir = mktempdir() - output_file = joinpath(tmpdir, output_file) - end - end - - @assert maxsize > 3 - @assert warmup_maxsize_by >= 0.0f0 - @assert length(unary_operators) <= max_ops - @assert length(binary_operators) <= max_ops - - # Make sure nested_constraints contains functions within our operator set: - _nested_constraints = build_nested_constraints(; - binary_operators, unary_operators, nested_constraints - ) - - if typeof(constraints) <: Tuple - constraints = collect(constraints) - end - if constraints !== nothing - @assert bin_constraints === nothing - @assert una_constraints === nothing - # TODO: This is redundant with the checks in equation_search - for op in binary_operators - @assert !(op in unary_operators) - end - for op in unary_operators - @assert !(op in binary_operators) - end - bin_constraints = constraints - una_constraints = constraints - end - - _una_constraints, _bin_constraints = build_constraints(; - una_constraints, bin_constraints, unary_operators, binary_operators - ) - - complexity_mapping = ComplexityMapping( - complexity_of_operators, - complexity_of_variables, - complexity_of_constants, - binary_operators, - unary_operators, - ) - - if maxdepth === nothing - maxdepth = maxsize - end - - if define_helper_functions - # We call here so that mapped operators, like ^ - # are correctly overloaded, rather than overloading - # operators like "safe_pow", etc. - OperatorEnum(; - binary_operators=binary_operators, - unary_operators=unary_operators, - define_helper_functions=true, - empty_old_operators=true, - ) - end - - binary_operators = map(binopmap, binary_operators) - unary_operators = map(unaopmap, unary_operators) - - operators = OperatorEnum(; - binary_operators=binary_operators, - unary_operators=unary_operators, - define_helper_functions=define_helper_functions, - empty_old_operators=false, - ) - - early_stop_condition = if typeof(early_stop_condition) <: Real - # Need to make explicit copy here for this to work: - stopping_point = Float64(early_stop_condition) - (loss, complexity) -> loss < stopping_point - else - early_stop_condition - end - - # Parse optimizer options - if !isa(optimizer_options, Optim.Options) - optimizer_iterations = isnothing(optimizer_iterations) ? 8 : optimizer_iterations - optimizer_f_calls_limit = if isnothing(optimizer_f_calls_limit) - 10_000 - else - optimizer_f_calls_limit - end - extra_kws = hasfield(Optim.Options, :show_warnings) ? (; show_warnings=false) : () - optimizer_options = Optim.Options(; - iterations=optimizer_iterations, - f_calls_limit=optimizer_f_calls_limit, - extra_kws..., - (isnothing(optimizer_options) ? () : optimizer_options)..., - ) - else - @assert optimizer_iterations === nothing && optimizer_f_calls_limit === nothing - end - if hasfield(Optim.Options, :show_warnings) && optimizer_options.show_warnings - @warn "Optimizer warnings are turned on. This might result in a lot of warnings being printed from NaNs, as these are common during symbolic regression" - end - - set_mutation_weights = create_mutation_weights(mutation_weights) - set_llm_options = create_llm_options(llm_options) - validate_llm_options(set_llm_options) - - @assert print_precision > 0 - - _autodiff_backend = if autodiff_backend isa Union{Nothing,AbstractADType} - autodiff_backend - else - ADTypes.Auto(autodiff_backend) - end - - options = Options{ - typeof(complexity_mapping), - operator_specialization(typeof(operators)), - node_type, - expression_type, - typeof(expression_options), - turbo, - bumper, - deprecated_return_state, - typeof(_autodiff_backend), - }( - operators, - _bin_constraints, - _una_constraints, - complexity_mapping, - tournament_selection_n, - tournament_selection_p, - parsimony, - dimensional_constraint_penalty, - dimensionless_constants_only, - alpha, - maxsize, - maxdepth, - Val(turbo), - Val(bumper), - migration, - hof_migration, - should_simplify, - should_optimize_constants, - output_file, - populations, - perturbation_factor, - annealing, - batching, - batch_size, - set_mutation_weights, - set_llm_options, - crossover_probability, - warmup_maxsize_by, - use_frequency, - use_frequency_in_tournament, - adaptive_parsimony_scaling, - population_size, - ncycles_per_iteration, - fraction_replaced, - fraction_replaced_hof, - topn, - verbosity, - print_precision, - save_to_file, - probability_negate_constant, - length(unary_operators), - length(binary_operators), - seed, - elementwise_loss, - loss_function, - node_type, - expression_type, - expression_options, - progress, - terminal_width, - optimizer_algorithm, - optimizer_probability, - optimizer_nrestarts, - optimizer_options, - _autodiff_backend, - recorder_file, - tournament_selection_p, - early_stop_condition, - Val(deprecated_return_state), - timeout_in_seconds, - max_evals, - skip_mutation_failures, - _nested_constraints, - deterministic, - define_helper_functions, - use_recorder, - ) - - return options -end - -end diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl deleted file mode 100644 index c0bb73130..000000000 --- a/src/OptionsStruct.jl +++ /dev/null @@ -1,243 +0,0 @@ -module OptionsStructModule - -using DispatchDoctor: @unstable -using Optim: Optim -using DynamicExpressions: - AbstractOperatorEnum, AbstractExpressionNode, AbstractExpression, OperatorEnum -using LossFunctions: SupervisedLoss - -import ..MutationWeightsModule: MutationWeights -import ..LLMOptionsModule: LLMOptions - -""" -This struct defines how complexity is calculated. - -# Fields -- `use`: Shortcut indicating whether we use custom complexities, - or just use 1 for everything. -- `binop_complexities`: Complexity of each binary operator. -- `unaop_complexities`: Complexity of each unary operator. -- `variable_complexity`: Complexity of using a variable. -- `constant_complexity`: Complexity of using a constant. -""" -struct ComplexityMapping{T<:Real,VC<:Union{T,AbstractVector{T}}} - use::Bool - binop_complexities::Vector{T} - unaop_complexities::Vector{T} - variable_complexity::VC - constant_complexity::T -end - -Base.eltype(::ComplexityMapping{T}) where {T} = T - -"""Promote type when defining complexity mapping.""" -function ComplexityMapping(; - binop_complexities::Vector{T1}, - unaop_complexities::Vector{T2}, - variable_complexity::Union{T3,AbstractVector{T3}}, - constant_complexity::T4, -) where {T1<:Real,T2<:Real,T3<:Real,T4<:Real} - T = promote_type(T1, T2, T3, T4) - vc = map(T, variable_complexity) - return ComplexityMapping{T,typeof(vc)}( - true, - map(T, binop_complexities), - map(T, unaop_complexities), - vc, - T(constant_complexity), - ) -end - -function ComplexityMapping( - ::Nothing, ::Nothing, ::Nothing, binary_operators, unary_operators -) - # If no customization provided, then we simply - # turn off the complexity mapping - use = false - return ComplexityMapping{Int,Int}(use, zeros(Int, 0), zeros(Int, 0), 0, 0) -end -function ComplexityMapping( - complexity_of_operators, - complexity_of_variables, - complexity_of_constants, - binary_operators, - unary_operators, -) - _complexity_of_operators = if complexity_of_operators === nothing - Dict{Function,Int64}() - else - # Convert to dict: - Dict(complexity_of_operators) - end - - VAR_T = if (complexity_of_variables !== nothing) - if complexity_of_variables isa AbstractVector - eltype(complexity_of_variables) - else - typeof(complexity_of_variables) - end - else - Int - end - CONST_T = if (complexity_of_constants !== nothing) - typeof(complexity_of_constants) - else - Int - end - OP_T = eltype(_complexity_of_operators).parameters[2] - - T = promote_type(VAR_T, CONST_T, OP_T) - - # If not in dict, then just set it to 1. - binop_complexities = T[ - (haskey(_complexity_of_operators, op) ? _complexity_of_operators[op] : one(T)) # - for op in binary_operators - ] - unaop_complexities = T[ - (haskey(_complexity_of_operators, op) ? _complexity_of_operators[op] : one(T)) # - for op in unary_operators - ] - - variable_complexity = if complexity_of_variables !== nothing - map(T, complexity_of_variables) - else - one(T) - end - constant_complexity = if complexity_of_constants !== nothing - map(T, complexity_of_constants) - else - one(T) - end - - return ComplexityMapping(; - binop_complexities, unaop_complexities, variable_complexity, constant_complexity - ) -end - -# Controls level of specialization we compile -function operator_specialization end -if VERSION >= v"1.10.0-DEV.0" - @eval operator_specialization(::Type{<:OperatorEnum}) = OperatorEnum -else - @eval operator_specialization(O::Type{<:OperatorEnum}) = O -end - -struct Options{ - CM<:ComplexityMapping, - OP<:AbstractOperatorEnum, - N<:AbstractExpressionNode, - E<:AbstractExpression, - EO<:NamedTuple, - _turbo, - _bumper, - _return_state, - AD, -} - operators::OP - bin_constraints::Vector{Tuple{Int,Int}} - una_constraints::Vector{Int} - complexity_mapping::CM - tournament_selection_n::Int - tournament_selection_p::Float32 - parsimony::Float32 - dimensional_constraint_penalty::Union{Float32,Nothing} - dimensionless_constants_only::Bool - alpha::Float32 - maxsize::Int - maxdepth::Int - turbo::Val{_turbo} - bumper::Val{_bumper} - migration::Bool - hof_migration::Bool - should_simplify::Bool - should_optimize_constants::Bool - output_file::String - populations::Int - perturbation_factor::Float32 - annealing::Bool - batching::Bool - batch_size::Int - mutation_weights::MutationWeights - llm_options::LLMOptions - crossover_probability::Float32 - warmup_maxsize_by::Float32 - use_frequency::Bool - use_frequency_in_tournament::Bool - adaptive_parsimony_scaling::Float64 - population_size::Int - ncycles_per_iteration::Int - fraction_replaced::Float32 - fraction_replaced_hof::Float32 - topn::Int - verbosity::Union{Int,Nothing} - print_precision::Int - save_to_file::Bool - probability_negate_constant::Float32 - nuna::Int - nbin::Int - seed::Union{Int,Nothing} - elementwise_loss::Union{SupervisedLoss,Function} - loss_function::Union{Nothing,Function} - node_type::Type{N} - expression_type::Type{E} - expression_options::EO - progress::Union{Bool,Nothing} - terminal_width::Union{Int,Nothing} - optimizer_algorithm::Optim.AbstractOptimizer - optimizer_probability::Float32 - optimizer_nrestarts::Int - optimizer_options::Optim.Options - autodiff_backend::AD - recorder_file::String - prob_pick_first::Float32 - early_stop_condition::Union{Function,Nothing} - return_state::Val{_return_state} - timeout_in_seconds::Union{Float64,Nothing} - max_evals::Union{Int,Nothing} - skip_mutation_failures::Bool - nested_constraints::Union{Vector{Tuple{Int,Int,Vector{Tuple{Int,Int,Int}}}},Nothing} - deterministic::Bool - define_helper_functions::Bool - use_recorder::Bool -end - -function Base.print(io::IO, options::Options) - return print( - io, - "Options(" * - "binops=$(options.operators.binops), " * - "unaops=$(options.operators.unaops), " - # Fill in remaining fields automatically: - * - join( - [ - if fieldname in (:optimizer_options, :mutation_weights, :llm_options) - "$(fieldname)=..." - else - "$(fieldname)=$(getfield(options, fieldname))" - end for - fieldname in fieldnames(Options) if fieldname ∉ [:operators, :nuna, :nbin] - ], - ", ", - ) * - ")", - ) -end -Base.show(io::IO, ::MIME"text/plain", options::Options) = Base.print(io, options) - -@unstable function specialized_options(options::Options) - return _specialized_options(options) -end -@generated function _specialized_options(options::O) where {O<:Options} - # Return an options struct with concrete operators - type_parameters = O.parameters - fields = Any[:(getfield(options, $(QuoteNode(k)))) for k in fieldnames(O)] - quote - operators = getfield(options, :operators) - Options{$(type_parameters[1]),typeof(operators),$(type_parameters[3:end]...)}( - $(fields...) - ) - end -end - -end diff --git a/src/Parse.jl b/src/Parse.jl new file mode 100644 index 000000000..74cc87121 --- /dev/null +++ b/src/Parse.jl @@ -0,0 +1,141 @@ +module ParseModule + +using DispatchDoctor: @unstable +using DynamicExpressions +using .DynamicExpressions.NodeModule: Node +using SymbolicRegression: AbstractOptions, DATA_TYPE + +""" + parse_expr(expr_str::String, options) -> AbstractExpressionNode + +Given a string (e.g., from `string_tree`) and an options object (containing +operators, variable naming conventions, etc.), reconstruct an +AbstractExpressionNode. +""" +function parse_expr( + expr_str::String, options::AbstractOptions, ::Type{T} +) where {T<:DATA_TYPE} + parsed = Meta.parse(expr_str) + return _parse_expr(parsed, options, T) +end + +function _make_constant_node(val::Number, ::Type{T}) where {T<:DATA_TYPE} + return Node{T}(; val=float(val)) +end + +function _make_variable_node(sym::Symbol, ::Type{T}) where {T<:DATA_TYPE} + local idx = parse(Int, String(sym)[2:end]) # e.g. x5 => 5 + return Node{T}(; feature=idx) +end + +function _make_call_node(ex::Expr, options::AbstractOptions, ::Type{T}) where {T<:DATA_TYPE} + op_sym = ex.args[1] + argexprs = ex.args[2:end] + children = map(child_ex -> _parse_expr(child_ex, options, T), argexprs) + op_idx = _find_operator_index(op_sym, options) + if length(children) == 1 + return Node{T}(; op=op_idx, l=children[1]) + elseif length(children) == 2 + return Node{T}(; op=op_idx, l=children[1], r=children[2]) + else + error("Operator with $(length(children)) children not supported.") + end +end + +function _render_function( + fn::F, function_str_map::Dict{S,S} +)::String where {F<:Function,S<:AbstractString} + fn_str = replace(string(fn), "safe_" => "") + if haskey(function_str_map, fn_str) + return function_str_map[fn_str] + end + return fn_str +end + +@unstable function _find_operator_index(op_sym, options::AbstractOptions) + function_str_map = Dict("pow" => "^") + binops = map((x) -> _render_function(x, function_str_map), options.operators.binops) + unaops = map((x) -> _render_function(x, function_str_map), options.operators.unaops) + + for (i, opfunc) in pairs(binops) + if opfunc == string(op_sym) + return UInt8(i) + end + end + + for (i, opfunc) in pairs(unaops) + if opfunc == string(op_sym) + return UInt8(i) + end + end + + return error("Unrecognized operator symbol: $op_sym") +end + +@unstable function _parse_expr(ex, options::AbstractOptions, ::Type{T}) where {T<:DATA_TYPE} + if ex isa Number + return _make_constant_node(ex, T) + elseif ex isa Symbol + return _make_variable_node(ex, T) + elseif ex isa Expr + if ex.head === :call + return _make_call_node(ex, options, T) + elseif ex.head === :negative + # If we see something like -(3.14), + # parse it as (0 - 3.14). + return _parse_expr(Expr(:call, :-, 0, ex.args[1]), options, T) + else + error("Unsupported expression head: $(ex.head)") + end + else + error("Unsupported expression: $(ex)") + end +end + +@unstable function _sketch_const(val) + does_not_need_brackets = (typeof(val) <: Union{Real,AbstractArray}) + + if does_not_need_brackets + if isinteger(val) && (abs(val) < 5) # don't abstract integer constants from -4 to 4, useful for exponents + string(val) + else + "C" + end + else + if isinteger(val) && (abs(val) < 5) # don't abstract integer constants from -4 to 4, useful for exponents + "(" * string(val) * ")" + else + "(C)" + end + end +end + +""" + render_expr(ex::AbstractExpression{T}, options::AbstractOptions) -> String + +Given an AbstractExpression and an options object, return a string representation +of the expression. Specifically, replace constants with "C" and variables with +"x", "y", "z", etc or the prespecified variable names. +""" +function render_expr( + ex::AbstractExpression{T}, options::AbstractOptions +)::String where {T<:DATA_TYPE} + return render_expr(get_contents(ex), options) +end + +function render_expr(tree::AbstractExpressionNode{T}, options)::String where {T<:DATA_TYPE} + variable_names = get_variable_names(options.variable_names) + return string_tree( + tree, options.operators; f_constant=_sketch_const, variable_names=variable_names + ) +end + +function get_variable_names(variable_names::Dict)::Vector{String} + return [variable_names[key] for key in sort(collect(keys(variable_names)))] +end + +function get_variable_names(variable_names::Nothing)::Vector{String} + return ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] +end + +end diff --git a/src/PopMember.jl b/src/PopMember.jl deleted file mode 100644 index 84f29f451..000000000 --- a/src/PopMember.jl +++ /dev/null @@ -1,166 +0,0 @@ -module PopMemberModule - -using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, create_expression -import ..ComplexityModule: compute_complexity -using ..UtilsModule: get_birth_order -using ..LossFunctionsModule: score_func - -# Define a member of population by equation, score, and age -mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - tree::N - score::L # Inludes complexity penalty, normalization - loss::L # Raw loss - birth::Int - complexity::Int - - # For recording history: - ref::Int - parent::Int -end -function Base.setproperty!(member::PopMember, field::Symbol, value) - field == :complexity && throw( - error("Don't set `.complexity` directly. Use `recompute_complexity!` instead.") - ) - field == :tree && setfield!(member, :complexity, -1) - return setfield!(member, field, value) -end -@unstable @inline function Base.getproperty(member::PopMember, field::Symbol) - field == :complexity && throw( - error("Don't access `.complexity` directly. Use `compute_complexity` instead.") - ) - return getfield(member, field) -end -function Base.show(io::IO, p::PopMember{T,L,N}) where {T,L,N} - shower(x) = sprint(show, x) - print(io, "PopMember(") - print(io, "tree = (", string_tree(p.tree), "), ") - print(io, "loss = ", shower(p.loss), ", ") - print(io, "score = ", shower(p.score)) - print(io, ")") - return nothing -end - -generate_reference() = abs(rand(Int)) - -""" - PopMember(t::AbstractExpression{T}, score::L, loss::L) - -Create a population member with a birth date at the current time. -The type of the `Node` may be different from the type of the score -and loss. - -# Arguments - -- `t::AbstractExpression{T}`: The tree for the population member. -- `score::L`: The score (normalized to a baseline, and offset by a complexity penalty) -- `loss::L`: The raw loss to assign. -""" -function PopMember( - t::AbstractExpression{T}, - score::L, - loss::L, - options::Union{Options,Nothing}=nothing, - complexity::Union{Int,Nothing}=nothing; - ref::Int=-1, - parent::Int=-1, - deterministic=nothing, -) where {T<:DATA_TYPE,L<:LOSS_TYPE} - if ref == -1 - ref = generate_reference() - end - if !(deterministic isa Bool) - throw( - ArgumentError( - "You must declare `deterministic` as `true` or `false`, it cannot be left undefined.", - ), - ) - end - complexity = complexity === nothing ? -1 : complexity - return PopMember{T,L,typeof(t)}( - t, - score, - loss, - get_birth_order(; deterministic=deterministic), - complexity, - ref, - parent, - ) -end - -""" - PopMember( - dataset::Dataset{T,L}, - t::AbstractExpression{T}, - options::Options - ) - -Create a population member with a birth date at the current time. -Automatically compute the score for this tree. - -# Arguments - -- `dataset::Dataset{T,L}`: The dataset to evaluate the tree on. -- `t::AbstractExpression{T}`: The tree for the population member. -- `options::Options`: What options to use. -""" -function PopMember( - dataset::Dataset{T,L}, - tree::Union{AbstractExpressionNode{T},AbstractExpression{T}}, - options::Options, - complexity::Union{Int,Nothing}=nothing; - ref::Int=-1, - parent::Int=-1, - deterministic=nothing, -) where {T<:DATA_TYPE,L<:LOSS_TYPE} - ex = create_expression(tree, options, dataset) - set_complexity = complexity === nothing ? compute_complexity(ex, options) : complexity - @assert set_complexity != -1 - score, loss = score_func(dataset, ex, options; complexity=set_complexity) - return PopMember( - ex, - score, - loss, - options, - set_complexity; - ref=ref, - parent=parent, - deterministic=deterministic, - ) -end - -function Base.copy(p::P) where {P<:PopMember} - tree = copy(p.tree) - score = copy(p.score) - loss = copy(p.loss) - birth = copy(p.birth) - complexity = copy(getfield(p, :complexity)) - ref = copy(p.ref) - parent = copy(p.parent) - return P(tree, score, loss, birth, complexity, ref, parent) -end - -function reset_birth!(p::PopMember; deterministic::Bool) - p.birth = get_birth_order(; deterministic) - return p -end - -# Can read off complexity directly from pop members -function compute_complexity( - member::PopMember, options::Options; break_sharing=Val(false) -)::Int - complexity = getfield(member, :complexity) - complexity == -1 && return recompute_complexity!(member, options; break_sharing) - # TODO: Turn this into a warning, and then return normal compute_complexity instead. - return complexity -end -function recompute_complexity!( - member::PopMember, options::Options; break_sharing=Val(false) -)::Int - complexity = compute_complexity(member.tree, options; break_sharing) - setfield!(member, :complexity, complexity) - return complexity -end - -end diff --git a/src/Population.jl b/src/Population.jl deleted file mode 100644 index 75d0b75c2..000000000 --- a/src/Population.jl +++ /dev/null @@ -1,236 +0,0 @@ -module PopulationModule - -using StatsBase: StatsBase -using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpression, string_tree -using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE -using ..ComplexityModule: compute_complexity -using ..LossFunctionsModule: score_func, update_baseline_loss! -using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..MutationFunctionsModule: gen_random_tree -using ..LLMFunctionsModule: gen_llm_random_tree -using ..PopMemberModule: PopMember -using ..UtilsModule: bottomk_fast, argmin_fast, PerThreadCache -# A list of members of the population, with easy constructors, -# which allow for random generation of new populations -struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} - members::Array{PopMember{T,L,N},1} - n::Int -end -""" - Population(pop::Array{PopMember{T,L}, 1}) - -Create population from list of PopMembers. -""" -function Population(pop::Vector{<:PopMember}) - return Population(pop, size(pop, 1)) -end - -@unstable function gen_random_tree_pop( - nlength::Int, - options::Options, - nfeatures::Int, - ::Type{T}, - idea_database::Union{Vector{String},Nothing}, -) where {T<:DATA_TYPE} - if options.llm_options.active && (rand() < options.llm_options.weights.llm_gen_random) - gen_llm_random_tree(nlength, options, nfeatures, T, idea_database) - else - gen_random_tree(nlength, options, nfeatures, T) - end -end - -""" - Population(dataset::Dataset{T,L}; - population_size, nlength::Int=3, options::Options, - nfeatures::Int, idea_database::Vector{String}) - -Create random population with LLM and RNG and score them on the dataset. -""" -@unstable function Population( - dataset::Dataset{T,L}; - options::Options, - population_size=nothing, - nlength::Int=3, - nfeatures::Int, - npop=nothing, - idea_database::Union{Vector{String},Nothing}=nothing, -) where {T,L} - @assert (population_size !== nothing) ⊻ (npop !== nothing) - population_size = if npop === nothing - population_size - else - npop - end - return Population( - [ - PopMember( - dataset, - gen_random_tree_pop(nlength, options, nfeatures, T, idea_database), - options; - parent=-1, - deterministic=options.deterministic, - ) for _ in 1:population_size - ], - population_size, - ) -end -""" - Population(X::AbstractMatrix{T}, y::AbstractVector{T}; - population_size, nlength::Int=3, - options::Options, nfeatures::Int, - loss_type::Type=Nothing) - -Create random population and score them on the dataset. -""" -@unstable function Population( - X::AbstractMatrix{T}, - y::AbstractVector{T}; - population_size=nothing, - nlength::Int=3, - options::Options, - nfeatures::Int, - loss_type::Type{L}=Nothing, - npop=nothing, -) where {T<:DATA_TYPE,L} - @assert (population_size !== nothing) ⊻ (npop !== nothing) - population_size = if npop === nothing - population_size - else - npop - end - dataset = Dataset(X, y, L) - update_baseline_loss!(dataset, options) - return Population( - dataset; population_size=population_size, options=options, nfeatures=nfeatures - ) -end - -function Base.copy(pop::P)::P where {T,L,N,P<:Population{T,L,N}} - copied_members = Vector{PopMember{T,L,N}}(undef, pop.n) - Threads.@threads for i in 1:(pop.n) - copied_members[i] = copy(pop.members[i]) - end - return Population(copied_members) -end - -# Sample random members of the population, and make a new one -function sample_pop(pop::P, options::Options)::P where {P<:Population} - return Population( - StatsBase.sample(pop.members, options.tournament_selection_n; replace=false) - ) -end - -# Sample the population, and get the best member from that sample -function best_of_sample( - pop::Population{T,L,N}, - running_search_statistics::RunningSearchStatistics, - options::Options, -) where {T,L,N} - sample = sample_pop(pop, options) - return _best_of_sample( - sample.members, running_search_statistics, options - )::PopMember{T,L,N} -end -function _best_of_sample( - members::Vector{P}, running_search_statistics::RunningSearchStatistics, options::Options -) where {T,L,P<:PopMember{T,L}} - p = options.tournament_selection_p - n = length(members) # == tournament_selection_n - scores = Vector{L}(undef, n) - if options.use_frequency_in_tournament - # Score based on frequency of that size occurring. - # In the end, all sizes should be just as common in the population. - adaptive_parsimony_scaling = L(options.adaptive_parsimony_scaling) - # e.g., for 100% occupied at one size, exp(-20*1) = 2.061153622438558e-9; which seems like a good punishment for dominating the population. - - for i in 1:n - member = members[i] - size = compute_complexity(member, options) - frequency = if (0 < size <= options.maxsize) - L(running_search_statistics.normalized_frequencies[size]) - else - L(0) - end - scores[i] = member.score * exp(adaptive_parsimony_scaling * frequency) - end - else - map!(member -> member.score, scores, members) - end - - chosen_idx = if p == 1.0 - argmin_fast(scores) - else - # First, decide what place we take (usually 1st place wins): - tournament_winner = StatsBase.sample(get_tournament_selection_weights(options)) - # Then, find the member that won that place, given - # their fitness: - if tournament_winner == 1 - argmin_fast(scores) - else - bottomk_fast(scores, tournament_winner)[2][end] - end - end - return members[chosen_idx] -end - -const CACHED_WEIGHTS = - let init_k = collect(0:5), - init_prob_each = 0.5f0 * (1 - 0.5f0) .^ init_k, - test_weights = StatsBase.Weights(init_prob_each, sum(init_prob_each)) - - PerThreadCache{Dict{Tuple{Int,Float32},typeof(test_weights)}}() - end - -@unstable function get_tournament_selection_weights(@nospecialize(options::Options)) - n = options.tournament_selection_n - p = options.tournament_selection_p - # Computing the weights for the tournament becomes quite expensive, - return get!(CACHED_WEIGHTS, (n, p)) do - k = collect(0:(n - 1)) - prob_each = p * ((1 - p) .^ k) - - return StatsBase.Weights(prob_each, sum(prob_each)) - end -end - -function finalize_scores( - dataset::Dataset{T,L}, pop::P, options::Options -)::Tuple{P,Float64} where {T,L,P<:Population{T,L}} - need_recalculate = options.batching - num_evals = 0.0 - if need_recalculate - for member in 1:(pop.n) - score, loss = score_func(dataset, pop.members[member], options) - pop.members[member].score = score - pop.members[member].loss = loss - end - num_evals += pop.n - end - return (pop, num_evals) -end - -# Return best 10 examples -function best_sub_pop(pop::P; topn::Int=10)::P where {P<:Population} - best_idx = sortperm([pop.members[member].score for member in 1:(pop.n)]) - return Population(pop.members[best_idx[1:topn]]) -end - -function record_population(pop::Population, options::Options)::RecordType - return RecordType( - "population" => [ - RecordType( - "tree" => string_tree(member.tree, options), - "loss" => member.loss, - "score" => member.score, - "complexity" => compute_complexity(member, options), - "birth" => member.birth, - "ref" => member.ref, - "parent" => member.parent, - ) for member in pop.members - ], - "time" => time(), - ) -end - -end diff --git a/src/ProgramConstants.jl b/src/ProgramConstants.jl deleted file mode 100644 index 607ce08b2..000000000 --- a/src/ProgramConstants.jl +++ /dev/null @@ -1,11 +0,0 @@ -module ProgramConstantsModule - -const MAX_DEGREE = 2 -const BATCH_DIM = 2 -const FEATURE_DIM = 1 -const RecordType = Dict{String,Any} - -const DATA_TYPE = Number -const LOSS_TYPE = Real - -end diff --git a/src/ProgressBars.jl b/src/ProgressBars.jl deleted file mode 100644 index 5a1f3fe6e..000000000 --- a/src/ProgressBars.jl +++ /dev/null @@ -1,36 +0,0 @@ -module ProgressBarsModule - -using ProgressBars: ProgressBar, set_multiline_postfix - -# Simple wrapper for a progress bar which stores its own state -mutable struct WrappedProgressBar - bar::ProgressBar - state::Union{Int,Nothing} - cycle::Union{Int,Nothing} - - function WrappedProgressBar(args...; kwargs...) - if haskey(ENV, "SYMBOLIC_REGRESSION_TEST") && - ENV["SYMBOLIC_REGRESSION_TEST"] == "true" - output_stream = devnull - return new(ProgressBar(args...; output_stream, kwargs...), nothing, nothing) - end - return new(ProgressBar(args...; kwargs...), nothing, nothing) - end -end - -"""Iterate a progress bar without needing to store cycle/state externally.""" -function manually_iterate!(pbar::WrappedProgressBar) - cur_cycle = pbar.cycle - if cur_cycle === nothing - pbar.cycle, pbar.state = iterate(pbar.bar) - else - pbar.cycle, pbar.state = iterate(pbar.bar, pbar.state) - end - return nothing -end - -function set_multiline_postfix!(t::WrappedProgressBar, postfix::AbstractString) - return set_multiline_postfix(t.bar, postfix) -end - -end diff --git a/src/README.md b/src/README.md deleted file mode 100644 index 9e661e555..000000000 --- a/src/README.md +++ /dev/null @@ -1,2 +0,0 @@ -If you are looking for the main loop, start with `function _equation_search` in `SymbolicRegression.jl`. You can proceed from there. -All functions are imported at the top using `import {filename}Module` syntax, which should help you navigate the codebase. diff --git a/src/Recorder.jl b/src/Recorder.jl deleted file mode 100644 index a25ac0e78..000000000 --- a/src/Recorder.jl +++ /dev/null @@ -1,22 +0,0 @@ -module RecorderModule - -using ..CoreModule: RecordType - -"Assumes that `options` holds the user options::Options" -macro recorder(ex) - quote - if $(esc(:options)).use_recorder - $(esc(ex)) - end - end -end - -function find_iteration_from_record(key::String, record::RecordType) - iteration = 0 - while haskey(record[key], "iteration$(iteration)") - iteration += 1 - end - return iteration - 1 -end - -end diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl deleted file mode 100644 index b2a08c990..000000000 --- a/src/RegularizedEvolution.jl +++ /dev/null @@ -1,121 +0,0 @@ -module RegularizedEvolutionModule - -using DynamicExpressions: string_tree -using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE -using ..PopMemberModule: PopMember -using ..PopulationModule: Population, best_of_sample -using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..MutateModule: next_generation, crossover_generation -using ..RecorderModule: @recorder -using ..UtilsModule: argmin_fast - -# Pass through the population several times, replacing the oldest -# with the fittest of a small subsample -function reg_evol_cycle( - dataset::Dataset{T,L}, - pop::P, - temperature, - curmaxsize::Int, - running_search_statistics::RunningSearchStatistics, - options::Options, - record::RecordType; - dominating=nothing, - idea_database=nothing, -)::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:Population{T,L}} - # Batch over each subsample. Can give 15% improvement in speed; probably moreso for large pops. - # but is ultimately a different algorithm than regularized evolution, and might not be - # as good. - if options.crossover_probability > 0.0 - @recorder error("You cannot have the recorder on when using crossover") - end - - num_evals = 0.0 - n_evol_cycles = ceil(Int, pop.n / options.tournament_selection_n) - - for i in 1:n_evol_cycles - if rand() > options.crossover_probability - allstar = best_of_sample(pop, running_search_statistics, options) - mutation_recorder = RecordType() - baby, mutation_accepted, tmp_num_evals = next_generation( - dataset, - allstar, - temperature, - curmaxsize, - running_search_statistics, - options; - tmp_recorder=mutation_recorder, - dominating=dominating, - idea_database=idea_database, - ) - num_evals += tmp_num_evals - - if !mutation_accepted && options.skip_mutation_failures - # Skip this mutation rather than replacing oldest member with unchanged member - continue - end - - oldest = argmin_fast([pop.members[member].birth for member in 1:(pop.n)]) - - @recorder begin - if !haskey(record, "mutations") - record["mutations"] = RecordType() - end - for member in [allstar, baby, pop.members[oldest]] - if !haskey(record["mutations"], "$(member.ref)") - record["mutations"]["$(member.ref)"] = RecordType( - "events" => Vector{RecordType}(), - "tree" => string_tree(member.tree, options), - "score" => member.score, - "loss" => member.loss, - "parent" => member.parent, - ) - end - end - mutate_event = RecordType( - "type" => "mutate", - "time" => time(), - "child" => baby.ref, - "mutation" => mutation_recorder, - ) - death_event = RecordType("type" => "death", "time" => time()) - - # Put in random key rather than vector; otherwise there are collisions! - push!(record["mutations"]["$(allstar.ref)"]["events"], mutate_event) - push!( - record["mutations"]["$(pop.members[oldest].ref)"]["events"], death_event - ) - end - - pop.members[oldest] = baby - - else # Crossover - allstar1 = best_of_sample(pop, running_search_statistics, options) - allstar2 = best_of_sample(pop, running_search_statistics, options) - - baby1, baby2, crossover_accepted, tmp_num_evals = crossover_generation( - allstar1, - allstar2, - dataset, - curmaxsize, - options; - dominating=dominating, - idea_database=idea_database, - ) - num_evals += tmp_num_evals - - if !crossover_accepted && options.skip_mutation_failures - continue - end - - # Replace old members with new ones: - oldest = argmin_fast([pop.members[member].birth for member in 1:(pop.n)]) - pop.members[oldest] = baby1 - oldest = argmin_fast([pop.members[member].birth for member in 1:(pop.n)]) - pop.members[oldest] = baby2 - end - end - - return (pop, num_evals) -end - -end diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl deleted file mode 100644 index ff57e0eeb..000000000 --- a/src/SearchUtils.jl +++ /dev/null @@ -1,522 +0,0 @@ -"""Functions to help with the main loop of LibraryAugmentedSymbolicRegression.jl. - -This includes: process management, stdin reading, checking for early stops.""" -module SearchUtilsModule - -using Printf: @printf, @sprintf -using Distributed: Distributed, @spawnat, Future, procs -using StatsBase: mean -using DispatchDoctor: @unstable - -using DynamicExpressions: AbstractExpression, string_tree -using ..UtilsModule: subscriptify -using ..CoreModule: Dataset, Options, MAX_DEGREE, RecordType -using ..ComplexityModule: compute_complexity -using ..PopulationModule: Population -using ..PopMemberModule: PopMember -using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve -using ..ProgressBarsModule: WrappedProgressBar, set_multiline_postfix!, manually_iterate! -using ..AdaptiveParsimonyModule: RunningSearchStatistics - -""" - RuntimeOptions{N,PARALLELISM,DIM_OUT,RETURN_STATE} - -Parameters for a search that are passed to `equation_search` directly, -rather than set within `Options`. This is to differentiate between -parameters that relate to processing and the duration of the search, -and parameters dealing with the search hyperparameters itself. -""" -Base.@kwdef struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE} - niterations::Int64 - total_cycles::Int64 - numprocs::Int64 - init_procs::Union{Vector{Int},Nothing} - addprocs_function::Function - exeflags::Cmd - runtests::Bool - verbosity::Int64 - progress::Bool - parallelism::Val{PARALLELISM} - dim_out::Val{DIM_OUT} - return_state::Val{RETURN_STATE} -end -@unstable @inline function Base.getproperty( - roptions::RuntimeOptions{P,D,R}, name::Symbol -) where {P,D,R} - if name == :parallelism - return P - elseif name == :dim_out - return D - elseif name == :return_state - return R - else - getfield(roptions, name) - end -end -function Base.propertynames(roptions::RuntimeOptions) - return (Base.fieldnames(typeof(roptions))..., :parallelism, :dim_out, :return_state) -end - -"""A simple dictionary to track worker allocations.""" -const WorkerAssignments = Dict{Tuple{Int,Int},Int} - -function next_worker(worker_assignment::WorkerAssignments, procs::Vector{Int})::Int - job_counts = Dict(proc => 0 for proc in procs) - for (key, value) in worker_assignment - @assert haskey(job_counts, value) - job_counts[value] += 1 - end - least_busy_worker = reduce( - (proc1, proc2) -> (job_counts[proc1] <= job_counts[proc2] ? proc1 : proc2), procs - ) - return least_busy_worker -end - -function assign_next_worker!( - worker_assignment::WorkerAssignments; pop, out, parallelism, procs -)::Int - if parallelism == :multiprocessing - worker_idx = next_worker(worker_assignment, procs) - worker_assignment[(out, pop)] = worker_idx - return worker_idx - else - return 0 - end -end - -const DefaultWorkerOutputType{P,H} = Tuple{P,H,RecordType,Float64} - -function get_worker_output_type( - ::Val{PARALLELISM}, ::Type{PopType}, ::Type{HallOfFameType} -) where {PARALLELISM,PopType,HallOfFameType} - if PARALLELISM == :serial - DefaultWorkerOutputType{PopType,HallOfFameType} - elseif PARALLELISM == :multiprocessing - Future - else - Task - end -end - -#! format: off -extract_from_worker(p::DefaultWorkerOutputType, _, _) = p -extract_from_worker(f::Future, ::Type{P}, ::Type{H}) where {P,H} = fetch(f)::DefaultWorkerOutputType{P,H} -extract_from_worker(t::Task, ::Type{P}, ::Type{H}) where {P,H} = fetch(t)::DefaultWorkerOutputType{P,H} -#! format: on - -macro sr_spawner(expr, kws...) - # Extract parallelism and worker_idx parameters from kws - @assert length(kws) == 2 - @assert all(ex -> ex.head == :(=), kws) - @assert any(ex -> ex.args[1] == :parallelism, kws) - @assert any(ex -> ex.args[1] == :worker_idx, kws) - parallelism = kws[findfirst(ex -> ex.args[1] == :parallelism, kws)::Int].args[2] - worker_idx = kws[findfirst(ex -> ex.args[1] == :worker_idx, kws)::Int].args[2] - return quote - if $(parallelism) == :serial - $(expr) - elseif $(parallelism) == :multiprocessing - @spawnat($(worker_idx), $(expr)) - elseif $(parallelism) == :multithreading - Threads.@spawn($(expr)) - else - error("Invalid parallel type ", string($(parallelism)), ".") - end - end |> esc -end - -@unstable function init_dummy_pops( - npops::Int, datasets::Vector{D}, options::Options -) where {T,L,D<:Dataset{T,L}} - prototype = Population( - first(datasets); - population_size=1, - options=options, - nfeatures=first(datasets).nfeatures, - ) - # ^ Due to occasional inference issue, we manually specify the return type - return [ - typeof(prototype)[ - if (i == 1 && j == 1) - prototype - else - Population( - datasets[j]; - population_size=1, - options=options, - nfeatures=datasets[j].nfeatures, - ) - end for i in 1:npops - ] for j in 1:length(datasets) - ] -end - -struct StdinReader{ST} - can_read_user_input::Bool - stream::ST -end - -"""Start watching stream (like stdin) for user input.""" -function watch_stream(stream) - can_read_user_input = isreadable(stream) - - can_read_user_input && try - Base.start_reading(stream) - bytes = bytesavailable(stream) - if bytes > 0 - # Clear out initial data - read(stream, bytes) - end - catch err - if isa(err, MethodError) - can_read_user_input = false - else - throw(err) - end - end - return StdinReader(can_read_user_input, stream) -end - -"""Close the stdin reader and stop reading.""" -function close_reader!(reader::StdinReader) - if reader.can_read_user_input - Base.stop_reading(reader.stream) - end -end - -"""Check if the user typed 'q' and or .""" -function check_for_user_quit(reader::StdinReader)::Bool - if reader.can_read_user_input - bytes = bytesavailable(reader.stream) - if bytes > 0 - # Read: - data = read(reader.stream, bytes) - control_c = 0x03 - quit = 0x71 - if length(data) > 1 && (data[end] == control_c || data[end - 1] == quit) - return true - end - end - end - return false -end - -function check_for_loss_threshold(halls_of_fame, options::Options)::Bool - return _check_for_loss_threshold(halls_of_fame, options.early_stop_condition, options) -end - -function _check_for_loss_threshold(_, ::Nothing, ::Options) - return false -end -function _check_for_loss_threshold(halls_of_fame, f::F, options::Options) where {F} - return all(halls_of_fame) do hof - any(hof.members[hof.exists]) do member - f(member.loss, compute_complexity(member, options))::Bool - end - end -end - -function check_for_timeout(start_time::Float64, options::Options)::Bool - return options.timeout_in_seconds !== nothing && - time() - start_time > options.timeout_in_seconds::Float64 -end - -function check_max_evals(num_evals, options::Options)::Bool - return options.max_evals !== nothing && options.max_evals::Int <= sum(sum, num_evals) -end - -""" -This struct is used to monitor resources. - -Whenever we check a channel, we record if it was empty or not. -This gives us a measure for how much of a bottleneck there is -at the head worker. -""" -Base.@kwdef mutable struct ResourceMonitor - population_ready::Vector{Bool} = Bool[] - max_recordings::Int - start_reporting_at::Int - window_size::Int -end - -function record_channel_state!(monitor::ResourceMonitor, state) - push!(monitor.population_ready, state) - if length(monitor.population_ready) > monitor.max_recordings - popfirst!(monitor.population_ready) - end - return nothing -end - -function estimate_work_fraction(monitor::ResourceMonitor)::Float64 - if length(monitor.population_ready) <= monitor.start_reporting_at - return 0.0 # Can't estimate from only one interval, due to JIT. - end - return mean(monitor.population_ready[(end - (monitor.window_size - 1)):end]) -end - -function get_load_string(; head_node_occupation::Float64, parallelism=:serial) - if parallelism == :serial || head_node_occupation == 0.0 - return "" - end - return "" - ## TODO: Debug why populations are always ready - # out = @sprintf("Head worker occupation: %.1f%%", head_node_occupation * 100) - - # raise_usage_warning = head_node_occupation > 0.4 - # if raise_usage_warning - # out *= "." - # out *= " This is high, and will prevent efficient resource usage." - # out *= " Increase `ncycles_per_iteration` to reduce load on head worker." - # end - - # out *= "\n" - # return out -end - -function update_progress_bar!( - progress_bar::WrappedProgressBar, - hall_of_fame::HallOfFame{T,L}, - dataset::Dataset{T,L}, - options::Options, - equation_speed::Vector{Float32}, - head_node_occupation::Float64, - parallelism=:serial, -) where {T,L} - equation_strings = string_dominating_pareto_curve( - hall_of_fame, dataset, options; width=progress_bar.bar.width - ) - # TODO - include command about "q" here. - load_string = if length(equation_speed) > 0 - average_speed = sum(equation_speed) / length(equation_speed) - @sprintf( - "Expressions evaluated per second: %-5.2e. ", - round(average_speed, sigdigits=3) - ) - else - @sprintf("Expressions evaluated per second: [.....]. ") - end - load_string *= get_load_string(; head_node_occupation, parallelism) - load_string *= @sprintf("Press 'q' and then to stop execution early.\n") - equation_strings = load_string * equation_strings - set_multiline_postfix!(progress_bar, equation_strings) - manually_iterate!(progress_bar) - return nothing -end - -function print_search_state( - hall_of_fames, - datasets; - options::Options, - equation_speed::Vector{Float32}, - total_cycles::Int, - cycles_remaining::Vector{Int}, - head_node_occupation::Float64, - parallelism=:serial, - width::Union{Integer,Nothing}=nothing, -) - twidth = (width === nothing) ? 100 : max(100, width::Integer) - nout = length(datasets) - average_speed = sum(equation_speed) / length(equation_speed) - - @printf("\n") - @printf("Expressions evaluated per second: %.3e\n", round(average_speed, sigdigits=3)) - load_string = get_load_string(; head_node_occupation, parallelism) - print(load_string) - cycles_elapsed = total_cycles * nout - sum(cycles_remaining) - @printf( - "Progress: %d / %d total iterations (%.3f%%)\n", - cycles_elapsed, - total_cycles * nout, - 100.0 * cycles_elapsed / total_cycles / nout - ) - - print("="^twidth * "\n") - for (j, (hall_of_fame, dataset)) in enumerate(zip(hall_of_fames, datasets)) - if nout > 1 - @printf("Best equations for output %d\n", j) - end - equation_strings = string_dominating_pareto_curve( - hall_of_fame, dataset, options; width=width - ) - print(equation_strings * "\n") - print("="^twidth * "\n") - end - return print("Press 'q' and then to stop execution early.\n") -end - -function load_saved_hall_of_fame(saved_state) - hall_of_fame = saved_state[2] - hall_of_fame = if isa(hall_of_fame, HallOfFame) - [hall_of_fame] - else - hall_of_fame - end - return [copy(hof) for hof in hall_of_fame] -end -load_saved_hall_of_fame(::Nothing)::Nothing = nothing - -function get_population( - pops::Vector{Vector{P}}; out::Int, pop::Int -)::P where {P<:Population} - return pops[out][pop] -end -function get_population(pops::Matrix{P}; out::Int, pop::Int)::P where {P<:Population} - return pops[out, pop] -end -function load_saved_population(saved_state; out::Int, pop::Int) - saved_pop = get_population(saved_state[1]; out=out, pop=pop) - return copy(saved_pop) -end -load_saved_population(::Nothing; kws...) = nothing - -""" - SearchState{PopType,HallOfFameType,WorkerOutputType,ChannelType} - -The state of a search, including the populations, worker outputs, tasks, and -channels. This is used to manage the search and keep track of runtime variables -in a single struct. -""" -Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} - procs::Vector{Int} - we_created_procs::Bool - worker_output::Vector{Vector{WorkerOutputType}} - tasks::Vector{Vector{Task}} - channels::Vector{Vector{ChannelType}} - worker_assignment::WorkerAssignments - task_order::Vector{Tuple{Int,Int}} - halls_of_fame::Vector{HallOfFame{T,L,N}} - last_pops::Vector{Vector{Population{T,L,N}}} - best_sub_pops::Vector{Vector{Population{T,L,N}}} - all_running_search_statistics::Vector{RunningSearchStatistics} - num_evals::Vector{Vector{Float64}} - cycles_remaining::Vector{Int} - cur_maxsizes::Vector{Int} - stdin_reader::StdinReader - record::Base.RefValue{RecordType} -end - -function save_to_file( - dominating, nout::Integer, j::Integer, dataset::Dataset{T,L}, options::Options -) where {T,L} - output_file = options.output_file - if nout > 1 - output_file = output_file * ".out$j" - end - dominating_n = length(dominating) - - complexities = Vector{Int}(undef, dominating_n) - losses = Vector{L}(undef, dominating_n) - strings = Vector{String}(undef, dominating_n) - - Threads.@threads for i in 1:dominating_n - member = dominating[i] - complexities[i] = compute_complexity(member, options) - losses[i] = member.loss - strings[i] = string_tree( - member.tree, options; variable_names=dataset.variable_names - ) - end - - s = let - tmp_io = IOBuffer() - - println(tmp_io, "Complexity,Loss,Equation") - for i in 1:dominating_n - println(tmp_io, "$(complexities[i]),$(losses[i]),\"$(strings[i])\"") - end - - String(take!(tmp_io)) - end - - # Write file twice in case exit in middle of filewrite - for out_file in (output_file, output_file * ".bkup") - open(out_file, "w") do io - write(io, s) - end - end - return nothing -end - -""" - get_cur_maxsize(; options, total_cycles, cycles_remaining) - -For searches where the maxsize gradually increases, this function returns the -current maxsize. -""" -function get_cur_maxsize(; options::Options, total_cycles::Int, cycles_remaining::Int) - cycles_elapsed = total_cycles - cycles_remaining - fraction_elapsed = 1.0f0 * cycles_elapsed / total_cycles - in_warmup_period = fraction_elapsed <= options.warmup_maxsize_by - - if options.warmup_maxsize_by > 0 && in_warmup_period - return 3 + floor( - Int, (options.maxsize - 3) * fraction_elapsed / options.warmup_maxsize_by - ) - else - return options.maxsize - end -end - -function construct_datasets( - X, - y, - weights, - variable_names, - display_variable_names, - y_variable_names, - X_units, - y_units, - extra, - ::Type{L}, -) where {L} - nout = size(y, 1) - return [ - Dataset( - X, - y[j, :], - L; - index=j, - weights=(weights === nothing ? weights : weights[j, :]), - variable_names=variable_names, - display_variable_names=display_variable_names, - y_variable_name=if y_variable_names === nothing - if nout > 1 - "y$(subscriptify(j))" - else - if variable_names === nothing || "y" ∉ variable_names - "y" - else - "target" - end - end - elseif isa(y_variable_names, AbstractVector) - y_variable_names[j] - else - y_variable_names - end, - X_units=X_units, - y_units=isa(y_units, AbstractVector) ? y_units[j] : y_units, - extra=extra, - ) for j in 1:nout - ] -end - -function update_hall_of_fame!( - hall_of_fame::HallOfFame, members::Vector{PM}, options::Options -) where {PM<:PopMember} - for member in members - size = compute_complexity(member, options) - valid_size = 0 < size < options.maxsize + MAX_DEGREE - if !valid_size - continue - end - not_filled = !hall_of_fame.exists[size] - better_than_current = member.score < hall_of_fame.members[size].score - if not_filled || better_than_current - hall_of_fame.members[size] = copy(member) - hall_of_fame.exists[size] = true - end - end -end - -end diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl deleted file mode 100644 index ae1b3dad2..000000000 --- a/src/SingleIteration.jl +++ /dev/null @@ -1,174 +0,0 @@ -module SingleIterationModule - -using ADTypes: AutoEnzyme -using DynamicExpressions: AbstractExpression, string_tree, simplify_tree!, combine_operators -using ..UtilsModule: @threads_if -using ..CoreModule: Options, Dataset, RecordType, create_expression -using ..ComplexityModule: compute_complexity -using ..PopMemberModule: generate_reference -using ..PopulationModule: Population, finalize_scores -using ..HallOfFameModule: HallOfFame -using ..AdaptiveParsimonyModule: RunningSearchStatistics -using ..RegularizedEvolutionModule: reg_evol_cycle -using ..LossFunctionsModule: score_func_batched, batch_sample -using ..ConstantOptimizationModule: optimize_constants -using ..RecorderModule: @recorder - -# Cycle through regularized evolution many times, -# printing the fittest equation every 10% through -function s_r_cycle( - dataset::D, - pop::P, - ncycles::Int, - curmaxsize::Int, - running_search_statistics::RunningSearchStatistics; - verbosity::Int=0, - options::Options, - record::RecordType, - dominating=nothing, - idea_database=nothing, -)::Tuple{ - P,HallOfFame{T,L,N},Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:Population{T,L,N}} - max_temp = 1.0 - min_temp = 0.0 - if !options.annealing - min_temp = max_temp - end - all_temperatures = LinRange(max_temp, min_temp, ncycles) - best_examples_seen = HallOfFame(options, dataset) - num_evals = 0.0 - - # For evaluating on a fixed batch (for batching) - idx = options.batching ? batch_sample(dataset, options) : Int[] - example_tree = create_expression(zero(T), options, dataset) - loss_cache = [(oid=example_tree, score=zero(L)) for member in pop.members] - first_loop = true - - for temperature in all_temperatures - pop, tmp_num_evals = reg_evol_cycle( - dataset, - pop, - temperature, - curmaxsize, - running_search_statistics, - options, - record; - dominating=dominating, - idea_database=idea_database, - ) - num_evals += tmp_num_evals - for (i, member) in enumerate(pop.members) - size = compute_complexity(member, options) - score = if options.batching - oid = member.tree - if loss_cache[i].oid != oid || first_loop - # Evaluate on fixed batch so that we can more accurately - # compare expressions with a batched loss (though the batch - # changes each iteration, and we evaluate on full-batch outside, - # so this is not biased). - _score, _ = score_func_batched( - dataset, member, options; complexity=size, idx=idx - ) - loss_cache[i] = (oid=copy(oid), score=_score) - _score - else - # Already evaluated this particular expression, so just use - # the cached score - loss_cache[i].score - end - else - member.score - end - # TODO: Note that this per-population hall of fame only uses the batched - # loss, and is therefore inaccurate. Therefore, some expressions - # may be loss if a very small batch size is used. - # - Could have different batch size for different things (smaller for constant opt) - # - Could just recompute losses here (expensive) - # - Average over a few batches - # - Store multiple expressions in hall of fame - if 0 < size <= options.maxsize && ( - !best_examples_seen.exists[size] || - score < best_examples_seen.members[size].score - ) - best_examples_seen.exists[size] = true - best_examples_seen.members[size] = copy(member) - end - end - first_loop = false - end - - return (pop, best_examples_seen, num_evals) -end - -function optimize_and_simplify_population( - dataset::D, pop::P, options::Options, curmaxsize::Int, record::RecordType -)::Tuple{P,Float64} where {T,L,D<:Dataset{T,L},P<:Population{T,L}} - array_num_evals = zeros(Float64, pop.n) - do_optimization = rand(pop.n) .< options.optimizer_probability - # Note: we have to turn off this threading loop due to Enzyme, since we need - # to manually allocate a new task with a larger stack for Enzyme. - should_thread = !(options.deterministic) && !(isa(options.autodiff_backend, AutoEnzyme)) - @threads_if should_thread for j in 1:(pop.n) - if options.should_simplify - tree = pop.members[j].tree - tree = simplify_tree!(tree, options.operators) - tree = combine_operators(tree, options.operators) - pop.members[j].tree = tree - end - if options.should_optimize_constants && do_optimization[j] - # TODO: Might want to do full batch optimization here? - pop.members[j], array_num_evals[j] = optimize_constants( - dataset, pop.members[j], options - ) - end - end - num_evals = sum(array_num_evals) - pop, tmp_num_evals = finalize_scores(dataset, pop, options) - num_evals += tmp_num_evals - - # Now, we create new references for every member, - # and optionally record which operations occurred. - for j in 1:(pop.n) - old_ref = pop.members[j].ref - new_ref = generate_reference() - pop.members[j].parent = old_ref - pop.members[j].ref = new_ref - - @recorder begin - # Same structure as in RegularizedEvolution.jl, - # except we assume that the record already exists. - @assert haskey(record, "mutations") - member = pop.members[j] - if !haskey(record["mutations"], "$(member.ref)") - record["mutations"]["$(member.ref)"] = RecordType( - "events" => Vector{RecordType}(), - "tree" => string_tree(member.tree, options), - "score" => member.score, - "loss" => member.loss, - "parent" => member.parent, - ) - end - optimize_and_simplify_event = RecordType( - "type" => "tuning", - "time" => time(), - "child" => new_ref, - "mutation" => RecordType( - "type" => - if (do_optimization[j] && options.should_optimize_constants) - "simplification_and_optimization" - else - "simplification" - end, - ), - ) - death_event = RecordType("type" => "death", "time" => time()) - - push!(record["mutations"]["$(old_ref)"]["events"], optimize_and_simplify_event) - push!(record["mutations"]["$(old_ref)"]["events"], death_event) - end - end - return (pop, num_evals) -end - -end diff --git a/src/Utils.jl b/src/Utils.jl index a667b6987..744a19fae 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -1,7 +1,6 @@ """Useful functions to be used throughout the library.""" module UtilsModule -using Printf: @printf using MacroTools: splitdef macro ignore(args...) end diff --git a/src/deprecates.jl b/src/deprecates.jl deleted file mode 100644 index 54816a408..000000000 --- a/src/deprecates.jl +++ /dev/null @@ -1,97 +0,0 @@ -using Base: @deprecate - -import .HallOfFameModule: calculate_pareto_frontier -import .MutationFunctionsModule: gen_random_tree, gen_random_tree_fixed_size - -@deprecate( - gen_random_tree(length::Int, options::Options, nfeatures::Int, t::Type), - gen_random_tree(length, options, nfeatures, t) -) -@deprecate( - gen_random_tree_fixed_size(node_count::Int, options::Options, nfeatures::Int, t::Type), - gen_random_tree_fixed_size(node_count, options, nfeatures, t) -) - -@deprecate( - calculate_pareto_frontier(X, y, hallOfFame, options; weights=nothing, varMap=nothing), - calculate_pareto_frontier(hallOfFame) -) -@deprecate( - calculate_pareto_frontier(dataset, hallOfFame, options), - calculate_pareto_frontier(hallOfFame) -) - -@deprecate( - EquationSearch(X::AbstractMatrix{T1}, y::AbstractMatrix{T2}; kw...) where {T1,T2}, - equation_search(X, y; kw...) -) - -@deprecate( - EquationSearch(X::AbstractMatrix{T1}, y::AbstractVector{T2}; kw...) where {T1,T2}, - equation_search(X, y; kw...) -) - -@deprecate(EquationSearch(dataset::Dataset; kws...), equation_search(dataset; kws...),) - -@deprecate( - EquationSearch( - X::AbstractMatrix{T}, - y::AbstractMatrix{T}; - niterations::Int=10, - weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing, - variable_names::Union{Vector{String},Nothing}=nothing, - options::Options=Options(), - parallelism=:multithreading, - numprocs::Union{Int,Nothing}=nothing, - procs::Union{Vector{Int},Nothing}=nothing, - addprocs_function::Union{Function,Nothing}=nothing, - runtests::Bool=true, - saved_state=nothing, - loss_type::Type=Nothing, - # Deprecated: - multithreaded=nothing, - varMap=nothing, - ) where {T<:DATA_TYPE}, - equation_search( - X, - y; - niterations, - weights, - variable_names, - options, - parallelism, - numprocs, - procs, - addprocs_function, - runtests, - saved_state, - loss_type, - multithreaded, - varMap, - ) -) - -@deprecate( - EquationSearch( - datasets::Vector{D}; - niterations::Int=10, - options::Options=Options(), - parallelism=:multithreading, - numprocs::Union{Int,Nothing}=nothing, - procs::Union{Vector{Int},Nothing}=nothing, - addprocs_function::Union{Function,Nothing}=nothing, - runtests::Bool=true, - saved_state=nothing, - ) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}}, - equation_search( - datasets; - niterations, - options, - parallelism, - numprocs, - procs, - addprocs_function, - runtests, - saved_state, - ) -) diff --git a/src/precompile.jl b/src/precompile.jl index df736dac0..5b7e68115 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -1,5 +1,4 @@ using PrecompileTools: @compile_workload, @setup_workload - macro maybe_setup_workload(mode, ex) precompile_ex = Expr( :macrocall, Symbol("@setup_workload"), LineNumberNode(@__LINE__), ex @@ -41,11 +40,12 @@ function do_precompilation(::Val{mode}) where {mode} X = randn(T, 3, N) y = start ? randn(T, N) : randn(T, nout, N) @maybe_compile_workload mode begin - options = LibraryAugmentedSymbolicRegression.Options(; + options = LibraryAugmentedSymbolicRegression.LaSROptions(; binary_operators=[+, *, /, -, ^], unary_operators=[sin, cos, exp, log, sqrt, abs], populations=3, population_size=start ? 50 : 12, + tournament_selection_n=6, ncycles_per_iteration=start ? 30 : 1, mutation_weights=MutationWeights(; mutate_constant=1.0, @@ -57,13 +57,12 @@ function do_precompilation(::Val{mode}) where {mode} simplify=1.0, randomize=1.0, do_nothing=1.0, - optimize=PRECOMPILE_OPTIMIZATION ? 1.0 : 0.0, + optimize=1.0, ), fraction_replaced=0.2, fraction_replaced_hof=0.2, define_helper_functions=false, - optimizer_probability=PRECOMPILE_OPTIMIZATION ? 0.05 : 0.0, - should_optimize_constants=PRECOMPILE_OPTIMIZATION, + optimizer_probability=0.05, save_to_file=false, ) state = equation_search( diff --git a/src/scripts/apply_deprecates.py b/src/scripts/apply_deprecates.py deleted file mode 100644 index d3ded6053..000000000 --- a/src/scripts/apply_deprecates.py +++ /dev/null @@ -1,32 +0,0 @@ -from glob import glob -import re - -# Use src/Deprecates.jl to replace all deprecated functions -# in entire codebase, excluding src/Deprecates.jl - -# First, we build the library: -with open("src/Deprecates.jl", "r") as f: - library = {} - for line in f.read().split("\n"): - # If doesn't start with `@deprecate`, skip: - if not line.startswith("@deprecate"): - continue - - # Each line is in the format: - # @deprecate - function_name = line.split(" ")[1] - replacement = line.split(" ")[2] - library[function_name] = replacement - -# Now, we replace all deprecated functions in src/*.jl: -for fname in glob("**/*.jl"): - if fname == "src/Deprecates.jl": - continue - with open(fname, "r") as f: - contents = f.read() - for function_name in library: - contents = re.sub( - r"\b" + function_name + r"\b", library[function_name], contents - ) - with open(fname, "w") as f: - f.write(contents) diff --git a/src/scripts/fromfile_to_modules.sh b/src/scripts/fromfile_to_modules.sh deleted file mode 100755 index 016fc92d3..000000000 --- a/src/scripts/fromfile_to_modules.sh +++ /dev/null @@ -1,19 +0,0 @@ -#!/bin/bash -# Requires vim-stream (https://github.com/MilesCranmer/vim-stream) - -# The user passes files as arguments: -FILES=$@ - -# Loop through files: -for file in $FILES; do - base=$(basename ${file%.*}) - cat $file | vims -t '%g/^@from/s/@from "\(.\{-}\)\.jl" import/import ..\1Module:/g' -e 'using FromFile' 'dd' -s "Omodule ${base}Module\" 'Go\end' | sed "s/^ $//g" > tmp.jl - mv tmp.jl $file -done - - -# Changes to make: -# - Run this file on everything. -# - Change `import ..` to `import .` in SymbolicRegression.jl -# - Rename module that have same name as existing variables: - # - All files are mapped to _{file}.jl, and modules as well. diff --git a/test/Project.toml b/test/Project.toml index fb83b7c8d..e2fe8e96c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,32 +1,23 @@ [deps] -Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" -Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" -DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" -DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" -Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" -DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" -DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" -Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" -LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" -MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" -MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" -MLJTestInterface = "72560011-54dd-4dc2-94f3-c5de45b75ecd" -Optim = "429524aa-4258-5aef-a3af-852621145aeb" -Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" -Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" -SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" -Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" LibraryAugmentedSymbolicRegression = "158930c3-947c-4174-974b-74b39e64a28f" -SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" +JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb" +DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" +Optim = "429524aa-4258-5aef-a3af-852621145aeb" +LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" + +[compat] +Aqua = "0.8" [preferences.LibraryAugmentedSymbolicRegression] instability_check = "error" +Random = "1.11.0" +SymbolicRegression = "1" +Optim = "1.10.0" +LineSearches = "7.3.0" \ No newline at end of file diff --git a/test/manual_distributed.jl b/test/manual_distributed.jl deleted file mode 100644 index 2bfe6aeeb..000000000 --- a/test/manual_distributed.jl +++ /dev/null @@ -1,31 +0,0 @@ -include("test_params.jl") - -using Distributed -procs = addprocs(2) -using Test, Pkg -project_path = splitdir(Pkg.project().path)[1] -@everywhere procs begin - Base.MainInclude.eval( - quote - using Pkg - Pkg.activate($$project_path) - end, - ) -end -@everywhere using LibraryAugmentedSymbolicRegression -@everywhere _inv(x::Float32)::Float32 = 1.0f0 / x -X = rand(Float32, 5, 100) .+ 1 -y = 1.2f0 .+ 2 ./ X[3, :] - -options = LibraryAugmentedSymbolicRegression.Options(; - default_params..., binary_operators=(+, *), unary_operators=(_inv,), populations=8 -) -hallOfFame = equation_search( - X, y; niterations=8, options=options, parallelism=:multiprocessing, procs=procs -) -rmprocs(procs) - -dominating = calculate_pareto_frontier(hallOfFame) -best = dominating[end] -# Test the score -@test best.loss < maximum_residual / 10 diff --git a/test/runtests.jl b/test/runtests.jl index 0a484650d..844639de9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,180 +2,26 @@ using TestItems: @testitem using TestItemRunner: @run_package_tests ENV["SYMBOLIC_REGRESSION_TEST"] = "true" -tags_to_run = let t = get(ENV, "SYMBOLIC_REGRESSION_TEST_SUITE", "part1,part2,part3") +tags_to_run = let t = get(ENV, "SYMBOLIC_REGRESSION_TEST_SUITE", "online,offline") t = split(t, ",") t = map(Symbol, t) t end - @eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) verbose = true -# TODO: This is a very slow test -@testitem "Test custom operators and additional types" tags = [:part2] begin - include("test_operators.jl") -end - -@testitem "Test tree construction and scoring" tags = [:part3] begin - include("test_tree_construction.jl") -end - -include("test_graph_nodes.jl") - -@testitem "Test SymbolicUtils interface" tags = [:part1] begin - include("test_symbolic_utils.jl") -end - -@testitem "Test constraints interface" tags = [:part2] begin - include("test_constraints.jl") -end - -@testitem "Test custom losses" tags = [:part1] begin - include("test_losses.jl") -end - -@testitem "Test derivatives" tags = [:part2] begin - include("test_derivatives.jl") -end -include("test_expression_derivatives.jl") - -@testitem "Test simplification" tags = [:part3] begin - include("test_simplification.jl") -end - -@testitem "Test printing" tags = [:part1] begin - include("test_print.jl") -end - -@testitem "Test validity of expression evaluation" tags = [:part2] begin - include("test_evaluation.jl") -end - -@testitem "Test turbo mode with NaN" tags = [:part3] begin - include("test_turbo_nan.jl") -end - -@testitem "Test validity of integer expression evaluation" tags = [:part1] begin - include("test_integer_evaluation.jl") -end - -@testitem "Test tournament selection" tags = [:part2] begin - include("test_prob_pick_first.jl") -end - -@testitem "Test crossover mutation" tags = [:part3] begin - include("test_crossover.jl") -end - -# TODO: This is another very slow test -@testitem "Test NaN detection in evaluator" tags = [:part1] begin - include("test_nan_detection.jl") -end - -@testitem "Test nested constraint checking" tags = [:part2] begin - include("test_nested_constraints.jl") -end - -@testitem "Test complexity evaluation" tags = [:part3] begin - include("test_complexity.jl") -end - -@testitem "Test options" tags = [:part1] begin - include("test_options.jl") -end - -@testitem "Test hash of tree" tags = [:part2] begin - include("test_hash.jl") -end - -@testitem "Test migration" tags = [:part3] begin - include("test_migration.jl") -end - -@testitem "Test deprecated options" tags = [:part1] begin - include("test_deprecation.jl") -end - -@testitem "Test optimization mutation" tags = [:part2] begin - include("test_optimizer_mutation.jl") -end - -@testitem "Test RunningSearchStatistics" tags = [:part3] begin - include("test_search_statistics.jl") -end - -@testitem "Test utils" tags = [:part1] begin - include("test_utils.jl") -end - -include("test_units.jl") - -@testitem "Dataset" tags = [:part3] begin - include("test_dataset.jl") -end - -include("test_mixed.jl") - -@testitem "Testing fast-cycle and custom variable names" tags = [:part2] begin - include("test_fast_cycle.jl") -end - -@testitem "Testing whether we can stop based on clock time." tags = [:part3] begin - include("test_stop_on_clock.jl") -end - -@testitem "Running README example." tags = [:part1] begin - ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" - include("../example.jl") +@testitem "Test expression parser" tags = [:online] begin + include("test_lasr_parser.jl") end -# TODO: This is the slowest test. -@testitem "Running parameterized function example." tags = [:part2] begin - ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" - include("../examples/parameterized_function.jl") -end - -@testitem "Testing whether the recorder works." tags = [:part3] begin - include("test_recorder.jl") -end - -@testitem "Testing whether deterministic mode works." tags = [:part1] begin - include("test_deterministic.jl") -end - -@testitem "Testing whether early stop criteria works." tags = [:part2] begin - include("test_early_stop.jl") -end - -include("test_mlj.jl") - -@testitem "Testing whether we can move operators to workers." tags = [:part1] begin - include("test_custom_operators_multiprocessing.jl") -end - -@testitem "Test whether the precompilation script works." tags = [:part2] begin +@testitem "Test whether the precompilation script works." tags = [:online] begin include("test_precompilation.jl") end -@testitem "Test whether custom objectives work." tags = [:part3] begin - include("test_custom_objectives.jl") -end - -@testitem "Test abstract numbers" tags = [:part1] begin - include("test_abstract_numbers.jl") -end - -include("test_pretty_printing.jl") -include("test_expression_builder.jl") - -@testitem "Aqua tests" tags = [:part2, :aqua] begin +@testitem "Aqua tests" tags = [:online, :aqua] begin include("test_aqua.jl") end -@testitem "JET tests" tags = [:part1, :jet] begin +@testitem "JET tests" tags = [:online, :jet] begin test_jet_file = joinpath((@__DIR__), "test_jet.jl") run(`$(Base.julia_cmd()) --startup-file=no $test_jet_file`) end - -@testitem "LLM Integration tests" tags = [:part3, :llm] begin - include("test_lasr_integration.jl") -end diff --git a/test/test_abstract_numbers.jl b/test/test_abstract_numbers.jl deleted file mode 100644 index 373cea5a7..000000000 --- a/test/test_abstract_numbers.jl +++ /dev/null @@ -1,37 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Random -include("test_params.jl") - -get_base_type(::Type{<:Complex{BT}}) where {BT} = BT - -for T in (ComplexF16, ComplexF32, ComplexF64) - L = get_base_type(T) - @testset "Test search with $T type" begin - X = randn(MersenneTwister(0), T, 1, 100) - y = @. (2 - 0.5im) * cos((1 + 1im) * X[1, :]) |> T - - early_stop(loss::L, c) where {L} = ((loss <= L(1e-2)) && (c <= 15)) - - options = LibraryAugmentedSymbolicRegression.Options(; - binary_operators=[+, *, -, /], - unary_operators=[cos], - populations=20, - early_stop_condition=early_stop, - elementwise_loss=(prediction, target) -> abs2(prediction - target), - ) - - dataset = Dataset(X, y, L) - hof = if T == ComplexF16 - equation_search([dataset]; options=options, niterations=1_000_000_000) - else - # Should automatically find correct type: - equation_search(X, y; options=options, niterations=1_000_000_000) - end - - dominating = calculate_pareto_frontier(hof) - @test typeof(dominating[end].loss) == L - output, _ = eval_tree_array(dominating[end].tree, X, options) - @test typeof(output) <: AbstractArray{T} - @test sum(abs2, output .- y) / length(output) <= L(1e-2) - end -end diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 6c8ef5442..5ed4a216d 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -3,4 +3,4 @@ using Aqua Aqua.test_all(LibraryAugmentedSymbolicRegression; ambiguities=false) -VERSION >= v"1.9" && Aqua.test_ambiguities(LibraryAugmentedSymbolicRegression) +Aqua.test_ambiguities(LibraryAugmentedSymbolicRegression) diff --git a/test/test_backwards_compat.jl b/test/test_backwards_compat.jl new file mode 100644 index 000000000..34c30db7a --- /dev/null +++ b/test/test_backwards_compat.jl @@ -0,0 +1,2 @@ +# These are some "stress tests" copied over from the SymbolicRegression.jl package. +# They are meant to test the backwards compatibility of the LaSROptions struct. diff --git a/test/test_complexity.jl b/test/test_complexity.jl deleted file mode 100644 index dc63aa302..000000000 --- a/test/test_complexity.jl +++ /dev/null @@ -1,54 +0,0 @@ -println("Testing custom complexities.") -using LibraryAugmentedSymbolicRegression, Test - -x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") - -# First, test regular complexities: -function make_options(; kw...) - return Options(; binary_operators=(+, -, *, /, ^), unary_operators=(cos, sin), kw...) -end -options = make_options() -@extend_operators options -tree = sin((x1 + x2 + x3)^2.3) -@test compute_complexity(tree, options) == 8 - -options = make_options(; complexity_of_operators=[sin => 3]) -@test compute_complexity(tree, options) == 10 -options = make_options(; complexity_of_operators=[sin => 3, (+) => 2]) -@test compute_complexity(tree, options) == 12 - -# Real numbers: -options = make_options(; complexity_of_operators=[sin => 3, (+) => 2, (^) => 3.2]) -@test compute_complexity(tree, options) == round(Int, 12 + (3.2 - 1)) - -# Now, test other things, like variables and constants: -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], complexity_of_variables=2 -) -@test compute_complexity(tree, options) == 12 + 3 * 1 -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], - complexity_of_variables=2, - complexity_of_constants=2, -) -@test compute_complexity(tree, options) == 12 + 3 * 1 + 1 -options = make_options(; - complexity_of_operators=[sin => 3, (+) => 2], - complexity_of_variables=2, - complexity_of_constants=2.6, -) -@test compute_complexity(tree, options) == 12 + 3 * 1 + 1 + 1 - -# Custom variables -options = make_options(; - complexity_of_variables=[1, 2, 3], complexity_of_operators=[(+) => 5, (*) => 2] -) -x1, x2, x3 = [Node{Float64}(; feature=i) for i in 1:3] -tree = x1 + x2 * x3 -@test compute_complexity(tree, options) == 1 + 5 + 2 + 2 + 3 -options = make_options(; - complexity_of_variables=2, complexity_of_operators=[(+) => 5, (*) => 2] -) -@test compute_complexity(tree, options) == 2 + 5 + 2 + 2 + 2 - -println("Passed.") diff --git a/test/test_constraints.jl b/test/test_constraints.jl deleted file mode 100644 index 90ed581f1..000000000 --- a/test/test_constraints.jl +++ /dev/null @@ -1,57 +0,0 @@ -using DynamicExpressions: count_depth -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: check_constraints -include("test_params.jl") - -_inv(x) = 1 / x -options = Options(; - default_params..., - binary_operators=(+, *, ^, /, greater), - unary_operators=(_inv,), - constraints=(_inv => 4,), - populations=4, -) -@extend_operators options -tree = Node(5, (^)(Node(; val=3.0) * Node(1, Node("x1")), 2.0), Node(; val=-1.2)) -violating_tree = Node(1, tree) - -@test check_constraints(tree, options) == true -@test check_constraints(violating_tree, options) == false - -# Test complexity constraints: -options = Options(; binary_operators=(+, *), maxsize=5) -@extend_operators options -x1, x2, x3 = [Node(; feature=i) for i in 1:3] -tree = x1 + x2 * x3 -violating_tree = 5.1 * tree -@test check_constraints(tree, options) == true -@test check_constraints(violating_tree, options) == false - -# Also test for custom complexities: -options = Options(; binary_operators=(+, *), maxsize=5, complexity_of_operators=[(*) => 3]) -@test check_constraints(tree, options) == false -options = Options(; binary_operators=(+, *), maxsize=5, complexity_of_operators=[(*) => 0]) -@test check_constraints(violating_tree, options) == true - -# Test for depth constraints: -options = Options(; - binary_operators=(+, *), unary_operators=(cos,), maxsize=100, maxdepth=3 -) -@extend_operators options -x1, x2, x3 = [Node(; feature=i) for i in 1:3] - -tree = (x1 + x2) + (x3 + x1) -@test count_depth(tree) == 3 -@test check_constraints(tree, options) == true - -tree = (x1 + x2) + (x3 + x1) * x1 -@test count_depth(tree) == 4 -@test check_constraints(tree, options) == false - -tree = cos(cos(x1)) -@test count_depth(tree) == 3 -@test check_constraints(tree, options) == true - -tree = cos(cos(cos(x1))) -@test count_depth(tree) == 4 -@test check_constraints(tree, options) == false diff --git a/test/test_crossover.jl b/test/test_crossover.jl deleted file mode 100644 index 2f72fc471..000000000 --- a/test/test_crossover.jl +++ /dev/null @@ -1,47 +0,0 @@ -println("Testing crossover function.") -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: crossover_trees -include("test_params.jl") - -options = LibraryAugmentedSymbolicRegression.Options(; - default_params..., - binary_operators=(+, *, /, -), - unary_operators=(cos, exp), - populations=8, -) -tree1 = cos(Node("x1")) + (3.0f0 + Node("x2")) -tree2 = exp(Node("x1") - Node("x2") * Node("x2")) + 10.0f0 * Node("x3") - -# See if we can observe operators flipping sides: -cos_flip_to_tree2 = false -exp_flip_to_tree1 = false -swapped_cos_with_exp = false -for i in 1:1000 - child_tree1, child_tree2 = crossover_trees(tree1, tree2) - if occursin("cos", repr(child_tree2)) - # Moved cosine to tree2 - global cos_flip_to_tree2 = true - end - if occursin("exp", repr(child_tree1)) - # Moved exp to tree1 - global exp_flip_to_tree1 = true - end - if occursin("cos", repr(child_tree2)) && occursin("exp", repr(child_tree1)) - global swapped_cos_with_exp = true - # Moved exp with cos - @assert !occursin("cos", repr(child_tree1)) - @assert !occursin("exp", repr(child_tree2)) - end - - # Check that exact same operators, variables, numbers before and after: - rep_tree_final = sort([a for a in repr(child_tree1) * repr(child_tree2)]) - rep_tree_final = strip(String(rep_tree_final), ['(', ')', ' ']) - rep_tree_initial = sort([a for a in repr(tree1) * repr(tree2)]) - rep_tree_initial = strip(String(rep_tree_initial), ['(', ')', ' ']) - @test rep_tree_final == rep_tree_initial -end - -@test cos_flip_to_tree2 -@test exp_flip_to_tree1 -@test swapped_cos_with_exp -println("Passed.") diff --git a/test/test_custom_objectives.jl b/test/test_custom_objectives.jl deleted file mode 100644 index 9b3a43760..000000000 --- a/test/test_custom_objectives.jl +++ /dev/null @@ -1,52 +0,0 @@ -using LibraryAugmentedSymbolicRegression -include("test_params.jl") - -def = quote - function my_custom_loss( - tree::$(AbstractExpressionNode){T}, dataset::$(Dataset){T}, options::$(Options) - ) where {T} - # We multiply the tree by 2.0: - tree = $(Node)(1, tree, $(Node)(T; val=2.0)) - out, completed = $(eval_tree_array)(tree, dataset.X, options) - if !completed - return T(Inf) - end - return sum(abs, out .- dataset.y) - end -end - -# TODO: Required for workers as they assume the function is defined in the Main module -if (@__MODULE__) != Core.Main - Core.eval(Core.Main, def) - eval(:(using Main: my_custom_loss)) -else - eval(def) -end - -options = Options(; - binary_operators=[*, /, +, -], - unary_operators=[cos, sin], - loss_function=my_custom_loss, - elementwise_loss=nothing, - maxsize=10, - early_stop_condition=1e-10, - adaptive_parsimony_scaling=100.0, - mutation_weights=MutationWeights(; optimize=0.01), -) - -@test options.should_simplify == false - -X = rand(2, 100) .* 10 -y = X[1, :] .+ X[2, :] - -# The best tree should be 0.5 * (x1 + x2), since the custom loss function -# multiplies the tree by 2.0. - -hall_of_fame = equation_search( - X, y; niterations=100, options=options, parallelism=:multiprocessing, numprocs=1 -) -dominating = calculate_pareto_frontier(hall_of_fame) - -testX = rand(2, 100) .* 10 -expected_y = 0.5 .* (testX[1, :] .+ testX[2, :]) -@test eval_tree_array(dominating[end].tree, testX, options)[1] ≈ expected_y atol = 1e-5 diff --git a/test/test_custom_operators.jl b/test/test_custom_operators.jl deleted file mode 100644 index d0d6d4af0..000000000 --- a/test/test_custom_operators.jl +++ /dev/null @@ -1,76 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Random - -# Test that we can work with custom operators: -function op1(x::T, y::T)::T where {T<:Real} - return x + y -end -function op2(x::T, y::T)::T where {T<:Real} - return x^2 + 1 / ((y)^2 + 0.1) -end -function op3(x::T)::T where {T<:Real} - return sin(x) + cos(x) -end -local operators, tree -options = Options(; binary_operators=[op1, op2], unary_operators=[op3]) -@extend_operators options -x1 = Node(; feature=1) -x2 = Node(; feature=2) -tree = op1(op2(x1, x2), op3(x1)) -@test repr(tree) == "op1(op2(x1, x2), op3(x1))" -# Test evaluation: -X = randn(MersenneTwister(0), Float32, 2, 10); -@test tree(X, options) ≈ ((x1, x2) -> op1(op2(x1, x2), op3(x1))).(X[1, :], X[2, :]) - -# Now, test that we can work with operators defined in modules -module A - -using LibraryAugmentedSymbolicRegression -using Random - -function my_func_a(x::T, y::T) where {T<:Real} - return x^2 * y -end - -function my_func_b(x::T) where {T<:Real} - return x^3 -end - -options = Options(; binary_operators=[my_func_a], unary_operators=[my_func_b]) -@extend_operators options - -function create_and_eval_tree() - x1 = Node(Float64; feature=1) - x2 = Node(Float64; feature=2) - c1 = Node(Float64; val=0.2) - tree = my_func_a(my_func_a(x2, 0.2), my_func_b(x1)) - func = (x1, x2) -> my_func_a(my_func_a(x2, 0.2), my_func_b(x1)) - X = randn(MersenneTwister(0), 2, 20) - return tree(X, options), func.(X[1, :], X[2, :]) -end - -end - -using .A: create_and_eval_tree -prediction, truth = create_and_eval_tree() -@test prediction ≈ truth - -# Now, test that we can work with operators defined in other modules -module B - -my_func_c(x::T, y::T) where {T<:Real} = x * y + T(0.3) -my_func_d(x::T) where {T<:Real} = x / (abs(x)^T(0.2) + 0.1) - -end - -using .B: my_func_c, my_func_d -options = Options(; binary_operators=[my_func_c], unary_operators=[my_func_d]) -@extend_operators options - -x1 = Node(Float64; feature=1) -x2 = Node(Float64; feature=2) -c1 = Node(Float64; val=0.2) -tree = my_func_c(my_func_c(x2, 0.2), my_func_d(x1)) -func = (x1, x2) -> my_func_c(my_func_c(x2, 0.2), my_func_d(x1)) -X = randn(MersenneTwister(0), 2, 20) -@test tree(X, options) ≈ func.(X[1, :], X[2, :]) diff --git a/test/test_custom_operators_multiprocessing.jl b/test/test_custom_operators_multiprocessing.jl deleted file mode 100644 index 1cb481f88..000000000 --- a/test/test_custom_operators_multiprocessing.jl +++ /dev/null @@ -1,45 +0,0 @@ -using LibraryAugmentedSymbolicRegression - -defs = quote - _plus(x, y) = x + y - _mult(x, y) = x * y - _div(x, y) = x / y - _min(x, y) = x - y - _cos(x) = cos(x) - _exp(x) = exp(x) - early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10)) - my_loss(x, y, w) = abs(x - y)^2 * w -end - -# This is needed as workers are initialized in `Core.Main`! -if (@__MODULE__) != Core.Main - Core.eval(Core.Main, defs) - eval(:(using Main: _plus, _mult, _div, _min, _cos, _exp, early_stop, my_loss)) -else - eval(defs) -end - -X = randn(Float32, 5, 100) -y = _mult.(2, _cos.(X[4, :])) + _mult.(X[1, :], X[1, :]) - -options = LibraryAugmentedSymbolicRegression.Options(; - binary_operators=(_plus, _mult, _div, _min), - unary_operators=(_cos, _exp), - populations=20, - early_stop_condition=early_stop, - elementwise_loss=my_loss, -) - -hof = equation_search( - X, - y; - weights=ones(Float32, 100), - options=options, - niterations=1_000_000_000, - numprocs=2, - parallelism=:multiprocessing, -) - -@test any( - early_stop(member.loss, count_nodes(member.tree)) for member in hof.members[hof.exists] -) diff --git a/test/test_dataset.jl b/test/test_dataset.jl deleted file mode 100644 index ededef6c3..000000000 --- a/test/test_dataset.jl +++ /dev/null @@ -1,16 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using DispatchDoctor: allow_unstable - -@testset "Dataset construction" begin - # Promotion of types: - dataset = Dataset(randn(3, 32), randn(Float32, 32); weights=randn(Float32, 32)) - @test typeof(dataset.y) == Array{Float64,1} - @test typeof(dataset.weights) == Array{Float64,1} -end - -@testset "With deprecated kwarg" begin - dataset = allow_unstable() do - Dataset(randn(ComplexF32, 3, 32), randn(ComplexF32, 32); loss_type=Float64) - end - @test dataset isa Dataset{ComplexF32,Float64} -end diff --git a/test/test_deprecation.jl b/test/test_deprecation.jl deleted file mode 100644 index 036925c0c..000000000 --- a/test/test_deprecation.jl +++ /dev/null @@ -1,17 +0,0 @@ -using LibraryAugmentedSymbolicRegression - -# Deprecated kwargs should still work: -options = Options(; - mutationWeights=MutationWeights(; mutate_constant=0.0), - fractionReplacedHof=0.01f0, - shouldOptimizeConstants=true, - loss=L2DistLoss(), -) - -@test options.mutation_weights.mutate_constant == 0.0 -@test options.fraction_replaced_hof == 0.01f0 -@test options.should_optimize_constants == true -@test options.elementwise_loss == L2DistLoss() - -options = Options(; mutationWeights=[1.0 for i in 1:8]) -@test options.mutation_weights.add_node == 1.0 diff --git a/test/test_derivatives.jl b/test/test_derivatives.jl deleted file mode 100644 index 8ee4d1cc9..000000000 --- a/test/test_derivatives.jl +++ /dev/null @@ -1,149 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: eval_diff_tree_array, eval_grad_tree_array -using Random -using Zygote -using LinearAlgebra - -seed = 0 -pow_abs2(x::T, y::T) where {T<:Real} = abs(x)^y -custom_cos(x::T) where {T<:Real} = cos(x)^2 - -equation1(x1, x2, x3) = x1 + x2 + x3 + 3.2 -equation2(x1, x2, x3) = pow_abs2(x1, x2) + x3 + custom_cos(1.0 + x3) + 3.0 / x1 -function equation3(x1, x2, x3) - return ( - ((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + ( - sin( - custom_cos( - sin(1.2926733 - 1.6606787) / sin( - ((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426 - ), - ) * (custom_cos(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)), - ) / (0.14854191 - ((custom_cos(x2) * -1.6047639) - 0.023943262)) - ) - ) -end - -nx1 = Node("x1") -nx2 = Node("x2") -nx3 = Node("x3") - -# Equations to test gradients on: - -function array_test(ar1, ar2; rtol=0.1) - return isapprox(ar1, ar2; rtol=rtol) -end - -for type in [Float16, Float32, Float64] - println("Testing derivatives with respect to variables, with type=$(type).") - rng = MersenneTwister(seed) - nfeatures = 3 - N = 100 - - X = rand(rng, type, nfeatures, N) * 5 - - options = Options(; - binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin) - ) - @extend_operators options - - for j in 1:3 - equation = [equation1, equation2, equation3][j] - if type == Float16 && j == 3 - # Numerical precision hurts this comparison too much - continue - end - - tree = convert(Node{type}, equation(nx1, nx2, nx3)) - predicted_output = eval_tree_array(tree, X, options)[1] - true_output = equation.([X[i, :] for i in 1:nfeatures]...) - true_output = convert(AbstractArray{type}, true_output) - - # First, check if the predictions are approximately equal: - @test array_test(predicted_output, true_output) - - true_grad = gradient( - (x1, x2, x3) -> sum(equation.(x1, x2, x3)), [X[i, :] for i in 1:nfeatures]... - ) - # Convert tuple of vectors to matrix: - true_grad = reduce(hcat, true_grad)' - predicted_grad = eval_grad_tree_array(tree, X, options; variable=true)[2] - predicted_grad2 = - reduce( - hcat, [eval_diff_tree_array(tree, X, options, i)[2] for i in 1:nfeatures] - )' - - # Print largest difference between predicted_grad, true_grad: - @test array_test(predicted_grad, true_grad) - @test array_test(predicted_grad2, true_grad) - - # Make sure that the array_test actually works: - @test !array_test(predicted_grad .* 0, true_grad) - @test !array_test(predicted_grad2 .* 0, true_grad) - end - println("Done.") - println("Testing derivatives with respect to constants, with type=$(type).") - - # Test gradient with respect to constants: - equation4(x1, x2, x3) = 3.2f0 * x1 - # The gradient should be: (C * x1) => x1 is gradient with respect to C. - tree = equation4(nx1, nx2, nx3) - tree = convert(Node{type}, tree) - predicted_grad = eval_grad_tree_array(tree, X, options; variable=false)[2] - @test array_test(predicted_grad[1, :], X[1, :]) - - # More complex expression: - const_value = 2.1f0 - const_value2 = -3.2f0 - - function equation5(x1, x2, x3) - return pow_abs2(x1, x2) + x3 + custom_cos(const_value + x3) + const_value2 / x1 - end - function equation5_with_const(c1, c2, x1, x2, x3) - return pow_abs2(x1, x2) + x3 + custom_cos(c1 + x3) + c2 / x1 - end - - tree = equation5(nx1, nx2, nx3) - tree = convert(Node{type}, tree) - - # Use zygote to explicitly find the gradient: - true_grad = gradient( - (c1, c2, x1, x2, x3) -> sum(equation5_with_const.(c1, c2, x1, x2, x3)), - fill(const_value, N), - fill(const_value2, N), - [X[i, :] for i in 1:nfeatures]..., - )[1:2] - true_grad = reduce(hcat, true_grad)' - predicted_grad = eval_grad_tree_array(tree, X, options; variable=false)[2] - - @test array_test(predicted_grad, true_grad) - println("Done.") -end - -println("Testing NodeIndex.") - -using LibraryAugmentedSymbolicRegression: get_scalar_constants, NodeIndex, index_constants - -options = Options(; - binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin) -) -@extend_operators options -tree = equation3(nx1, nx2, nx3) - -"""Check whether the ordering of constant_list is the same as the ordering of node_index.""" -function check_tree( - tree::AbstractExpressionNode, node_index::NodeIndex, constant_list::AbstractVector -) - if tree.degree == 0 - (!tree.constant) || tree.val == constant_list[node_index.val::UInt16] - elseif tree.degree == 1 - check_tree(tree.l, node_index.l, constant_list) - else - check_tree(tree.l, node_index.l, constant_list) && - check_tree(tree.r, node_index.r, constant_list) - end -end - -@test check_tree(tree, index_constants(tree), first(get_scalar_constants(tree))) - -println("Done.") diff --git a/test/test_deterministic.jl b/test/test_deterministic.jl deleted file mode 100644 index deee1e8f7..000000000 --- a/test/test_deterministic.jl +++ /dev/null @@ -1,33 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Random - -X = 2 .* randn(MersenneTwister(0), Float32, 2, 1000) -y = 3 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2 - -options = LibraryAugmentedSymbolicRegression.Options(; - binary_operators=(+, *, /, -), - unary_operators=(cos,), - crossover_probability=0.0, # required for recording, as not set up to track crossovers. - max_evals=10000, - deterministic=true, - seed=0, - verbosity=0, - progress=false, -) - -all_outputs = [] -for i in 1:2 - hall_of_fame = equation_search( - X, - y; - niterations=5, - options=options, - parallelism=:serial, - v_dim_out=Val(1), - return_state=Val(false), - ) - dominating = calculate_pareto_frontier(hall_of_fame) - push!(all_outputs, dominating[end].tree) -end - -@test string(all_outputs[1]) == string(all_outputs[2]) diff --git a/test/test_early_stop.jl b/test/test_early_stop.jl deleted file mode 100644 index 8f64689b3..000000000 --- a/test/test_early_stop.jl +++ /dev/null @@ -1,19 +0,0 @@ -using LibraryAugmentedSymbolicRegression - -X = randn(Float32, 5, 100) -y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 - -early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10)) - -options = LibraryAugmentedSymbolicRegression.Options(; - binary_operators=(+, *, /, -), - unary_operators=(cos, exp), - populations=20, - early_stop_condition=early_stop, -) - -hof = equation_search(X, y; options=options, niterations=1_000_000_000) - -@test any( - early_stop(member.loss, count_nodes(member.tree)) for member in hof.members[hof.exists] -) diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl deleted file mode 100644 index 0c38f0e58..000000000 --- a/test/test_evaluation.jl +++ /dev/null @@ -1,74 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Random -include("test_params.jl") - -# Test simple evaluations: -options = Options(; - default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin) -) - -# Here, we unittest the fast function evaluation scheme -# We need to trigger all possible fused functions, with all their logic. -# These are as follows: - -## We fuse (and compile) the following: -## - op(op2(x, y)), where x, y, z are constants or variables. -## - op(op2(x)), where x is a constant or variable. -## - op(x), for any x. -## We fuse (and compile) the following: -## - op(x, y), where x, y are constants or variables. -## - op(x, y), where x is a constant or variable but y is not. -## - op(x, y), where y is a constant or variable but x is not. -## - op(x, y), for any x or y -for fnc in [ - # deg2_l0_r0_eval - (x1, x2, x3) -> x1 * x2, - (x1, x2, x3) -> x1 * 3.0f0, - (x1, x2, x3) -> 3.0f0 * x2, - (((x1, x2, x3) -> 3.0f0 * 6.0f0), ((x1, x2, x3) -> Node(; val=3.0f0) * 6.0f0)), - # deg2_l0_eval - (x1, x2, x3) -> x1 * sin(x2), - (x1, x2, x3) -> 3.0f0 * sin(x2), - - # deg2_r0_eval - (x1, x2, x3) -> sin(x1) * x2, - (x1, x2, x3) -> sin(x1) * 3.0f0, - - # deg1_l2_ll0_lr0_eval - (x1, x2, x3) -> cos(x1 * x2), - (x1, x2, x3) -> cos(x1 * 3.0f0), - (x1, x2, x3) -> cos(3.0f0 * x2), - ( - ((x1, x2, x3) -> cos(3.0f0 * -0.5f0)), - ((x1, x2, x3) -> cos(Node(; val=3.0f0) * -0.5f0)), - ), - - # deg1_l1_ll0_eval - (x1, x2, x3) -> cos(sin(x1)), - (((x1, x2, x3) -> cos(sin(3.0f0))), ((x1, x2, x3) -> cos(sin(Node(; val=3.0f0))))), - - # everything else: - (x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0f0) * -0.5f0) + 2.0f0) * 5.0f0, -] - - # check if fnc is tuple - if typeof(fnc) <: Tuple - realfnc = fnc[1] - nodefnc = fnc[2] - else - realfnc = fnc - nodefnc = fnc - end - - global tree = nodefnc(Node("x1"), Node("x2"), Node("x3")) - - N = 100 - nfeatures = 3 - X = randn(MersenneTwister(0), Float32, nfeatures, N) - - test_y = eval_tree_array(tree, X, options)[1] - true_y = realfnc.(X[1, :], X[2, :], X[3, :]) - - zero_tolerance = 1e-6 - @test all(abs.(test_y .- true_y) / N .< zero_tolerance) -end diff --git a/test/test_expression_builder.jl b/test/test_expression_builder.jl deleted file mode 100644 index 9de5299db..000000000 --- a/test/test_expression_builder.jl +++ /dev/null @@ -1,66 +0,0 @@ -# This file tests particular functionality of ExpressionBuilderModule -@testitem "ParametricExpression" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression.ExpressionBuilderModule: - strip_metadata, embed_metadata, init_params - - options = Options() - ex = parse_expression( - :(x1 * p1); - expression_type=ParametricExpression, - operators=options.operators, - parameters=ones(2, 1) * 3, - parameter_names=["p1", "p2"], - variable_names=["x1"], - ) - X = ones(1, 1) * 2 - y = ones(1) - dataset = Dataset(X, y; extra=(; classes=[1])) - - @test ex isa ParametricExpression - @test ex(dataset.X, dataset.extra.classes) ≈ ones(1, 1) * 6 - - # Mistake in that we gave the wrong options! - @test_throws( - AssertionError( - "Need prototype to be of type $(options.expression_type), but got $(ex)::$(typeof(ex))", - ), - init_params(options, dataset, ex, Val(true)) - ) - - options = Options(; - expression_type=ParametricExpression, expression_options=(; max_parameters=2) - ) - - # Mistake in that we also gave the wrong number of parameter names! - pop!(ex.metadata.parameter_names) - @test_throws( - AssertionError( - "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(ex.metadata.parameter_names)", - ), - init_params(options, dataset, ex, Val(true)) - ) - # So, we fix it: - push!(ex.metadata.parameter_names, "p2") - - @test ex.metadata.parameter_names == ["p1", "p2"] - @test keys(init_params(options, dataset, ex, Val(true))) == - (:operators, :variable_names, :parameters, :parameter_names) - - @test sprint(show, ex) == "x1 * p1" - stripped_ex = strip_metadata(ex, options, dataset) - # Stripping the metadata means that operations like `show` - # do not know what binary operator to use: - @test sprint(show, stripped_ex) == "binary_operator[4](x1, p1)" - - # However, it's important that parametric expressions are still parametric: - @test stripped_ex isa ParametricExpression - # And, that they still have the right parameters: - @test haskey(getfield(stripped_ex.metadata, :_data), :parameters) - @test stripped_ex.metadata.parameters ≈ ones(2, 1) * 3 - - # Now, test that we can embed metadata back in: - embedded_ex = embed_metadata(stripped_ex, options, dataset) - @test embedded_ex isa ParametricExpression - @test ex == embedded_ex -end diff --git a/test/test_expression_derivatives.jl b/test/test_expression_derivatives.jl deleted file mode 100644 index e9d2a0a89..000000000 --- a/test/test_expression_derivatives.jl +++ /dev/null @@ -1,143 +0,0 @@ -@testitem "Test derivatives" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using Zygote: Zygote - using Random: MersenneTwister - - ex = @parse_expression( - x * x - cos(2.5 * y), - unary_operators = [cos], - binary_operators = [*, -, +], - variable_names = [:x, :y] - ) - - rng = MersenneTwister(0) - X = rand(rng, 2, 32) - - (δy,) = Zygote.gradient(X) do X - x = @view X[1, :] - y = @view X[2, :] - - sum(i -> x[i] * x[i] - cos(2.5 * y[i]), eachindex(x)) - end - δy_hat = ex'(X) - - @test δy ≈ δy_hat - - options2 = Options(; unary_operators=[sin], binary_operators=[+, *, -]) - (δy2,) = Zygote.gradient(X) do X - x = @view X[1, :] - y = @view X[2, :] - - sum(i -> (x[i] + x[i]) * sin(2.5 + y[i]), eachindex(x)) - end - δy2_hat = ex'(X, options2) - - @test δy2 ≈ δy2_hat -end - -@testitem "Test derivatives during optimization" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: - Evaluator, GradEvaluator - using DynamicExpressions - using Zygote: Zygote - using Random: MersenneTwister - using DifferentiationInterface: value_and_gradient - - rng = MersenneTwister(0) - X = rand(rng, 2, 32) - y = @. X[1, :] * X[1, :] - cos(2.6 * X[2, :]) - dataset = Dataset(X, y) - - options = Options(; - unary_operators=[cos], binary_operators=[+, *, -], autodiff_backend=:Zygote - ) - - ex = @parse_expression( - x * x - cos(2.5 * y), operators = options.operators, variable_names = [:x, :y] - ) - f = Evaluator(ex, last(get_scalar_constants(ex)), dataset, options, nothing) - fg! = GradEvaluator(f, options.autodiff_backend) - - @test f(first(get_scalar_constants(ex))) isa Float64 - - x = first(get_scalar_constants(ex)) - G = zero(x) - fg!(nothing, G, x) - @test G[] != 0 -end - -@testitem "Test derivatives of parametric expression during optimization" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: - Evaluator, GradEvaluator, optimize_constants, specialized_options - using DynamicExpressions - using Zygote: Zygote - using Random: MersenneTwister - using DifferentiationInterface: value_and_gradient, AutoZygote, AutoEnzyme - enzyme_compatible = VERSION >= v"1.10.0" && VERSION < v"1.11.0-DEV.0" - @static if enzyme_compatible - using Enzyme: Enzyme - end - - rng = MersenneTwister(0) - X = rand(rng, 2, 32) - true_params = [0.5 2.0] - init_params = [0.1 0.2] - init_constants = [2.5, -0.5] - classes = rand(rng, 1:2, 32) - y = [ - X[1, i] * X[1, i] - cos(2.6 * X[2, i] - 0.2) + true_params[1, classes[i]] for - i in 1:32 - ] - - dataset = Dataset(X, y; extra=(; classes)) - - (true_val, (true_d_params, true_d_constants)) = - value_and_gradient(AutoZygote(), (init_params, init_constants)) do (params, c) - pred = [ - X[1, i] * X[1, i] - cos(c[1] * X[2, i] + c[2]) + params[1, classes[i]] for - i in 1:32 - ] - sum(abs2, pred .- y) / length(y) - end - - options = Options(; - unary_operators=[cos], binary_operators=[+, *, -], autodiff_backend=:Zygote - ) - - ex = @parse_expression( - x * x - cos(2.5 * y + -0.5) + p1, - operators = options.operators, - expression_type = ParametricExpression, - variable_names = ["x", "y"], - extra_metadata = (parameter_names=["p1"], parameters=init_params) - ) - - function test_backend(ex, @nospecialize(backend); allow_failure=false) - x0, refs = get_scalar_constants(ex) - G = zero(x0) - - f = Evaluator(ex, refs, dataset, specialized_options(options), nothing) - fg! = GradEvaluator(f, backend) - - @test f(x0) ≈ true_val - - try - val = fg!(nothing, G, x0) - @test val ≈ true_val - @test G ≈ vcat(true_d_constants[:], true_d_params[:]) - catch e - if allow_failure - @warn "Expected failure" e - else - rethrow(e) - end - end - end - - test_backend(ex, AutoZygote(); allow_failure=false) - @static if enzyme_compatible - test_backend(ex, AutoEnzyme(); allow_failure=true) - end -end diff --git a/test/test_fast_cycle.jl b/test/test_fast_cycle.jl deleted file mode 100644 index 91c7bd147..000000000 --- a/test/test_fast_cycle.jl +++ /dev/null @@ -1,70 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Random -include("test_params.jl") - -options = LibraryAugmentedSymbolicRegression.Options(; - default_params..., - binary_operators=(+, *), - unary_operators=(cos,), - populations=4, - constraints=((*) => (-1, 10), cos => (5)), - fast_cycle=true, - skip_mutation_failures=true, -) -X = randn(MersenneTwister(0), Float32, 5, 100) -y = 2 * cos.(X[4, :]) .- X[2, :] -variable_names = ["t1", "t2", "t3", "t4", "t5"] -state, hall_of_fame = equation_search( - X, y; variable_names=variable_names, niterations=2, options=options, return_state=true -) -dominating = calculate_pareto_frontier(hall_of_fame) - -best = dominating[end] - -# Test the score -@test best.loss < maximum_residual / 10 - -# Do search again, but with saved state: -# We do 0 iterations to make sure the state is used. -println("Passed.") -println("Testing whether state saving works.") -new_state, new_hall_of_fame = equation_search( - X, - y; - variable_names=variable_names, - niterations=0, - options=options, - saved_state=(deepcopy(state), deepcopy(hall_of_fame)), - return_state=true, -) - -dominating = calculate_pareto_frontier(new_hall_of_fame) -best = dominating[end] -print_tree(best.tree, options) -@test best.loss < maximum_residual / 10 - -println("Testing whether state saving works with changed loss function.") -previous_loss = best.loss -new_loss(x, y) = sum(abs2, x - y) * 0.1 -options = LibraryAugmentedSymbolicRegression.Options(; - default_params..., - binary_operators=(+, *), - unary_operators=(cos,), - populations=4, - constraints=((*) => (-1, 10), cos => (5)), - fast_cycle=true, - skip_mutation_failures=true, - elementwise_loss=new_loss, -) -state, hall_of_fame = equation_search( - X, - y; - variable_names=variable_names, - niterations=0, - options=options, - saved_state=(state, hall_of_fame), - return_state=true, -) -dominating = calculate_pareto_frontier(hall_of_fame) -best = dominating[end] -@test best.loss ≈ previous_loss * 0.1 diff --git a/test/test_graph_nodes.jl b/test/test_graph_nodes.jl deleted file mode 100644 index 4f3ebb2fa..000000000 --- a/test/test_graph_nodes.jl +++ /dev/null @@ -1,135 +0,0 @@ -@testitem "GraphNode evaluation" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - - options = Options(; - binary_operators=[+, -, *, /], unary_operators=[cos, sin], maxsize=30 - ) - - x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] - - base_tree = cos(x1 - 3.2) * x2 - x3 * copy(x3) - tree = sin(base_tree) + base_tree - - X = randn(3, 50) - z = @. cos(X[1, :] - 3.2) * X[2, :] - X[3, :] * X[3, :] - y = @. sin(z) + z - dataset = Dataset(X, y) - - tree(dataset.X, options) - - eval_tree_array(tree, dataset.X, options) -end - -@testitem "GraphNode complexity" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - - options = Options(; - binary_operators=[+, -, *, /], unary_operators=[cos, sin], maxsize=30 - ) - x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] - - base_tree = cos(x1 - 3.2) * x2 - x3 * copy(x3) - tree = sin(base_tree) + base_tree - - @test compute_complexity(tree, options) == 12 - @test compute_complexity(tree, options; break_sharing=Val(true)) == 22 -end - -@testitem "GraphNode population" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - - options = Options(; - binary_operators=[+, -, *, /], - unary_operators=[cos, sin], - maxsize=30, - node_type=GraphNode, - ) - - X = randn(3, 50) - z = @. cos(X[1, :] - 3.2) * X[2, :] - X[3, :] * X[3, :] - y = @. sin(z) + z - dataset = Dataset(X, y) - - pop = Population(dataset; options, nlength=3, nfeatures=3, population_size=100) - @test pop isa Population{T,T,<:Expression{T,<:GraphNode{T}}} where {T} - - # Seems to not work yet: - # equation_search([dataset]; niterations=10, options) -end - -@testitem "GraphNode break connection mutation" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression.MutationFunctionsModule: - break_random_connection! - using Random: MersenneTwister - - options = Options(; - binary_operators=[+, -, *, /], - unary_operators=[cos, sin], - maxsize=30, - node_type=GraphNode, - ) - - x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] - base_tree = cos(x1 - 3.2) * x2 - tree = sin(base_tree) + base_tree - - ex = Expression(tree; operators=options.operators, variable_names=["x1", "x2", "x3"]) - - s = strip(sprint(print_tree, ex)) - @test s == "sin(cos(x1 - 3.2) * x2) + {(cos(x1 - 3.2) * x2)}" - - rng = MersenneTwister(0) - expressions = [copy(ex) for _ in 1:1000] - expressions = [break_random_connection!(ex, rng) for ex in expressions] - strings = [strip(sprint(print_tree, ex)) for ex in expressions] - strings = unique(strings) - @test Set(strings) == Set([ - "sin(cos(x1 - 3.2) * x2) + {(cos(x1 - 3.2) * x2)}", - "sin(cos(x1 - 3.2) * x2) + (cos(x1 - 3.2) * x2)", - ]) - # Either it breaks the connection or not -end - -@testitem "GraphNode form connection mutation" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression.MutationFunctionsModule: - form_random_connection! - using Random: MersenneTwister - - options = Options(; - binary_operators=[+, -, *, /], - unary_operators=[cos, sin], - maxsize=30, - node_type=GraphNode, - ) - - x1, x2 = [GraphNode{Float64}(; feature=i) for i in 1:2] - - tree = cos(x1 * x2 + 1.5) - ex = Expression(tree; operators=options.operators, variable_names=["x1", "x2"]) - rng = MersenneTwister(0) - expressions = [copy(ex) for _ in 1:3_000] - expressions = [form_random_connection!(ex, rng) for ex in expressions] - strings = [strip(sprint(print_tree, ex)) for ex in expressions] - strings = sort(unique(strings); by=length) - - # All possible connections that can be made - @test Set(strings) == Set([ - "cos(x1)", - "cos(x2)", - "cos(1.5)", - "cos(x1 * x2)", - "cos(x2 + 1.5)", - "cos(x1 + 1.5)", - "cos(1.5 + {1.5})", - "cos((x1 * x2) + 1.5)", - "cos((x1 * x2) + {x2})", - "cos((x1 * x2) + {x1})", - "cos((x2 * {x2}) + 1.5)", - "cos((x1 * {x1}) + 1.5)", - "cos((x1 * 1.5) + {1.5})", - "cos((1.5 * x2) + {1.5})", - "cos((x1 * x2) + {(x1 * x2)})", - ]) -end diff --git a/test/test_hash.jl b/test/test_hash.jl deleted file mode 100644 index a8c86db4b..000000000 --- a/test/test_hash.jl +++ /dev/null @@ -1,7 +0,0 @@ -using LibraryAugmentedSymbolicRegression - -options = Options(; binary_operators=(+, *, ^, /, greater), unary_operators=(cos,)) -@extend_operators options -tree = Node(3, (^)(Node(; val=3.0) * Node(1, Node("x1")), 2.0), Node(; val=-1.2)) -x = hash(tree) -@test typeof(x) == UInt diff --git a/test/test_integer_evaluation.jl b/test/test_integer_evaluation.jl deleted file mode 100644 index bd313ff0d..000000000 --- a/test/test_integer_evaluation.jl +++ /dev/null @@ -1,23 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Random -include("test_params.jl") - -# Test evaluation on integer-based trees. -options = Options(; - default_params..., binary_operators=(+, *, /, -), unary_operators=(square,) -) - -nodefnc(x1, x2, x3) = x2 * x3 + Int32(2) - square(x1) - -x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") -tree = nodefnc(x1, x2, x3) - -tree = convert(Node{Int32}, tree) -X = Int32.(rand(MersenneTwister(0), -5:5, 3, 100)) - -true_out = nodefnc.(X[1, :], X[2, :], X[3, :]) -@test eltype(true_out) == Int32 -out, flag = eval_tree_array(tree, X, options) -@test flag -@test isapprox(out, true_out) -@test eltype(out) == Int32 diff --git a/test/test_jet.jl b/test/test_jet.jl index 720a11aa2..a0ecd5154 100644 --- a/test/test_jet.jl +++ b/test/test_jet.jl @@ -1,7 +1,7 @@ +# Copied over from SymbolicRegression.jl if !(VERSION >= v"1.10.0" && VERSION < v"1.11.0-DEV.0") exit(0) end -# TODO: Check why is breaking on 1.11.0 dir = mktempdir() diff --git a/test/test_lasr_integration.jl b/test/test_lasr_integration.jl deleted file mode 100644 index f40cfbc52..000000000 --- a/test/test_lasr_integration.jl +++ /dev/null @@ -1,31 +0,0 @@ -using LibraryAugmentedSymbolicRegression: LLMOptions, Options - -# test that we can partially specify LLMOptions -op1 = LLMOptions(; active=false) -@test op1.active == false - -# test that we can fully specify LLMOptions -op2 = LLMOptions(; - active=true, - weights=LLMWeights(; llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2), - num_pareto_context=5, - prompt_evol=true, - prompt_concepts=true, - api_key="vllm_api.key", - model="modelx", - api_kwargs=Dict("url" => "http://localhost:11440/v1"), - http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), - llm_recorder_dir="test/", - llm_context="test", - var_order=nothing, - idea_threshold=30, -) -@test op2.active == true - -# test that we can pass LLMOptions to Options -llm_opt = LLMOptions(; active=false) -op = Options(; - optimizer_options=(iterations=16, f_calls_limit=100, x_tol=1e-16), llm_options=llm_opt -) -@test isa(op.llm_options, LLMOptions) -println("Passed.") diff --git a/test/test_lasr_llmgenerate.jl b/test/test_lasr_llmgenerate.jl new file mode 100644 index 000000000..d12a832cb --- /dev/null +++ b/test/test_lasr_llmgenerate.jl @@ -0,0 +1,40 @@ +# using LibraryAugmentedSymbolicRegression: LLMOptions, Options +# """ +# Test if the wrappers around the llm generation commands process the data properly. + +# Instead of importing aigenerate from the PromptingTools package, we will use a custom aigenerate +# function that will return a predetermined string and a predetermined success value. + +# Each stringparsing function will be required to properly convert the predetermined generation +# string into the correct success value. +# """ + +# # # test that we can partially specify LLMOptions +# # op1 = LLMOptions(; use_llm=false) +# # @test op1.use_llm == false + +# # # test that we can fully specify LLMOptions +# # op2 = LLMOptions(; +# # use_llm=true, +# # lasr_weights=LLMWeights(; llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2), +# # num_pareto_context=5, +# # use_concept_evolution=true, +# # use_concepts=true, +# # api_key="vllm_api.key", +# # model="modelx", +# # api_kwargs=Dict("url" => "http://localhost:11440/v1"), +# # http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), +# # llm_recorder_dir="test/", +# # llm_context="test", +# # variable_names=nothing, +# # max_concepts=30, +# # ) +# # @test op2.use_llm == true + +# # # test that we can pass LLMOptions to Options +# # llm_opt = LLMOptions(; use_llm=false) +# # op = Options(; +# # optimizer_options=(iterations=16, f_calls_limit=100, x_tol=1e-16), llm_options=llm_opt +# # ) +# # @test isa(op.llm_options, LLMOptions) +# # println("Passed.") diff --git a/test/test_lasr_parser.jl b/test/test_lasr_parser.jl new file mode 100644 index 000000000..0ecf51a27 --- /dev/null +++ b/test/test_lasr_parser.jl @@ -0,0 +1,33 @@ +# LaSR needs a parser to convert LLM-generated expression strings into DynamicExpressions compatible trees. +# These are round trip tests to ensure that the parser is working correctly. +println("Testing LaSR expression parser") + +using Random: MersenneTwister +using LibraryAugmentedSymbolicRegression: + LaSROptions, string_tree, parse_expr, render_expr, gen_random_tree +include("test_params.jl") +options = LaSROptions(; + default_params..., binary_operators=[+, *, ^, -], unary_operators=[sin, cos, exp] +) + +rng = MersenneTwister(314159) + +for depth in [5, 10] + for nvar in [5, 10] + random_trees = [gen_random_tree(depth, options, nvar, Float32, rng) for _ in 1:1e3] + data = rand(Float32, nvar, 1000) + + for tree in random_trees + output = tree(data, options.operators) + if any(isnan.(output)) + continue + end + str_tree = string_tree(tree, options) + @test str_tree == String(strip(str_tree, [' ', '\n', '"', ',', '.', '[', ']'])) + expr_tree = parse_expr(str_tree, options, Float32) + expr_output = expr_tree(data, options.operators) + @test isapprox(expr_output, output) + end + end +end +println("Passed.") diff --git a/test/test_losses.jl b/test/test_losses.jl deleted file mode 100644 index f6c4cbc60..000000000 --- a/test/test_losses.jl +++ /dev/null @@ -1,58 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: eval_loss -using Random -include("test_params.jl") - -_loss = LibraryAugmentedSymbolicRegression.LossFunctionsModule._loss -_weighted_loss = LibraryAugmentedSymbolicRegression.LossFunctionsModule._weighted_loss - -customloss(x, y) = abs(x - y)^2.5 -customloss(x, y, w) = w * (abs(x - y)^2.5) -testl1(x, y) = abs(x - y) -testl1(x, y, w) = abs(x - y) * w - -for (loss_fnc, evaluator) in [(L1DistLoss(), testl1), (customloss, customloss)] - local options = Options(; - default_params..., - binary_operators=(+, *, -, /), - unary_operators=(cos, exp), - populations=4, - elementwise_loss=loss_fnc, - ) - x = randn(MersenneTwister(0), Float32, 100) - y = randn(MersenneTwister(1), Float32, 100) - w = abs.(randn(MersenneTwister(2), Float32, 100)) - @test abs(_loss(x, y, options.elementwise_loss) - sum(evaluator.(x, y)) / length(x)) < - 1e-6 - @test abs( - _weighted_loss(x, y, w, options.elementwise_loss) - - sum(evaluator.(x, y, w)) / sum(w), - ) < 1e-6 -end - -function custom_objective_batched( - tree::AbstractExpressionNode{T}, dataset::Dataset{T,L}, options, ::Nothing -) where {T,L} - return one(T) -end -function custom_objective_batched( - tree::AbstractExpressionNode{T}, dataset::Dataset{T,L}, options, idx -) where {T,L} - return sum(dataset.X[:, idx]) -end -let options = Options(; binary_operators=[+, *], loss_function=custom_objective_batched), - d = Dataset(randn(3, 10), randn(10)) - - @test eval_loss(Node(; val=1.0), d, options) === 1.0 - @test eval_loss(Node(; val=1.0), d, options; idx=[1, 2]) == sum(d.X[:, [1, 2]]) -end - -custom_objective_bad_batched(tree, dataset, options) = sum(dataset.X) - -let options = Options(; - binary_operators=[+, *], loss_function=custom_objective_bad_batched, batching=true - ), - d = Dataset(randn(3, 10), randn(10)) - - @test_throws ErrorException eval_loss(Node(; val=1.0), d, options; idx=[1, 2]) -end diff --git a/test/test_migration.jl b/test/test_migration.jl deleted file mode 100644 index 785115f38..000000000 --- a/test/test_migration.jl +++ /dev/null @@ -1,32 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: strip_metadata -using DynamicExpressions: get_tree -using Test -using Random: seed! - -seed!(0) - -X = randn(5, 100); -y = X[2, :] .* 3.2 .+ X[3, :] .+ 2.0; - -options = Options(); -population1 = Population( - X, y; population_size=100, options=options, nfeatures=5, nlength=10 -) -dataset = Dataset(X, y) - -tree = Node(1, Node(; val=1.0), Node(; feature=2) * 3.2) - -@test !(hash(tree) in [hash(p.tree) for p in population1.members]) - -ex = @parse_expression($tree, operators = options.operators, variable_names = [:x1, :x2],) -ex = strip_metadata(ex, options, dataset) - -LibraryAugmentedSymbolicRegression.MigrationModule.migrate!( - [PopMember(ex, 0.0, Inf, options; deterministic=false)] => population1, - options; - frac=0.5, -) - -# Now we see that the tree is in the population: -@test tree in [get_tree(p.tree) for p in population1.members] diff --git a/test/test_mixed.jl b/test/test_mixed.jl deleted file mode 100644 index a2d047166..000000000 --- a/test/test_mixed.jl +++ /dev/null @@ -1,39 +0,0 @@ -@testitem "Search with batching & weighted & serial & progress bar & warmup & BFGS" tags = [ - :part1 -] begin - include("test_mixed_utils.jl") - test_mixed(0, true, true, :serial) -end - -@testitem "Search with multiprocessing & batching & multi-output & use_frequency & string-specified parallelism" tags = [ - :part2 -] begin - include("test_mixed_utils.jl") - test_mixed(1, true, false, :multiprocessing) -end - -@testitem "Search with multi-threading & default settings" tags = [:part3] begin - include("test_mixed_utils.jl") - test_mixed(2, false, true, :multithreading) -end - -@testitem "Search with multi-threading & weighted & crossover & use_frequency_in_tournament & bumper" tags = [ - :part1 -] begin - include("test_mixed_utils.jl") - test_mixed(3, false, false, :multithreading) -end - -@testitem "Search with multi-threading & crossover & skip mutation failures & both frequencies options & Float16 type" tags = [ - :part2 -] begin - include("test_mixed_utils.jl") - test_mixed(4, false, false, :multithreading) -end - -@testitem "Search with multiprocessing & default hyperparameters & Float64 type & turbo" tags = [ - :part3 -] begin - include("test_mixed_utils.jl") - test_mixed(5, false, false, :multiprocessing) -end diff --git a/test/test_mixed_utils.jl b/test/test_mixed_utils.jl deleted file mode 100644 index 31fa5e6c3..000000000 --- a/test/test_mixed_utils.jl +++ /dev/null @@ -1,144 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: string_tree -using Random, Bumper, LoopVectorization - -include("test_params.jl") - -function test_mixed(i, batching::Bool, weighted::Bool, parallelism) - progress = false - warmup_maxsize_by = 0.0f0 - optimizer_algorithm = "NelderMead" - multi = false - tournament_selection_p = 1.0 - crossover_probability = 0.0f0 - skip_mutation_failures = false - use_frequency = false - use_frequency_in_tournament = false - turbo = false - bumper = false - T = Float32 - - niterations = 2 - if i == 0 - progress = true #Also try the progress bar. - warmup_maxsize_by = 0.5f0 #Smaller maxsize at first, build up slowly - optimizer_algorithm = "BFGS" - tournament_selection_p = 0.8 - elseif i == 1 - multi = true - use_frequency = true - elseif i == 3 - crossover_probability = 0.02f0 - use_frequency_in_tournament = true - bumper = true - elseif i == 4 - crossover_probability = 0.02f0 - skip_mutation_failures = true - use_frequency = true - use_frequency_in_tournament = true - T = Float16 - elseif i == 5 - T = Float64 - turbo = true - niterations = 5 - end - - numprocs = parallelism == :multiprocessing ? 2 : nothing - - options = if i == 5 - LibraryAugmentedSymbolicRegression.Options(; - unary_operators=(cos,), - batching=batching, - parsimony=0.0f0, # Required for scoring - early_stop_condition=1e-6, - ) - else - LibraryAugmentedSymbolicRegression.Options(; - default_params..., - binary_operators=(+, *), - unary_operators=(cos,), - populations=4, - batching=batching, - crossover_probability=crossover_probability, - skip_mutation_failures=skip_mutation_failures, - seed=0, - progress=progress, - warmup_maxsize_by=warmup_maxsize_by, - optimizer_algorithm=optimizer_algorithm, - tournament_selection_p=tournament_selection_p, - parsimony=0.0f0, - use_frequency=use_frequency, - use_frequency_in_tournament=use_frequency_in_tournament, - turbo=turbo, - bumper=bumper, - early_stop_condition=1e-6, - ) - end - - X = randn(MersenneTwister(0), T, 5, 100) - - (y, hallOfFame, dominating) = if weighted - mask = rand(100) .> 0.5 - weights = map(x -> convert(T, x), mask) - # Completely different function superimposed - need - # to use correct weights to figure it out! - y = (2 .* cos.(X[4, :])) .* weights .+ (1 .- weights) .* (5 .* X[2, :]) - hallOfFame = equation_search( - X, - y; - weights=weights, - niterations=niterations, - options=options, - parallelism=parallelism, - numprocs=numprocs, - ) - dominating = [calculate_pareto_frontier(hallOfFame)] - - (y, hallOfFame, dominating) - else - y = 2 * cos.(X[4, :]) - niterations = niterations - if multi - # Copy the same output twice; make sure we can find it twice - y = repeat(y, 1, 2) - y = transpose(y) - niterations = 20 - end - hallOfFame = equation_search( - X, - y; - niterations=niterations, - options=options, - parallelism=parallelism, - numprocs=numprocs, - ) - dominating = if multi - [calculate_pareto_frontier(hallOfFame[j]) for j in 1:2] - else - [calculate_pareto_frontier(hallOfFame)] - end - - (y, hallOfFame, dominating) - end - - # For brevity, always assume multi-output in this test: - for dom in dominating - @test length(dom) > 0 - best = dom[end] - # Assert we created the correct type of trees: - @test node_type(typeof(best.tree)) == Node{T} - - # Test the score - @test best.loss < maximum_residual - # Test the actual equation found: - testX = randn(MersenneTwister(1), T, 5, 100) - true_y = 2 * cos.(testX[4, :]) - predicted_y, flag = eval_tree_array(best.tree, testX, options) - - @test flag - @test sum(abs, true_y .- predicted_y) < maximum_residual - # eval evaluates inside global - end - - return println("Passed.") -end diff --git a/test/test_mlj.jl b/test/test_mlj.jl deleted file mode 100644 index 62f5d6ea8..000000000 --- a/test/test_mlj.jl +++ /dev/null @@ -1,207 +0,0 @@ -@testitem "Generic interface tests" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using MLJTestInterface: MLJTestInterface as MTI - include("test_params.jl") - - failures, summary = MTI.test( - [LaSRRegressor], MTI.make_regression()...; mod=@__MODULE__, verbosity=0, throw=true - ) - @test isempty(failures) - - X = randn(100, 3) - Y = @. cos(X^2) * 3.2 - 0.5 - (X, Y) = MTI.table.((X, Y)) - w = ones(100) - failures, summary = MTI.test( - [MultitargetLaSRRegressor], X, Y, w; mod=@__MODULE__, verbosity=0, throw=true - ) - @test isempty(failures) -end - -@testitem "Variable names - single outputs" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression: Node - using MLJBase - using SymbolicUtils - using Random: MersenneTwister - - include("test_params.jl") - - stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) - - rng = MersenneTwister(0) - X = (a=rand(rng, 32), b=rand(rng, 32)) - y = X.a .^ 2.1 - # We also make sure the deprecated npop and npopulations still work: - model = LaSRRegressor(; niterations=10, npop=1000, npopulations=15, stop_kws...) - mach = machine(model, X, y) - fit!(mach) - rep = report(mach) - @test occursin("a", rep.equation_strings[rep.best_idx]) - ypred_good = predict(mach, X) - @test sum(abs2, predict(mach, X) .- y) / length(y) < 1e-5 - - # Check that we can choose the equation - ypred_same = predict(mach, (data=X, idx=rep.best_idx)) - @test ypred_good == ypred_same - - ypred_bad = predict(mach, (data=X, idx=1)) - @test ypred_good != ypred_bad - - # Smoke test SymbolicUtils - eqn = node_to_symbolic(rep.equations[rep.best_idx], model) - n = symbolic_to_node(eqn, model) - eqn2 = convert(SymbolicUtils.Symbolic, n, model) - n2 = convert(Node, eqn2, model) -end - -@testitem "Variable names - multiple outputs" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using MLJBase - using Random: MersenneTwister - - include("test_params.jl") - - stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) - - rng = MersenneTwister(0) - X = (a=rand(rng, 32), b=rand(rng, 32)) - y = X.a .^ 2.1 - model = MultitargetLaSRRegressor(; niterations=10, stop_kws...) - mach = machine(model, X, reduce(hcat, [reshape(y, :, 1) for i in 1:3])) - fit!(mach) - rep = report(mach) - @test all( - eq -> occursin("a", eq), [rep.equation_strings[i][rep.best_idx[i]] for i in 1:3] - ) - ypred_good = predict(mach, X) - - # Test that we can choose the equation - ypred_same = predict(mach, (data=X, idx=rep.best_idx)) - @test ypred_good == ypred_same - - ypred_bad = predict(mach, (data=X, idx=[1, 1, 1])) - @test ypred_good != ypred_bad - - ypred_mixed = predict(mach, (data=X, idx=[rep.best_idx[1], 1, rep.best_idx[3]])) - @test ypred_mixed == hcat(ypred_good[:, 1], ypred_bad[:, 2], ypred_good[:, 3]) - - @test_throws AssertionError predict(mach, (data=X,)) - VERSION >= v"1.8" && - @test_throws "If specifying an equation index during" predict(mach, (data=X,)) - VERSION >= v"1.8" && - @test_throws "If specifying an equation index during" predict(mach, (X=X, idx=1)) -end - -@testitem "Variable names - named outputs" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using MLJBase - using Random: MersenneTwister - - include("test_params.jl") - - stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) - - rng = MersenneTwister(0) - X = (b1=randn(rng, 32), b2=randn(rng, 32)) - Y = (c1=X.b1 .* X.b2, c2=X.b1 .+ X.b2) - w = ones(32) - model = MultitargetLaSRRegressor(; niterations=10, stop_kws...) - mach = machine(model, X, Y, w) - fit!(mach) - test_outs = predict(mach, X) - @test isempty(setdiff((:c1, :c2), keys(test_outs))) - @test_throws AssertionError predict(mach, (a1=randn(32), b2=randn(32))) - VERSION >= v"1.8" && @test_throws "Variable names do not match fitted" predict( - mach, (b1=randn(32), a2=randn(32)) - ) -end - -@testitem "Good predictions" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using MLJBase - using Random: MersenneTwister - - include("test_params.jl") - - stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) - - rng = MersenneTwister(0) - X = randn(rng, 100, 3) - Y = X - model = MultitargetLaSRRegressor(; niterations=10, stop_kws...) - mach = machine(model, X, Y) - fit!(mach) - @test sum(abs2, predict(mach, X) .- Y) / length(X) < 1e-6 -end - -@testitem "Helpful errors" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using MLJBase - using Random: MersenneTwister - - include("test_params.jl") - - model = MultitargetLaSRRegressor() - rng = MersenneTwister(0) - mach = machine(model, randn(rng, 32, 3), randn(rng, 32); scitype_check_level=0) - @test_throws AssertionError @quiet(fit!(mach)) - VERSION >= v"1.8" && - @test_throws "For single-output regression, please" @quiet(fit!(mach)) - - model = LaSRRegressor() - rng = MersenneTwister(0) - mach = machine(model, randn(rng, 32, 3), randn(rng, 32, 2); scitype_check_level=0) - @test_throws AssertionError @quiet(fit!(mach)) - VERSION >= v"1.8" && - @test_throws "For multi-output regression, please" @quiet(fit!(mach)) - - model = LaSRRegressor(; verbosity=0) - rng = MersenneTwister(0) - mach = machine(model, randn(rng, 32, 3), randn(rng, 32)) - @test_throws ErrorException @quiet(fit!(mach; verbosity=0)) -end - -@testitem "Unfinished search" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using MLJBase - using Suppressor - using Random: MersenneTwister - - model = LaSRRegressor(; timeout_in_seconds=1e-10) - rng = MersenneTwister(0) - mach = machine(model, randn(rng, 32, 3), randn(rng, 32)) - fit!(mach) - # Ensure that the hall of fame is empty: - _, hof = mach.fitresult.state - hof.exists .= false - # Recompute the report: - mach.report[:fit] = LibraryAugmentedSymbolicRegression.MLJInterfaceModule.full_report( - model, mach.fitresult - ) - @test report(mach).best_idx == 0 - @test predict(mach, randn(32, 3)) == zeros(32) - msg = @capture_err begin - predict(mach, randn(32, 3)) - end - @test occursin("Evaluation failed either due to", msg) - - model = MultitargetLaSRRegressor(; timeout_in_seconds=1e-10) - rng = MersenneTwister(0) - mach = machine(model, randn(rng, 32, 3), randn(rng, 32, 3)) - fit!(mach) - # Ensure that the hall of fame is empty: - _, hofs = mach.fitresult.state - foreach(hofs) do hof - hof.exists .= false - end - mach.report[:fit] = LibraryAugmentedSymbolicRegression.MLJInterfaceModule.full_report( - model, mach.fitresult - ) - @test report(mach).best_idx == [0, 0, 0] - @test predict(mach, randn(32, 3)) == zeros(32, 3) - msg = @capture_err begin - predict(mach, randn(32, 3)) - end - @test occursin("Evaluation failed either due to", msg) -end diff --git a/test/test_nan_detection.jl b/test/test_nan_detection.jl deleted file mode 100644 index c1a104d1a..000000000 --- a/test/test_nan_detection.jl +++ /dev/null @@ -1,51 +0,0 @@ -println("Testing NaN detection.") -using LibraryAugmentedSymbolicRegression -using LoopVectorization - -for T in [Float16, Float32, Float64], turbo in [true, false] - T == Float16 && turbo && continue - local options, tree, X - - options = Options(; - binary_operators=(+, *, /, -, ^), unary_operators=(cos, sin, exp, sqrt), turbo=turbo - ) - @extend_operators options - # Creating a NaN via computation. - tree = exp(exp(exp(exp(Node(T; feature=1) + 1)))) - tree = convert(Node{T}, tree) - X = fill(T(100), 1, 10) - output, flag = eval_tree_array(tree, X, options) - @test !flag - - # Creating a NaN/Inf via division by constant zero. - tree = cos(Node(T; feature=1) / zero(T)) - tree = convert(Node{T}, tree) - output, flag = eval_tree_array(tree, X, options) - @test !flag - - # Creating a NaN via sqrt(-1): - tree = safe_sqrt(Node(T; feature=1) - 1) - tree = convert(Node{T}, tree) - X = fill(T(0), 1, 10) - output, flag = eval_tree_array(tree, X, options) - @test !flag - - # Creating a NaN via pow(-1, 0.5): - tree = (^)(Node(T; feature=1) - 1, 0.5) - tree = convert(Node{T}, tree) - X = fill(T(0), 1, 10) - output, flag = eval_tree_array(tree, X, options) - @test !flag - - # Having a NaN/Inf constants: - tree = cos(Node(T; feature=1) + T(Inf)) - tree = convert(Node{T}, tree) - output, flag = eval_tree_array(tree, X, options) - @test !flag - tree = cos(Node(T; feature=1) + T(NaN)) - tree = convert(Node{T}, tree) - output, flag = eval_tree_array(tree, X, options) - @test !flag -end - -println("Passed.") diff --git a/test/test_nested_constraints.jl b/test/test_nested_constraints.jl deleted file mode 100644 index c4527b3ed..000000000 --- a/test/test_nested_constraints.jl +++ /dev/null @@ -1,66 +0,0 @@ -println("Test operator nesting and flagging.") -using LibraryAugmentedSymbolicRegression - -function create_options(nested_constraints) - return Options(; - binary_operators=(+, *, /, -), - unary_operators=(cos, exp), - nested_constraints=nested_constraints, - ) -end - -options = create_options(nothing) -# Count max nests: -tree = cos(exp(exp(exp(exp(Node("x1")))))) -degree_of_exp = 1 -index_of_exp = findfirst(isequal(exp), options.operators.unaops) -@test 4 == LibraryAugmentedSymbolicRegression.CheckConstraintsModule.count_max_nestedness( - tree, degree_of_exp, index_of_exp -) - -tree = cos(exp(Node("x1")) + exp(exp(exp(exp(Node("x1")))))) -@test 4 == LibraryAugmentedSymbolicRegression.CheckConstraintsModule.count_max_nestedness( - tree, degree_of_exp, index_of_exp -) - -degree_of_plus = 2 -index_of_plus = findfirst(isequal(+), options.operators.binops) -tree = cos(exp(Node("x1")) + exp(exp(Node("x1") + exp(exp(exp(Node("x1"))))))) -@test 2 == LibraryAugmentedSymbolicRegression.CheckConstraintsModule.count_max_nestedness( - tree, degree_of_plus, index_of_plus -) - -# Test checking for illegal nests: -x1 = Node("x1") -options = create_options(nothing) -tree = cos(cos(x1)) + cos(x1) + exp(cos(x1)) -@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests( - tree, options -) - -options = create_options([cos => [cos => 0]]) -@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests( - tree, options -) - -options = create_options([cos => [cos => 1]]) -@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests( - tree, options -) - -options = create_options([cos => [exp => 0]]) -@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests( - tree, options -) - -options = create_options([exp => [cos => 0]]) -@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests( - tree, options -) - -options = create_options([(+) => [(+) => 0]]) -@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests( - tree, options -) - -println("Passed.") diff --git a/test/test_operators.jl b/test/test_operators.jl deleted file mode 100644 index 384d4c668..000000000 --- a/test/test_operators.jl +++ /dev/null @@ -1,152 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: - plus, - sub, - mult, - square, - cube, - safe_pow, - safe_log, - safe_log2, - safe_log10, - safe_sqrt, - safe_acosh, - neg, - greater, - cond, - relu, - logical_or, - logical_and, - gamma -using Random: MersenneTwister -using Suppressor: @capture_err -using LoopVectorization -include("test_params.jl") - -@testset "Generic operator tests" begin - types_to_test = [Float16, Float32, Float64, BigFloat] - for T in types_to_test - val = T(0.5) - val2 = T(3.2) - @test abs(safe_log(val) - log(val)) < 1e-6 - @test isnan(safe_log(-val)) - @test abs(safe_log2(val) - log2(val)) < 1e-6 - @test isnan(safe_log2(-val)) - @test abs(safe_log10(val) - log10(val)) < 1e-6 - @test isnan(safe_log10(-val)) - @test abs(safe_log1p(val) - log1p(val)) < 1e-6 - @test abs(safe_acosh(val2) - acosh(val2)) < 1e-6 - @test isnan(safe_acosh(-val2)) - @test neg(-val) == val - @test safe_sqrt(val) == sqrt(val) - @test isnan(safe_sqrt(-val)) - @test mult(val, val2) == val * val2 - @test plus(val, val2) == val + val2 - @test sub(val, val2) == val - val2 - @test square(val) == val * val - @test cube(val) == val * val * val - @test isnan(safe_pow(T(0.0), -T(1.0))) - @test isnan(safe_pow(-val, val2)) - @test all(isnan.([safe_pow(-val, -val2), safe_pow(T(0.0), -val2)])) - @test abs(safe_pow(val, val2) - val^val2) < 1e-6 - @test abs(safe_pow(val, -val2) - val^(-val2)) < 1e-6 - @test !isnan(safe_pow(T(-1.0), T(2.0))) - @test isnan(safe_pow(T(-1.0), T(2.1))) - @test isnan(safe_log(zero(T))) - @test isnan(safe_log2(zero(T))) - @test isnan(safe_log10(zero(T))) - @test isnan(safe_log1p(T(-2.0))) - @test greater(val, val2) == T(0.0) - @test greater(val2, val) == T(1.0) - @test relu(-val) == T(0.0) - @test relu(val) == val - @test logical_or(val, val2) == T(1.0) - @test logical_or(T(0.0), val2) == T(1.0) - @test logical_and(T(0.0), val2) == T(0.0) - - @inferred cond(val, val2) - @test cond(val, val2) == val2 - @test cond(-val, val2) == zero(T) - end -end - -@testset "Test built-in operators pass validation" begin - types_to_test = [Float16, Float32, Float64, BigFloat] - options = Options(; - binary_operators=[plus, sub, mult, /, ^, greater, logical_or, logical_and, cond], - unary_operators=[ - square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu - ], - ) - for T in types_to_test - @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined( - T, options - ) - end -end - -@testset "Test built-in operators pass validation for complex numbers" begin - types_to_test = [ComplexF16, ComplexF32, ComplexF64] - options = Options(; - binary_operators=[plus, sub, mult, /, ^], - unary_operators=[square, cube, log, log2, log10, log1p, sqrt, acosh, neg], - ) - for T in types_to_test - @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined( - T, options - ) - end -end - -@testset "Test incompatibilities are caught" begin - options = Options(; binary_operators=[greater]) - @test_throws ErrorException LibraryAugmentedSymbolicRegression.assert_operators_well_defined( - ComplexF64, options - ) - VERSION >= v"1.8" && - @test_throws "complex plane" LibraryAugmentedSymbolicRegression.assert_operators_well_defined( - ComplexF64, options - ) -end - -@testset "Operators which return the wrong type should fail" begin - my_bad_op(x) = 1.0f0 - options = Options(; binary_operators=[], unary_operators=[my_bad_op]) - @test_throws ErrorException LibraryAugmentedSymbolicRegression.assert_operators_well_defined( - Float64, options - ) - VERSION >= v"1.8" && - @test_throws "returned an output of type" LibraryAugmentedSymbolicRegression.assert_operators_well_defined( - Float64, options - ) - @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined( - Float32, options - ) -end - -@testset "Turbo mode should be the same" begin - binary_operators = [plus, sub, mult, /, ^, greater, logical_or, logical_and, cond] - unary_operators = [square, cube, log, log2, log10, log1p, sqrt, atanh, acosh, neg, relu] - options = Options(; binary_operators, unary_operators) - for T in (Float32, Float64), - index_bin in 1:length(binary_operators), - index_una in 1:length(unary_operators) - - x1, x2 = Node(T; feature=1), Node(T; feature=2) - tree = Node(index_bin, x1, Node(index_una, x2)) - X = rand(MersenneTwister(0), T, 2, 20) - for seed in 1:20 - Xpart = X[:, [seed]] - y, completed = eval_tree_array(tree, Xpart, options) - completed || continue - local y_turbo - # We capture any warnings about the LoopVectorization not working - eval_warnings = @capture_err begin - y_turbo, _ = eval_tree_array(tree, Xpart, options; turbo=true) - end - test_info(@test y[1] ≈ y_turbo[1] && eval_warnings == "") do - @info T tree X[:, seed] y y_turbo eval_warnings - end - end - end -end diff --git a/test/test_optimizer_mutation.jl b/test/test_optimizer_mutation.jl deleted file mode 100644 index a5bad0e3c..000000000 --- a/test/test_optimizer_mutation.jl +++ /dev/null @@ -1,42 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression: Dataset, RunningSearchStatistics, RecordType -using Optim: Optim -using LibraryAugmentedSymbolicRegression.MutateModule: next_generation -using DynamicExpressions: get_scalar_constants - -mutation_weights = (; optimize=1e30) # We also test whether a named tuple works. -options = Options(; - binary_operators=(+, -, *), - unary_operators=(sin,), - mutation_weights=mutation_weights, - optimizer_options=Optim.Options(), -) - -X = randn(5, 100) -y = sin.(X[1, :] .* 2.1 .+ 0.8) .+ X[2, :] .^ 2 -dataset = Dataset(X, y) - -x1 = Node(Float64; feature=1) -x2 = Node(Float64; feature=2) -tree = sin(x1 * 1.9 + 0.2) + x2 * x2 - -member = PopMember(dataset, tree, options; deterministic=false) -temperature = 1.0 -maxsize = 20 - -new_member, _, _ = next_generation( - dataset, - member, - temperature, - maxsize, - RunningSearchStatistics(; options=options), - options; - tmp_recorder=RecordType(), -) - -resultant_constants, refs = get_scalar_constants(new_member.tree) -for k in [0.0, 0.2, 0.5, 1.0] - @test sin(resultant_constants[1] * k + resultant_constants[2]) ≈ sin(2.1 * k + 0.8) atol = - 1e-3 -end diff --git a/test/test_options.jl b/test/test_options.jl deleted file mode 100644 index c60f94c15..000000000 --- a/test/test_options.jl +++ /dev/null @@ -1,15 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Optim: Optim - -# testing types -op = Options(; optimizer_options=(iterations=16, f_calls_limit=100, x_tol=1e-16)); -@test isa(op.optimizer_options, Optim.Options) - -op = Options(; - optimizer_options=Dict(:iterations => 32, :g_calls_limit => 50, :f_tol => 1e-16) -); -@test isa(op.optimizer_options, Optim.Options) - -optim_op = Optim.Options(; iterations=16) -op = Options(; optimizer_options=optim_op); -@test isa(op.optimizer_options, Optim.Options) diff --git a/test/test_params.jl b/test/test_params.jl index a9c4871bd..6ada277f5 100644 --- a/test/test_params.jl +++ b/test/test_params.jl @@ -30,7 +30,6 @@ const default_params = ( hof_migration=true, fraction_replaced_hof=0.1f0, should_optimize_constants=true, - output_file=nothing, perturbation_factor=1.000000f0, annealing=true, batching=false, diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl deleted file mode 100644 index 3d7f334a9..000000000 --- a/test/test_pretty_printing.jl +++ /dev/null @@ -1,111 +0,0 @@ -@testitem "pretty print member" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - - options = Options(; binary_operators=[+, ^]) - - ex = @parse_expression(x^2.0 + 1.5, binary_operators = [+, ^], variable_names = [:x]) - shower(x) = sprint((io, e) -> show(io, MIME"text/plain"(), e), x) - s = shower(ex) - @test s == "(x ^ 2.0) + 1.5" - - X = [1.0 2.0 3.0] - y = [2.0, 3.0, 4.0] - dataset = Dataset(X, y) - member = PopMember(dataset, ex, options; deterministic=false) - member.score = 1.0 - @test member isa PopMember{Float64,Float64,<:Expression{Float64,Node{Float64}}} - s_member = shower(member) - @test s_member == "PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0)" - - # New options shouldn't change this - options = Options(; binary_operators=[-, /]) - s_member = shower(member) - @test s_member == "PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0)" -end - -@testitem "pretty print hall of fame" tags = [:part1] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression: embed_metadata - using LibraryAugmentedSymbolicRegression.CoreModule: safe_pow - - options = Options(; binary_operators=[+, safe_pow], maxsize=7) - - ex = @parse_expression( - $safe_pow(x, 2.0) + 1.5, binary_operators = [+, safe_pow], variable_names = [:x] - ) - shower(x) = sprint((io, e) -> show(io, MIME"text/plain"(), e), x) - s = shower(ex) - @test s == "(x ^ 2.0) + 1.5" - - X = [1.0 2.0 3.0] - y = [2.0, 3.0, 4.0] - dataset = Dataset(X, y) - member = PopMember(dataset, ex, options; deterministic=false) - member.score = 1.0 - @test member isa PopMember{Float64,Float64,<:Expression{Float64,Node{Float64}}} - - hof = HallOfFame(options, dataset) - hof = embed_metadata(hof, options, dataset) - hof.members[5] = member - hof.exists[5] = true - s_hof = strip(shower(hof)) - true_s = "HallOfFame{...}: - .exists[1] = false - .members[1] = undef - .exists[2] = false - .members[2] = undef - .exists[3] = false - .members[3] = undef - .exists[4] = false - .members[4] = undef - .exists[5] = true - .members[5] = PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0) - .exists[6] = false - .members[6] = undef - .exists[7] = false - .members[7] = undef - .exists[8] = false - .members[8] = undef - .exists[9] = false - .members[9] = undef" - - @test s_hof == true_s -end - -@testitem "pretty print expression" tags = [:part2] begin - using LibraryAugmentedSymbolicRegression - using Suppressor: @capture_out - - options = Options(; binary_operators=[+, -, *, /], unary_operators=[cos]) - ex = @parse_expression( - cos(x) + y * y, operators = options.operators, variable_names = [:x, :y] - ) - - s = sprint((io, ex) -> print_tree(io, ex, options), ex) - @test strip(s) == "cos(x) + (y * y)" - - s = @capture_out begin - print_tree(ex, options) - end - @test strip(s) == "cos(x) + (y * y)" - - # Works with the tree itself too - s = @capture_out begin - print_tree(get_tree(ex), options) - end - @test strip(s) == "cos(x1) + (x2 * x2)" - s = sprint((io, ex) -> print_tree(io, ex, options), get_tree(ex)) - @test strip(s) == "cos(x1) + (x2 * x2)" - - # Updating options won't change printout, UNLESS - # we pass the options. - options = Options(; binary_operators=[/, *, -, +], unary_operators=[sin]) - - s = @capture_out begin - print_tree(ex) - end - @test strip(s) == "cos(x) + (y * y)" - - s = sprint((io, ex) -> print_tree(io, ex, options), ex) - @test strip(s) == "sin(x) / (y - y)" -end diff --git a/test/test_print.jl b/test/test_print.jl deleted file mode 100644 index 2c3b34b29..000000000 --- a/test/test_print.jl +++ /dev/null @@ -1,56 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression.UtilsModule: split_string - -include("test_params.jl") - -## Test Base.print -options = Options(; - default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, sin) -) - -f = (x1, x2, x3) -> (sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0 - -tree = f(Node("x1"), Node("x2"), Node("x3")) - -s = repr(tree) -true_s = "(sin(cos(sin(cos(x1) * x3) * 3.0) * -0.5) + 2.0) * 5.0" - -@test s == true_s - -equation_search( - randn(Float32, 3, 10), - randn(Float32, 10); - options=options, - variable_names=["v1", "v2", "v3"], - niterations=0, - parallelism=:multithreading, -) - -s = repr(tree) -true_s = "(sin(cos(sin(cos(v1) * v3) * 3.0) * -0.5) + 2.0) * 5.0" -@test s == true_s - -for unaop in [safe_log, safe_log2, safe_log10, safe_log1p, safe_sqrt, safe_acosh] - opts = Options(; - default_params..., binary_operators=(+, *, /, -), unary_operators=(unaop,) - ) - minitree = Node(1, Node("x1")) - @test string_tree(minitree, opts) == replace(string(unaop), "safe_" => "") * "(x1)" -end - -for binop in [safe_pow, ^] - opts = Options(; - default_params..., binary_operators=(+, *, /, -, binop), unary_operators=(cos,) - ) - minitree = Node(5, Node("x1"), Node("x2")) - @test string_tree(minitree, opts) == "x1 ^ x2" -end - -@testset "Test splitting of strings" begin - split_string("abcdefgh", 3) == ["abc", "def", "gh"] - split_string("abcdefgh", 100) == ["abcdefgh"] - split_string("⋅", 1) == ["⋅"] - split_string("⋅⋅", 1) == ["⋅", "⋅"] - split_string("⋅⋅⋅⋅", 2) == ["⋅⋅", "⋅⋅"] - split_string("ραβγ", 2) == ["ρα", "βγ"] -end diff --git a/test/test_prob_pick_first.jl b/test/test_prob_pick_first.jl deleted file mode 100644 index 9b21b7a06..000000000 --- a/test/test_prob_pick_first.jl +++ /dev/null @@ -1,52 +0,0 @@ -println("Testing whether tournament_selection_p works.") -using LibraryAugmentedSymbolicRegression -using DynamicExpressions: with_type_parameters, @parse_expression -using Test -include("test_params.jl") - -n = 10 - -options = Options(; - default_params..., - binary_operators=(+, -, *, /), - unary_operators=(cos, sin), - tournament_selection_p=0.999, - tournament_selection_n=n, -) - -for reverse in [false, true] - T = Float32 - - # Generate members with scores from 0 to 1: - members = [ - let - ex = @parse_expression( - x1 * 3.2, operators = options.operators, variable_names = [:x1], - ) - score = Float32(i - 1) / (n - 1) - if reverse - score = 1 - score - end - test_loss = 1.0f0 # (arbitrary for this test) - PopMember(ex, score, test_loss, options; deterministic=false) - end for i in 1:n - ] - - pop = Population(members) - - dummy_running_stats = LibraryAugmentedSymbolicRegression.AdaptiveParsimonyModule.RunningSearchStatistics(; - options=options - ) - best_pop_member = [ - LibraryAugmentedSymbolicRegression.best_of_sample( - pop, dummy_running_stats, options - ).score for j in 1:100 - ] - - mean_value = sum(best_pop_member) / length(best_pop_member) - - # Make sure average score is small - @test mean_value < 0.1 -end - -println("Passed.") diff --git a/test/test_recorder.jl b/test/test_recorder.jl deleted file mode 100644 index 054454fe1..000000000 --- a/test/test_recorder.jl +++ /dev/null @@ -1,52 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression.UtilsModule: recursive_merge -using JSON3 -include("test_params.jl") - -base_dir = mktempdir() -recorder_file = joinpath(base_dir, "pysr_recorder.json") -X = 2 .* randn(Float32, 2, 1000) -y = 3 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2 - -options = LibraryAugmentedSymbolicRegression.Options(; - binary_operators=(+, *, /, -), - unary_operators=(cos,), - use_recorder=true, - recorder_file=recorder_file, - crossover_probability=0.0, # required for recording, as not set up to track crossovers. - populations=2, - population_size=100, - maxsize=20, - complexity_of_operators=[cos => 2], -) - -hall_of_fame = equation_search( - X, y; niterations=5, options=options, parallelism=:multithreading -) - -data = open(options.recorder_file, "r") do io - JSON3.read(io; allow_inf=true) -end - -@test haskey(data, :options) -@test haskey(data, :out1_pop1) -@test haskey(data, :out1_pop2) -@test haskey(data, :mutations) - -# Test that "Options" is part of the string in `data.options`: -@test contains(data.options, "Options") -@test length(data.mutations) > 1000 - -# Check whether 10 random elements have the right properties: -for (i, key) in enumerate(keys(data.mutations)) - @test haskey(data.mutations[key], :events) - @test haskey(data.mutations[key], :score) - @test haskey(data.mutations[key], :tree) - @test haskey(data.mutations[key], :loss) - @test haskey(data.mutations[key], :parent) - if i > 10 - break - end -end - -@test_throws ErrorException recursive_merge() diff --git a/test/test_search_statistics.jl b/test/test_search_statistics.jl deleted file mode 100644 index 770803427..000000000 --- a/test/test_search_statistics.jl +++ /dev/null @@ -1,40 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression.AdaptiveParsimonyModule: - RunningSearchStatistics, update_frequencies!, move_window!, normalize_frequencies! -using Random - -options = Options() - -statistics = RunningSearchStatistics(; options=options, window_size=500) - -for i in 1:1000 - update_frequencies!(statistics; size=rand(MersenneTwister(i), 1:10)) -end - -normalize_frequencies!(statistics) - -@test sum(statistics.frequencies) == 1022 -@test sum(statistics.normalized_frequencies) ≈ 1.0 -@test statistics.normalized_frequencies[5] > statistics.normalized_frequencies[15] - -move_window!(statistics) - -@test sum(statistics.frequencies) ≈ 500.0 - -normalize_frequencies!(statistics) - -@test sum(statistics.normalized_frequencies[1:5]) > - sum(statistics.normalized_frequencies[10:15]) - -for i in 1:500 - update_frequencies!(statistics; size=rand(MersenneTwister(i), 10:15)) -end - -move_window!(statistics) - -@test sum(statistics.frequencies) ≈ 500.0 - -normalize_frequencies!(statistics) - -@test sum(statistics.normalized_frequencies[1:5]) < - sum(statistics.normalized_frequencies[10:15]) diff --git a/test/test_simplification.jl b/test/test_simplification.jl deleted file mode 100644 index fa6764c4a..000000000 --- a/test/test_simplification.jl +++ /dev/null @@ -1,86 +0,0 @@ -include("test_params.jl") -using LibraryAugmentedSymbolicRegression, Test -using SymbolicUtils: simplify, Symbolic -using DynamicExpressions.OperatorEnumConstructionModule: empty_all_globals! -#! format: off -using Base: ≈; using Random: MersenneTwister -#! format: on -# ^ Can't end line with ≈ due to JuliaSyntax.jl bug - -function Base.:≈(a::String, b::String) - a = replace(a, r"\s+" => "") - b = replace(b, r"\s+" => "") - return a == b -end - -empty_all_globals!() - -binary_operators = (+, -, /, *) - -index_of_mult = [i for (i, op) in enumerate(binary_operators) if op == *][1] - -options = Options(; binary_operators=binary_operators) -@test options.should_simplify # Default is true - -tree = Node("x1") + Node("x1") - -# Should simplify to 2*x1: -eqn = convert(Symbolic, tree, options) -eqn2 = simplify(eqn) -# Should correctly simplify to 2 x1: -# (although it might use 2(x1^1)) -@test occursin("2", "$(repr(eqn2)[1])") - -# Let's convert back the simplified version. -# This should remove the ^ operator: -tree = convert(Node, eqn2, options) -# Make sure one of the nodes is now 2.0: -@test (tree.l.constant ? tree.l : tree.r).val == 2 -# Make sure the other node is x1: -@test (!tree.l.constant ? tree.l : tree.r).feature == 1 - -# Finally, let's try converting a product, and ensure -# that SymbolicUtils does not convert it to a power: -tree = Node("x1") * Node("x1") -eqn = convert(Symbolic, tree, options) -@test repr(eqn) ≈ "x1*x1" -# Test converting back: -tree_copy = convert(Node, eqn, options) -@test repr(tree_copy) ≈ "x1 * x1" - -# Let's test a much more complex function, -# with custom operators, and unary operators: -x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") -pow_abs2(x, y) = abs(x)^y -custom_cos2(x) = cos(x)^2 - -options = Options(; - binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos2, exp, sin) -) -@extend_operators options -tree = ( - ((x2 + x2) * ((-0.5982493 / pow_abs2(x1, x2)) / -0.54734415)) + ( - sin( - custom_cos2( - sin(1.2926733 - 1.6606787) / - sin(((0.14577048 * x1) + ((0.111149654 + x1) - -0.8298334)) - -1.2071426), - ) * (custom_cos2(x3 - 2.3201916) + ((x1 - (x1 * x2)) / x2)), - ) / (0.14854191 - ((custom_cos2(x2) * -1.6047639) - 0.023943262)) - ) -) -# We use `index_functions` to avoid converting the custom operators into the primitives. -eqn = convert(Symbolic, tree, options; index_functions=true) - -tree_copy = convert(Node, eqn, options) -tree_copy2 = convert(Node, simplify(eqn), options) -# Too difficult to check the representation, so we check by evaluation: -N = 100 -X = rand(MersenneTwister(0), 3, N) .+ 0.1 -output1, flag1 = eval_tree_array(tree, X, options) -output2, flag2 = eval_tree_array(tree_copy, X, options) -output3, flag3 = eval_tree_array(tree_copy2, X, options) - -@test isapprox(output1, output2, atol=1e-4 * sqrt(N)) -# Simplified equation may give a different answer due to rounding errors, -# so we weaken the requirement: -@test isapprox(output1, output3, atol=1e-2 * sqrt(N)) diff --git a/test/test_stop_on_clock.jl b/test/test_stop_on_clock.jl deleted file mode 100644 index 67beb7e0a..000000000 --- a/test/test_stop_on_clock.jl +++ /dev/null @@ -1,25 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Random -using Distributed: rmprocs -include("test_params.jl") - -X = randn(MersenneTwister(0), Float32, 5, 100) -y = 2 * cos.(X[4, :]) - -# Ensure is precompiled: -options = Options(; - default_params..., - population_size=10, - ncycles_per_iteration=100, - maxsize=15, - timeout_in_seconds=1, -) -equation_search(X, y; niterations=1, options=options, parallelism=:serial) - -# Ensure nothing might prevent slow checking of the clock: -rmprocs() -GC.gc(true) # full=true -start_time = time() -equation_search(X, y; niterations=10000000, options=options, parallelism=:serial) -end_time = time() -@test end_time - start_time < 100 diff --git a/test/test_symbolic_utils.jl b/test/test_symbolic_utils.jl deleted file mode 100644 index 408c4e65e..000000000 --- a/test/test_symbolic_utils.jl +++ /dev/null @@ -1,20 +0,0 @@ -using SymbolicUtils -using LibraryAugmentedSymbolicRegression -include("test_params.jl") - -_inv(x) = 1 / x -options = Options(; - default_params..., - binary_operators=(+, *, ^, /, greater), - unary_operators=(_inv,), - constraints=(_inv => 4,), - populations=4, -) -@extend_operators options -tree = Node(5, (^)(Node(; val=3.0) * Node(1, Node("x1")), 2.0), Node(; val=-1.2)) - -eqn = node_to_symbolic(tree, options; variable_names=["energy"], index_functions=true) -@test string(eqn) == "greater(safe_pow(3.0_inv(energy), 2.0), -1.2)" - -tree2 = symbolic_to_node(eqn, options; variable_names=["energy"]) -@test string_tree(tree, options) == string_tree(tree2, options) diff --git a/test/test_tree_construction.jl b/test/test_tree_construction.jl deleted file mode 100644 index 49af73d46..000000000 --- a/test/test_tree_construction.jl +++ /dev/null @@ -1,111 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using Random -using LibraryAugmentedSymbolicRegression: eval_loss, score_func, Dataset -using ForwardDiff -include("test_params.jl") - -x1 = 2.0 - -# Initialize functions in Base.... -for unaop in [cos, exp, safe_log, safe_log2, safe_log10, safe_sqrt, relu, gamma, safe_acosh] - for binop in [sub] - function make_options(; kw...) - return Options(; - default_params..., - binary_operators=(+, *, ^, /, binop), - unary_operators=(unaop, abs), - populations=4, - verbosity=(unaop == gamma) ? 0 : Int(1e9), - kw..., - ) - end - options = make_options() - @extend_operators options - - # for unaop in - f_true = (x,) -> binop(abs(3.0 * unaop(x))^2.0, -1.2) - - # binop at outside: - const_tree = Node( - 5, (^)(Node(2, Node(; val=3.0) * Node(1, Node("x1"))), 2.0), Node(; val=-1.2) - ) - const_tree_bad = Node( - 5, (^)(Node(2, Node(; val=3.0) * Node(1, Node("x1"))), 2.1), Node(; val=-1.3) - ) - n = count_nodes(const_tree) - - true_result = f_true(x1) - - result = eval(Meta.parse(string_tree(const_tree, make_options()))) - - # Test Basics - @test n == 9 - @test result == true_result - - types_to_test = [Float32, Float64, BigFloat] - if unaop == cos - # Other unary operators produce numbers too large - # to do meaningful tests - types_to_test = [Float16, types_to_test...] - end - for T in types_to_test - if T == Float16 || unaop == gamma - zero_tolerance = 3e-2 - else - zero_tolerance = 1e-6 - end - - tree = convert(Node{T}, const_tree) - tree_bad = convert(Node{T}, const_tree_bad) - - Random.seed!(0) - N = 100 - if unaop in [safe_log, safe_log2, safe_log10, safe_acosh, safe_sqrt] - X = T.(rand(MersenneTwister(0), 5, N) / 3) - else - X = T.(randn(MersenneTwister(0), 5, N) / 3) - end - X = X + sign.(X) * T(0.1) - if unaop == safe_acosh - X = X .+ T(1.0) - end - - y = T.(f_true.(X[1, :])) - dataset = Dataset(X, y) - test_y, complete = eval_tree_array(tree, X, make_options()) - test_y2, complete2 = differentiable_eval_tree_array(tree, X, make_options()) - - # Test Evaluation - @test complete == true - @test all(abs.(test_y .- y) / N .< zero_tolerance) - @test complete2 == true - @test all(abs.(test_y2 .- y) / N .< zero_tolerance) - - # Test loss: - @test abs(eval_loss(tree, dataset, make_options())) < zero_tolerance - @test eval_loss(tree, dataset, make_options()) == - score_func(dataset, tree, make_options())[2] - - #Test Scoring - @test abs(score_func(dataset, tree, make_options(; parsimony=0.0))[1]) < - zero_tolerance - @test score_func(dataset, tree, make_options(; parsimony=1.0))[1] > 1.0 - @test score_func(dataset, tree, make_options())[1] < - score_func(dataset, tree_bad, make_options())[1] - - dataset_with_larger_baseline = deepcopy(dataset) - dataset_with_larger_baseline.baseline_loss = one(T) * 10 - @test score_func(dataset_with_larger_baseline, tree_bad, make_options())[1] < - score_func(dataset, tree_bad, make_options())[1] - - # Test gradients: - df_true = x -> ForwardDiff.derivative(f_true, x) - dy = T.(df_true.(X[1, :])) - test_dy = ForwardDiff.gradient( - _x -> sum(differentiable_eval_tree_array(tree, _x, make_options())[1]), X - ) - test_dy = test_dy[1, 1:end] - @test all(abs.(test_dy .- dy) / N .< zero_tolerance) - end - end -end diff --git a/test/test_turbo_nan.jl b/test/test_turbo_nan.jl deleted file mode 100644 index e57e73a0a..000000000 --- a/test/test_turbo_nan.jl +++ /dev/null @@ -1,34 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LoopVectorization - -bad_op(x::T) where {T} = (x >= 0) ? x : T(0) - -options = Options(; - unary_operators=(sin, exp, sqrt, bad_op), - binary_operators=(+, *), - turbo=true, - nested_constraints=[sin => [sin => 0], exp => [exp => 0]], - maxsize=30, - npopulations=40, - # ^ Leave as deprecated param, just to test. - parsimony=0.01, -) - -tree = Node(3, Node(1, Node(; val=-π / 2))) - -# Should still be safe against domain errors: -try - tree([0.0]', options) - @test true -catch e - @test false -end - -tree = Node(3, Node(1, Node(; feature=1))) - -try - tree([-π / 2]', options) - @test true -catch e - @test false -end diff --git a/test/test_units.jl b/test/test_units.jl deleted file mode 100644 index 6046f8d21..000000000 --- a/test/test_units.jl +++ /dev/null @@ -1,488 +0,0 @@ -@testitem "Dimensional analysis" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression.InterfaceDynamicQuantitiesModule: get_units - using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: - violates_dimensional_constraints - using DynamicQuantities - using DynamicQuantities: DEFAULT_DIM_BASE_TYPE - - X = randn(3, 100) - y = @. cos(X[3, :] * 2.1 - 0.2) + 0.5 - - custom_op(x, y) = x + y - options = Options(; - binary_operators=[-, *, /, custom_op, ^], - unary_operators=[cos, cbrt, sqrt, abs, inv], - ) - @extend_operators options - - (x1, x2, x3) = (i -> Node(Float64; feature=i)).(1:3) - - D = Dimensions{DEFAULT_DIM_BASE_TYPE} - SD = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE} - - @test get_units(Float64, D, [u"m", "1", "kg"], uparse) == - [Quantity(1.0; length=1), Quantity(1.0), Quantity(1.0; mass=1)] - @test get_units(Float64, SD, [us"m", "1", "kg"], sym_uparse) == [ - Quantity(1.0, SymbolicDimensions; m=1), - Quantity(1.0, SymbolicDimensions), - Quantity(1.0, SymbolicDimensions; kg=1), - ] - # Various input types: - @test get_units(Float64, SD, [us"m", 1.5, SD()], sym_uparse) == [ - Quantity(1.0, SymbolicDimensions; m=1), - Quantity(1.5, SymbolicDimensions), - Quantity(1.0, SymbolicDimensions), - ] - @test get_units(Float64, SD, [""], sym_uparse) == [Quantity(1.0, SymbolicDimensions)] - # Bad unit types: - @test_throws ErrorException get_units(Float64, D, (; X=[1, 2]), uparse) - - # Dataset creation: - dataset = Dataset(X, y; X_units=[u"m", u"1", u"kg"], y_units=u"1") - @test dataset.X_units == [Quantity(1.0; length=1), Quantity(1.0), Quantity(1.0; mass=1)] - @test dataset.X_sym_units == [ - Quantity(1.0, SymbolicDimensions; m=1), - Quantity(1.0, SymbolicDimensions), - Quantity(1.0, SymbolicDimensions; kg=1), - ] - @test dataset.y_sym_units == Quantity(1.0, SymbolicDimensions) - @test dataset.y_units == Quantity(1.0) - - violates(tree) = violates_dimensional_constraints(tree, dataset, options) - - good_expressions = [ - Node(; val=3.2), - 3.2 * x1 / x1, - 1.0 * (3.2 * x1 - x2 * x1), - 3.2 * x1 - x2, - cos(3.2 * x1), - cos(0.9 * x1 - 0.5 * x2), - 1.0 * (x1 - 0.5 * (x3 * (cos(0.9 * x1 - 0.5 * x2) - 1.2))), - 1.0 * (custom_op(x1, x1)), - 1.0 * (custom_op(x1, 2.1 * x3)), - 1.0 * (custom_op(custom_op(x1, 2.1 * x3), x1)), - 1.0 * (custom_op(custom_op(x1, 2.1 * x3), 0.9 * x1)), - x2, - 1.0 * x1, - 1.0 * x3, - (1.0 * x1)^(Node(; val=3.2)), - 1.0 * (cbrt(x3 * x3 * x3) - x3), - 1.0 * (sqrt(x3 * x3) - x3), - 1.0 * (sqrt(abs(x3) * abs(x3)) - x3), - inv(x2), - 1.0 * inv(x1), - x3 * inv(x3), - ] - bad_expressions = [ - x1, - x3, - x1 - x3, - 1.0 * cos(x1), - 1.0 * cos(x1 - 0.5 * x2), - 1.0 * (x1 - (x3 * (cos(0.9 * x1 - 0.5 * x2) - 1.2))), - 1.0 * custom_op(x1, x3), - 1.0 * custom_op(custom_op(x1, 2.1 * x3), x3), - 1.0 * cos(0.8606301 / x1) / cos(custom_op(cos(x1), 3.2263336)), - 1.0 * (x1^(Node(; val=3.2))), - 1.0 * ((1.0 * x1)^x1), - 1.0 * (cbrt(x3 * x3) - x3), - 1.0 * (sqrt(abs(x3)) - x3), - inv(x3), - inv(x1), - x1 * inv(x3), - ] - - for expr in good_expressions - @test !violates(expr) || @show expr - end - for expr in bad_expressions - @test violates(expr) || @show expr - end -end - -@testitem "Search with dimensional constraints" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: - violates_dimensional_constraints - using Random: MersenneTwister - - rng = MersenneTwister(0) - X = rand(rng, 1, 128) .* 20 - y = @. cos(X[1, :]) + X[1, :] - dataset = Dataset(X, y; X_units=["kg"], y_units="1") - custom_op(x, y) = x + y - options = Options(; - binary_operators=[-, *, /, custom_op], - unary_operators=[cos], - early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 8), - ) - @extend_operators options - - hof = equation_search(dataset; niterations=1000, options) - - # Solutions should be like cos([cons] * X[1]) + [cons]*X[1] - dominating = calculate_pareto_frontier(hof) - best_expr = first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree - - @test !violates_dimensional_constraints(best_expr, dataset, options) - x1 = Node(Float64; feature=1) - @test compute_complexity(best_expr, options) >= - compute_complexity(custom_op(cos(1 * x1), 1 * x1), options) - - # Check that every cos(...) which contains x1 also has complexity - has_cos(tree) = - any(get_tree(tree)) do t - t.degree == 1 && options.operators.unaops[t.op] == cos - end - valid_trees = [ - !has_cos(member.tree) || any( - t -> - t.degree == 1 && - options.operators.unaops[t.op] == cos && - Node(Float64; feature=1) in t && - compute_complexity(t, options) > 1, - get_tree(member.tree), - ) for member in dominating - ] - @test all(valid_trees) - @test length(valid_trees) > 0 -end - -@testitem "Operator compatibility" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using DynamicQuantities - - ## square cube plus sub mult greater cond relu logical_or logical_and safe_pow atanh_clip - # Want to ensure these operators perform correctly in the context of units - @test square(1.0u"m") == 1.0u"m^2" - @test cube(1.0u"m") == 1.0u"m^3" - @test plus(1.0u"m", 1.0u"m") == 2.0u"m" - @test_throws DimensionError plus(1.0u"m", 1.0u"s") - @test sub(1.0u"m", 1.0u"m") == 0.0u"m" - @test_throws DimensionError sub(1.0u"m", 1.0u"s") - @test mult(1.0u"m", 1.0u"m") == 1.0u"m^2" - @test mult(1.0u"m", 1.0u"s") == 1.0u"m*s" - @test greater(1.1u"m", 1.0u"m") == true - @test greater(0.9u"m", 1.0u"m") == false - @test typeof(greater(1.1u"m", 1.0u"m")) === typeof(1.0u"m") - @test_throws DimensionError greater(1.0u"m", 1.0u"s") - @test cond(0.1u"m", 1.5u"m") == 1.5u"m" - @test cond(-0.1u"m", 1.5u"m") == 0.0u"m" - @test cond(-0.1u"s", 1.5u"m") == 0.0u"m" - @test relu(0.1u"m") == 0.1u"m" - @test relu(-0.1u"m") == 0.0u"m" - @test logical_or(0.1u"m", 0.0u"m") == 1.0 - @test logical_or(-0.1u"m", 0.0u"m") == 0.0 - @test logical_or(-0.5u"m", 1.0u"m") == 1.0 - @test logical_or(-0.2u"m", -0.2u"m") == 0.0 - @test logical_and(0.1u"m", 0.0u"m") == 0.0 - @test logical_and(0.1u"s", 0.0u"m") == 0.0 - @test logical_and(-0.1u"m", 0.0u"m") == 0.0 - @test logical_and(-0.5u"m", 1.0u"m") == 0.0 - @test logical_and(-0.2u"s", -0.2u"m") == 0.0 - @test logical_and(0.2u"s", 0.2u"m") == 1.0 - @test safe_pow(4.0u"m", 0.5u"1") == 2.0u"m^0.5" - @test isnan(safe_pow(-4.0u"m", 0.5u"1")) - @test typeof(safe_pow(-4.0u"m", 0.5u"1")) === typeof(1.0u"m") - @inferred safe_pow(4.0u"m", 0.5u"1") - @test_throws DimensionError safe_pow(1.0u"m", 1.0u"m") - @test atanh_clip(0.5u"1") == atanh(0.5) - @test atanh_clip(2.5u"1") == atanh(0.5) - @test_throws DimensionError atanh_clip(1.0u"m") -end - -@testitem "Search with dimensional constraints on output" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using MLJBase: MLJBase as MLJ - using DynamicQuantities - using Random: MersenneTwister - - include("utils.jl") - - custom_op(x, y) = x + y - options = Options(; - binary_operators=[-, *, /, custom_op], - unary_operators=[cos], - early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity == 3), - ) - @extend_operators options - - rng = MersenneTwister(0) - X = randn(rng, 2, 128) - X[2, :] .= X[1, :] - y = X[1, :] .^ 2 - - # The search should find that y=X[2]^2 is the best, - # due to the dimensionality constraint: - hof = EquationSearch(X, y; options, X_units=["kg", "m"], y_units="m^2") - - # Solution should be x2 * x2 - dominating = calculate_pareto_frontier(hof) - best = get_tree(first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree) - - x2 = Node(Float64; feature=2) - - if compute_complexity(best, options) == 3 - @test best.degree == 2 - @test best.l == x2 - @test best.r == x2 - else - @warn "Complexity of best solution is not 3; search with units might have failed" - end - - rng = MersenneTwister(0) - X = randn(rng, 2, 128) - y = @. cbrt(X[1, :]) .+ sqrt(abs(X[2, :])) - options2 = Options(; - binary_operators=[+, *], - unary_operators=[sqrt, cbrt, abs], - early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity == 6), - ) - hof = EquationSearch(X, y; options=options2, X_units=["kg^3", "kg^2"], y_units="kg") - - dominating = calculate_pareto_frontier(hof) - best = first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree - @test compute_complexity(best, options2) == 6 - @test any(get_tree(best)) do t - t.degree == 1 && options2.operators.unaops[t.op] == cbrt - end - @test any(get_tree(best)) do t - t.degree == 1 && options2.operators.unaops[t.op] == safe_sqrt - end - - @testset "With MLJ" begin - for as_quantity_array in (false, true) - model = LaSRRegressor(; - binary_operators=[+, *], - unary_operators=[sqrt, cbrt, abs], - early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 6), - ) - X = if as_quantity_array - (; x1=randn(128) .* u"kg^3", x2=QuantityArray(randn(128) .* u"kg^2")) - else - (; x1=randn(128) .* u"kg^3", x2=randn(128) .* u"kg^2") - end - y = (@. cbrt(ustrip(X.x1)) + sqrt(abs(ustrip(X.x2)))) .* u"kg" - mach = MLJ.machine(model, X, y) - MLJ.fit!(mach) - report = MLJ.report(mach) - best_idx = findfirst(report.losses .< 1e-7)::Int - @test report.complexities[best_idx] <= 6 - @test any(get_tree(report.equations[best_idx])) do t - t.degree == 1 && t.op == 2 # cbrt - end - @test any(get_tree(report.equations[best_idx])) do t - t.degree == 1 && t.op == 1 # safe_sqrt - end - - # Prediction should have same units: - ypred = MLJ.predict(mach; rows=1:3) - @test dimension(ypred[begin]) == dimension(y[begin]) - end - - # Multiple outputs, and with RealQuantity - model = MultitargetLaSRRegressor(; - binary_operators=[+, *], - unary_operators=[sqrt, cbrt, abs], - early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 8), - ) - X = (; x1=randn(128), x2=randn(128)) - y = (; - a=(@. cbrt(ustrip(X.x1)) + sqrt(abs(ustrip(X.x2)))) .* RealQuantity(u"kg"), - b=X.x1, - ) - @test typeof(y.a) <: AbstractArray{<:RealQuantity} - mach = MLJ.machine(model, X, y) - MLJ.fit!(mach) - report = MLJ.report(mach) - @test minimum(report.losses[1]) < 1e-7 - @test minimum(report.losses[2]) < 1e-7 - - # Repeat with second run: - mach.model.niterations = 0 - MLJ.fit!(mach) - report = MLJ.report(mach) - @test minimum(report.losses[1]) < 1e-7 - @test minimum(report.losses[2]) < 1e-7 - - # Prediction should have same units: - ypred = MLJ.predict(mach; rows=1:3) - @test dimension(ypred.a[begin]) == dimension(y.a[begin]) - @test typeof(dimension(ypred.a[begin])) == typeof(dimension(y.a[begin])) - # TODO: Should return same quantity as input - @test typeof(ypred.a[begin]) <: Quantity - @test typeof(y.a[begin]) <: RealQuantity - VERSION >= v"1.8" && - @eval @test(typeof(ypred.b[begin]) == typeof(y.b[begin]), broken = true) - end -end - -@testitem "Should error on mismatched units" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using DynamicQuantities - - X = randn(11, 50) - y = randn(50) - VERSION >= v"1.8.0" && - @test_throws("Number of features", Dataset(X, y; X_units=["m", "1"], y_units="kg")) -end - -@testitem "Should print units" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using DynamicQuantities - - X = randn(5, 64) - y = randn(64) - dataset = Dataset(X, y; X_units=["m^3", "km/s", "kg", "1", "1"], y_units="kg") - x1, x2, x3, x4, x5 = [Node(Float64; feature=i) for i in 1:5] - options = Options(; binary_operators=[+, -, *, /], unary_operators=[cos, sin]) - tree = 1.0 * (x1 + x2 * x3 * 5.32) - cos(1.5 * (x1 - 0.5)) - - @test string_tree(tree, options) == - "(1.0 * (x1 + ((x2 * x3) * 5.32))) - cos(1.5 * (x1 - 0.5))" - @test string_tree(tree, options; raw=false) == - "(1 * (x₁ + ((x₂ * x₃) * 5.32))) - cos(1.5 * (x₁ - 0.5))" - @test string_tree( - tree, options; raw=false, display_variable_names=dataset.display_variable_names - ) == "(1 * (x₁ + ((x₂ * x₃) * 5.32))) - cos(1.5 * (x₁ - 0.5))" - @test string_tree( - tree, - options; - raw=false, - display_variable_names=dataset.display_variable_names, - X_sym_units=dataset.X_sym_units, - y_sym_units=dataset.y_sym_units, - ) == - "(1[?] * (x₁[m³] + ((x₂[s⁻¹ km] * x₃[kg]) * 5.32[?]))) - cos(1.5[?] * (x₁[m³] - 0.5[?]))" - - @test string_tree( - x5 * 3.2, - options; - raw=false, - display_variable_names=dataset.display_variable_names, - X_sym_units=dataset.X_sym_units, - y_sym_units=dataset.y_sym_units, - ) == "x₅ * 3.2[?]" - - # Should print numeric factor in unit if given: - dataset2 = Dataset(X, y; X_units=[1.5, 1.9, 2.0, 3.0, 5.0u"m"], y_units="kg") - @test string_tree( - x5 * 3.2, - options; - raw=false, - display_variable_names=dataset2.display_variable_names, - X_sym_units=dataset2.X_sym_units, - y_sym_units=dataset2.y_sym_units, - ) == "x₅[5.0 m] * 3.2[?]" - - # With dimensionless_constants_only, it will not print the [?]: - options = Options(; - binary_operators=[+, -, *, /], - unary_operators=[cos, sin], - dimensionless_constants_only=true, - ) - @test string_tree( - x5 * 3.2, - options; - raw=false, - display_variable_names=dataset2.display_variable_names, - X_sym_units=dataset2.X_sym_units, - y_sym_units=dataset2.y_sym_units, - ) == "x₅[5.0 m] * 3.2" -end - -@testitem "Dimensionless constants" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: - violates_dimensional_constraints - using DynamicQuantities - - include("utils.jl") - - options = Options(; - binary_operators=[+, -, *, /, square, cube], - unary_operators=[cos, sin], - dimensionless_constants_only=true, - ) - X = randn(5, 64) - y = randn(64) - dataset = Dataset(X, y; X_units=["m^3", "km/s", "kg", "hr", "1"], y_units="kg") - x1, x2, x3, x4, x5 = [Node(Float64; feature=i) for i in 1:5] - - dimensionally_valid_equations = [ - 1.5 * x1 / (cube(x2) * cube(x4)) * x3, x3, (square(x3) / x3) + x3 - ] - for tree in dimensionally_valid_equations - onfail(@test !violates_dimensional_constraints(tree, dataset, options)) do - @warn "Failed on" tree - end - end - dimensionally_invalid_equations = [Node(Float64; val=1.5), 1.5 * x1, x3 - 1.0 * x1] - for tree in dimensionally_invalid_equations - onfail(@test violates_dimensional_constraints(tree, dataset, options)) do - @warn "Failed on" tree - end - end - # But, all of these would be fine if we allow dimensionless constants: - let - options = Options(; binary_operators=[+, -, *, /], unary_operators=[cos, sin]) - for tree in dimensionally_invalid_equations - onfail(@test !violates_dimensional_constraints(tree, dataset, options)) do - @warn "Failed on" tree - end - end - end -end - -@testitem "Miscellaneous tests of unit interface" tags = [:part3] begin - using LibraryAugmentedSymbolicRegression - using DynamicQuantities - using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: - @maybe_return_call, WildcardQuantity - using LibraryAugmentedSymbolicRegression.MLJInterfaceModule: unwrap_units_single - using LibraryAugmentedSymbolicRegression.InterfaceDynamicQuantitiesModule: - get_dimensions_type - using MLJModelInterface: MLJModelInterface as MMI - - function test_return_call(op::Function, w...) - @maybe_return_call(typeof(first(w)), op, w) - return nothing - end - - x = WildcardQuantity{typeof(u"m")}(u"m/s", true, false) - - # Valid input returns as expected - @test ustrip(test_return_call(+, x, x)) == 2.0 - - # Regular errors are thrown - thrower(_...) = error("") - @test_throws ErrorException test_return_call(thrower, 1.0, 1.0) - - # But method errors are safely caught - @test test_return_call(+, 1.0, "1.0") === nothing - - # Edge case - ## First, what happens if we just pass some data with quantities, - ## and some without? - data = (a=randn(3), b=fill(us"m", 3), c=fill(u"m/s", 3)) - Xm_t = MMI.matrix(data; transpose=true) - @test typeof(Xm_t) <: Matrix{<:Quantity} - _, test_dims = unwrap_units_single(Xm_t, Dimensions) - @test test_dims == dimension.([u"1", u"m", u"m/s"]) - @test test_dims != dimension.([u"m", u"m", u"m"]) - @inferred unwrap_units_single(Xm_t, Dimensions) - - ## Now, we force promotion to generic `Number` type: - data = (a=Number[randn(3)...], b=fill(us"m", 3), c=fill(u"m/s", 3)) - Xm_t = MMI.matrix(data; transpose=true) - @test typeof(Xm_t) === Matrix{Number} - _, test_dims = unwrap_units_single(Xm_t, Dimensions) - @test test_dims == dimension.([u"1", u"m", u"m/s"]) - @test_skip @inferred unwrap_units_single(Xm_t, Dimensions) - - # Another edge case - ## Should be able to pull it out from array: - @test get_dimensions_type(Number[1.0, us"1"], Dimensions) <: SymbolicDimensions - @test get_dimensions_type(Number[1.0, 1.0], Dimensions) <: Dimensions -end diff --git a/test/test_utils.jl b/test/test_utils.jl deleted file mode 100644 index d6155c5f7..000000000 --- a/test/test_utils.jl +++ /dev/null @@ -1,29 +0,0 @@ -using LibraryAugmentedSymbolicRegression -using LibraryAugmentedSymbolicRegression.UtilsModule: - findmin_fast, argmin_fast, bottomk_fast, is_anonymous_function -using Random - -function simple_bottomk(x, k) - idx = sortperm(x)[1:k] - return x[idx], idx -end - -array_options = [ - (n=n, seed=seed, T=T) for n in (1, 5, 20, 50, 100, 1000), seed in 1:10, - T in (Float32, Float64, Int) -] - -@testset "argmin_fast" begin - for opt in array_options - x = rand(MersenneTwister(opt.seed), opt.T, opt.n) .* 2 .- 1 - @test findmin_fast(x) == findmin(x) - @test argmin_fast(x) == argmin(x) - end -end -@testset "bottomk_fast" begin - for opt in array_options, k in (1, 2, 3, 5, 10, 20, 50, 100) - k > opt.n && continue - x = rand(MersenneTwister(opt.seed), opt.T, opt.n) .* 2 .- 1 - @test bottomk_fast(x, k) == simple_bottomk(x, k) - end -end diff --git a/test/user_defined_operator.jl b/test/user_defined_operator.jl deleted file mode 100644 index f96f4d994..000000000 --- a/test/user_defined_operator.jl +++ /dev/null @@ -1,19 +0,0 @@ -using LibraryAugmentedSymbolicRegression, Test -include("test_params.jl") - -_inv(x::Float32)::Float32 = 1.0f0 / x -X = rand(Float32, 5, 100) .+ 1 -y = 1.2f0 .+ 2 ./ X[3, :] - -options = LibraryAugmentedSymbolicRegression.Options(; - default_params..., binary_operators=(+, *), unary_operators=(_inv,), populations=8 -) -hallOfFame = equation_search( - X, y; niterations=8, options=options, numprocs=2, parallelism=:multiprocessing -) - -dominating = calculate_pareto_frontier(X, y, hallOfFame, options) - -best = dominating[end] -# Test the score -@test best.loss < maximum_residual / 10