From bda097ee6ab7144567990530c86aca58b24c94c8 Mon Sep 17 00:00:00 2001 From: MilesCranmer Date: Sun, 24 Dec 2023 01:34:28 +0000 Subject: [PATCH] Simplify maxsize calculation --- src/SearchUtils.jl | 20 ++++++++++++++++++++ src/SymbolicRegression.jl | 34 +++++++++++----------------------- 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index 2624c7359..3074eb624 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -304,6 +304,26 @@ function load_saved_population(saved_state; out::Int, pop::Int) end load_saved_population(::Nothing; kws...) = nothing +""" + get_cur_maxsize(; options, total_cycles, cycles_remaining) + +For searches where the maxsize gradually increases, this function returns the +current maxsize. +""" +function get_cur_maxsize(; options::Options, total_cycles::Int, cycles_remaining::Int) + cycles_elapsed = total_cycles - cycles_remaining + fraction_elapsed = 1.0f0 * cycles_elapsed / total_cycles + in_warmup_period = fraction_elapsed <= options.warmup_maxsize_by + + if options.warmup_maxsize_by > 0 && in_warmup_period + return 3 + floor( + Int, (options.maxsize - 3) * fraction_elapsed / options.warmup_maxsize_by + ) + else + return options.maxsize + end +end + function construct_datasets( X, y, diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index b1ca2fd4b..4f1279e3b 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -231,7 +231,8 @@ import .SearchUtilsModule: init_dummy_pops, load_saved_hall_of_fame, load_saved_population, - construct_datasets + construct_datasets, + get_cur_maxsize include("deprecates.jl") include("Configure.jl") @@ -631,12 +632,6 @@ function _equation_search( record = RecordType() @recorder record["options"] = "$(options)" - curmaxsizes = if iszero(options.warmup_maxsize_by) - fill(options.maxsize, nout) - else - fill(convert(typeof(options.maxsize), 3), nout) - end - # Records the number of evaluations: # Real numbers indicate use of batching. num_evals = [[0.0 for i in 1:(options.populations)] for j in 1:nout] @@ -719,6 +714,12 @@ function _equation_search( end push!(allPops[j], new_pop) end + total_cycles = options.populations * niterations + cycles_remaining = [total_cycles for j in 1:nout] + curmaxsizes = [ + get_cur_maxsize(; options, total_cycles, cycles_remaining=cycles_remaining[j]) for + j in 1:nout + ] # 2. Start the cycle on every process: for j in 1:nout, i in 1:(options.populations) dataset = datasets[j] @@ -757,8 +758,6 @@ function _equation_search( verbosity > 0 && @info "Started!" start_time = time() - total_cycles = options.populations * niterations - cycles_remaining = [total_cycles for j in 1:nout] if progress #TODO: need to iterate this on the max cycles remaining! sum_cycle_remaining = sum(cycles_remaining) @@ -935,20 +934,9 @@ function _equation_search( tasks[j][i] = @async put!(channels[j][i], fetch(allPops[j][i])) end - cycles_elapsed = total_cycles - cycles_remaining[j] - if options.warmup_maxsize_by > 0 - fraction_elapsed = 1.0f0 * cycles_elapsed / total_cycles - if fraction_elapsed > options.warmup_maxsize_by - curmaxsizes[j] = options.maxsize - else - curmaxsizes[j] = - 3 + floor( - Int, - (options.maxsize - 3) * fraction_elapsed / - options.warmup_maxsize_by, - ) - end - end + curmaxsizes[j] = get_cur_maxsize(; + options, total_cycles, cycles_remaining=cycles_remaining[j] + ) stop_work_monitor!(resource_monitor) move_window!(all_running_search_statistics[j]) if progress