Skip to content

Commit

Permalink
Simplify maxsize calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
MilesCranmer committed Dec 24, 2023
1 parent 9a5ad57 commit bda097e
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 23 deletions.
20 changes: 20 additions & 0 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,26 @@ function load_saved_population(saved_state; out::Int, pop::Int)
end
load_saved_population(::Nothing; kws...) = nothing

"""
get_cur_maxsize(; options, total_cycles, cycles_remaining)
For searches where the maxsize gradually increases, this function returns the
current maxsize.
"""
function get_cur_maxsize(; options::Options, total_cycles::Int, cycles_remaining::Int)
cycles_elapsed = total_cycles - cycles_remaining
fraction_elapsed = 1.0f0 * cycles_elapsed / total_cycles
in_warmup_period = fraction_elapsed <= options.warmup_maxsize_by

if options.warmup_maxsize_by > 0 && in_warmup_period
return 3 + floor(
Int, (options.maxsize - 3) * fraction_elapsed / options.warmup_maxsize_by
)
else
return options.maxsize
end
end

function construct_datasets(
X,
y,
Expand Down
34 changes: 11 additions & 23 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,8 @@ import .SearchUtilsModule:
init_dummy_pops,
load_saved_hall_of_fame,
load_saved_population,
construct_datasets
construct_datasets,
get_cur_maxsize

include("deprecates.jl")
include("Configure.jl")
Expand Down Expand Up @@ -631,12 +632,6 @@ function _equation_search(
record = RecordType()
@recorder record["options"] = "$(options)"

curmaxsizes = if iszero(options.warmup_maxsize_by)
fill(options.maxsize, nout)
else
fill(convert(typeof(options.maxsize), 3), nout)
end

# Records the number of evaluations:
# Real numbers indicate use of batching.
num_evals = [[0.0 for i in 1:(options.populations)] for j in 1:nout]
Expand Down Expand Up @@ -719,6 +714,12 @@ function _equation_search(
end
push!(allPops[j], new_pop)
end
total_cycles = options.populations * niterations
cycles_remaining = [total_cycles for j in 1:nout]
curmaxsizes = [
get_cur_maxsize(; options, total_cycles, cycles_remaining=cycles_remaining[j]) for
j in 1:nout
]
# 2. Start the cycle on every process:
for j in 1:nout, i in 1:(options.populations)
dataset = datasets[j]
Expand Down Expand Up @@ -757,8 +758,6 @@ function _equation_search(

verbosity > 0 && @info "Started!"
start_time = time()
total_cycles = options.populations * niterations
cycles_remaining = [total_cycles for j in 1:nout]
if progress
#TODO: need to iterate this on the max cycles remaining!
sum_cycle_remaining = sum(cycles_remaining)
Expand Down Expand Up @@ -935,20 +934,9 @@ function _equation_search(
tasks[j][i] = @async put!(channels[j][i], fetch(allPops[j][i]))
end

cycles_elapsed = total_cycles - cycles_remaining[j]
if options.warmup_maxsize_by > 0
fraction_elapsed = 1.0f0 * cycles_elapsed / total_cycles
if fraction_elapsed > options.warmup_maxsize_by
curmaxsizes[j] = options.maxsize
else
curmaxsizes[j] =
3 + floor(
Int,
(options.maxsize - 3) * fraction_elapsed /
options.warmup_maxsize_by,
)
end
end
curmaxsizes[j] = get_cur_maxsize(;
options, total_cycles, cycles_remaining=cycles_remaining[j]
)
stop_work_monitor!(resource_monitor)
move_window!(all_running_search_statistics[j])
if progress
Expand Down

0 comments on commit bda097e

Please sign in to comment.