diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index 1f00962ea..c42a71349 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -192,7 +192,7 @@ import .CoreModule: erf, erfc, atanh_clip -import .UtilsModule: is_anonymous_function, recursive_merge, json3_write +import .UtilsModule: is_anonymous_function, recursive_merge, json3_write, get_base_type import .ComplexityModule: compute_complexity import .CheckConstraintsModule: check_constraints import .AdaptiveParsimonyModule: @@ -349,16 +349,16 @@ function equation_search( runtests::Bool=true, saved_state=nothing, return_state::Union{Bool,Nothing}=nothing, - loss_type::Type{Linit}=Nothing, + loss_type::Type{L}=Nothing, verbosity::Union{Integer,Nothing}=nothing, progress::Union{Bool,Nothing}=nothing, X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, - v_dim_out::Val{dim_out}=Val(nothing), + v_dim_out::Val{DIM_OUT}=Val(nothing), # Deprecated: multithreaded=nothing, varMap=nothing, -) where {T<:DATA_TYPE,Linit,dim_out} +) where {T<:DATA_TYPE,L,DIM_OUT} if multithreaded !== nothing error( "`multithreaded` is deprecated. Use the `parallelism` argument instead. " * @@ -371,10 +371,6 @@ function equation_search( @assert length(weights) == length(y) weights = reshape(weights, size(y)) end - if T <: Complex && loss_type == Nothing - get_base_type(::Type{Complex{BT}}) where {BT} = BT - loss_type = get_base_type(T) - end datasets = construct_datasets( X, @@ -385,7 +381,7 @@ function equation_search( y_variable_names, X_units, y_units, - loss_type, + (T <: Complex && L === Nothing) ? get_base_type(T) : L, ) return equation_search( @@ -402,7 +398,7 @@ function equation_search( return_state=return_state, verbosity=verbosity, progress=progress, - v_dim_out=Val(dim_out), + v_dim_out=Val(DIM_OUT), ) end @@ -439,8 +435,8 @@ function equation_search( return_state::Union{Bool,Nothing}=nothing, verbosity::Union{Int,Nothing}=nothing, progress::Union{Bool,Nothing}=nothing, - v_dim_out::Val{dim_out}=Val(nothing), -) where {dim_out,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} + v_dim_out::Val{DIM_OUT}=Val(nothing), +) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} v_concurrency, concurrency = if parallelism in (:multithreading, "multithreading") (Val(:multithreading), :multithreading) elseif parallelism in (:multiprocessing, "multiprocessing") @@ -476,10 +472,10 @@ function equation_search( options.return_state end - v_dim_out = if dim_out === nothing + v_dim_out = if DIM_OUT === nothing length(datasets) > 1 ? Val(2) : Val(1) else - Val(dim_out) + Val(DIM_OUT) end _numprocs::Int = if numprocs === nothing && procs === nothing 4 @@ -565,8 +561,8 @@ function equation_search( end function _equation_search( - ::Val{parallelism}, - ::Val{dim_out}, + ::Val{PARALLELISM}, + ::Val{DIM_OUT}, datasets::Vector{D}, niterations::Int, options::Options, @@ -578,8 +574,8 @@ function _equation_search( saved_state, verbosity, progress, - ::Val{should_return_state}, -) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L},parallelism,should_return_state,dim_out} + ::Val{RETURN_STATE}, +) where {T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L},PARALLELISM,RETURN_STATE,DIM_OUT} stdin_reader = watch_stream(stdin) if options.define_helper_functions @@ -588,10 +584,10 @@ function _equation_search( example_dataset = datasets[1] nout = size(datasets, 1) - @assert (nout == 1 || dim_out == 2) + @assert (nout == 1 || DIM_OUT == 2) if runtests - test_option_configuration(parallelism, datasets, saved_state, options) + test_option_configuration(PARALLELISM, datasets, saved_state, options) test_dataset_configuration(example_dataset, options, verbosity) end @@ -604,9 +600,9 @@ function _equation_search( end # Start a population on every process # Store the population, hall of fame - WorkerOutputType = if parallelism == :serial + WorkerOutputType = if PARALLELISM == :serial Tuple{Population{T,L},HallOfFame{T,L},RecordType,Float64} - elseif parallelism == :multiprocessing + elseif PARALLELISM == :multiprocessing Future else Task @@ -622,7 +618,7 @@ function _equation_search( # Initialize storage for workers tasks = [Task[] for j in 1:nout] # Set up a channel to send finished populations back to head node - channels = if parallelism == :multiprocessing + channels = if PARALLELISM == :multiprocessing [[RemoteChannel(1) for i in 1:(options.populations)] for j in 1:nout] else [[Channel(1) for i in 1:(options.populations)] for j in 1:nout] @@ -645,7 +641,7 @@ function _equation_search( ########################################################################## ### Distributed code: ########################################################################## - if parallelism == :multiprocessing + if PARALLELISM == :multiprocessing (procs, we_created_procs) = configure_workers(; procs, numprocs, @@ -680,7 +676,7 @@ function _equation_search( for j in 1:nout, i in 1:(options.populations) worker_idx = assign_next_worker!( - worker_assignment; out=j, pop=i, parallelism, procs + worker_assignment; out=j, pop=i, parallelism=PARALLELISM, procs ) saved_pop = load_saved_population(saved_state; out=j, pop=i) @@ -698,7 +694,7 @@ function _equation_search( begin (copy_pop, HallOfFame(options, T, L), RecordType(), 0.0) end, - parallelism = parallelism, + parallelism = PARALLELISM, worker_idx = worker_idx ) else @@ -720,7 +716,7 @@ function _equation_search( Float64(options.population_size), ) end, - parallelism = parallelism, + parallelism = PARALLELISM, worker_idx = worker_idx ) # This involves population_size evaluations, on the full dataset: @@ -740,7 +736,7 @@ function _equation_search( curmaxsize = curmaxsizes[j] @recorder record["out$(j)_pop$(i)"] = RecordType() worker_idx = assign_next_worker!( - worker_assignment; out=j, pop=i, parallelism, procs + worker_assignment; out=j, pop=i, parallelism=PARALLELISM, procs ) # TODO - why is this needed?? @@ -749,7 +745,7 @@ function _equation_search( last_pop = worker_output[j][i] updated_pop = @sr_spawner( begin - in_pop = if parallelism in (:multiprocessing, :multithreading) + in_pop = if PARALLELISM in (:multiprocessing, :multithreading) fetch(last_pop)[1] else last_pop[1] @@ -766,7 +762,7 @@ function _equation_search( running_search_statistics=c_rss, ) end, - parallelism = parallelism, + parallelism = PARALLELISM, worker_idx = worker_idx ) worker_output[j][i] = updated_pop @@ -789,7 +785,7 @@ function _equation_search( print_every_n_seconds = 5 equation_speed = Float32[] - if parallelism in (:multiprocessing, :multithreading) + if PARALLELISM in (:multiprocessing, :multithreading) for j in 1:nout, i in 1:(options.populations) # Start listening for each population to finish: t = @async put!(channels[j][i], fetch(worker_output[j][i])) @@ -817,14 +813,14 @@ function _equation_search( j, i = all_idx[kappa] # Check if error on population: - if parallelism in (:multiprocessing, :multithreading) + if PARALLELISM in (:multiprocessing, :multithreading) if istaskfailed(tasks[j][i]) fetch(tasks[j][i]) error("Task failed for population") end end # Non-blocking check if a population is ready: - population_ready = if parallelism in (:multiprocessing, :multithreading) + population_ready = if PARALLELISM in (:multiprocessing, :multithreading) # TODO: Implement type assertions based on parallelism. isready(channels[j][i]) else @@ -837,7 +833,7 @@ function _equation_search( start_work_monitor!(resource_monitor) # Take the fetch operation from the channel since its ready (cur_pop, best_seen, cur_record, cur_num_evals) = - if parallelism in (:multiprocessing, :multithreading) + if PARALLELISM in (:multiprocessing, :multithreading) take!(channels[j][i]) else worker_output[j][i] @@ -904,7 +900,7 @@ function _equation_search( break end worker_idx = assign_next_worker!( - worker_assignment; out=j, pop=i, parallelism, procs + worker_assignment; out=j, pop=i, parallelis=PARALLELISM, procs ) iteration = if is_recording(options) key = "out$(j)_pop$(i)" @@ -929,10 +925,10 @@ function _equation_search( running_search_statistics=c_rss, ) end, - parallelism = parallelism, + parallelism = PARALLELISM, worker_idx = worker_idx ) - if parallelism in (:multiprocessing, :multithreading) + if PARALLELISM in (:multiprocessing, :multithreading) tasks[j][i] = @async put!(channels[j][i], fetch(worker_output[j][i])) end @@ -950,7 +946,7 @@ function _equation_search( options, equation_speed, head_node_occupation, - parallelism, + PARALLELISM, ) end end @@ -990,7 +986,7 @@ function _equation_search( total_cycles, cycles_remaining, head_node_occupation, - parallelism, + PARALLELISM, width=options.terminal_width, ) end @@ -1014,9 +1010,9 @@ function _equation_search( close_reader!(stdin_reader) # Safely close all processes or threads - if parallelism == :multiprocessing + if PARALLELISM == :multiprocessing we_created_procs && rmprocs(procs) - elseif parallelism == :multithreading + elseif PARALLELISM == :multithreading for j in 1:nout, i in 1:(options.populations) wait(worker_output[j][i]) end @@ -1028,10 +1024,10 @@ function _equation_search( @recorder json3_write(record, options.recorder_file) - if should_return_state - return (returnPops, (dim_out == 1 ? only(hallOfFame) : hallOfFame)) + if RETURN_STATE + return (returnPops, (DIM_OUT == 1 ? only(hallOfFame) : hallOfFame)) else - return (dim_out == 1 ? only(hallOfFame) : hallOfFame) + return (DIM_OUT == 1 ? only(hallOfFame) : hallOfFame) end end diff --git a/src/Utils.jl b/src/Utils.jl index f0e474d97..6021f6eb8 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -30,6 +30,8 @@ recursive_merge(x::AbstractDict...) = merge(recursive_merge, x...) recursive_merge(x...) = x[end] recursive_merge() = error("Unexpected input.") +get_base_type(::Type{Complex{BT}}) where {BT} = BT + const subscripts = ('₀', '₁', '₂', '₃', '₄', '₅', '₆', '₇', '₈', '₉') function subscriptify(number::Integer) return join([subscripts[i + 1] for i in reverse(digits(number))])