diff --git a/Project.toml b/Project.toml index 741e83ea..9177b6e0 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SymbolicRegression" uuid = "8254be44-1295-4e6a-a16d-46603ac705cb" authors = ["MilesCranmer "] -version = "1.0.3" +version = "1.1.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/Configure.jl b/src/Configure.jl index 1151c4cb..a16280f4 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -221,7 +221,12 @@ function activate_env_on_workers( end end -function import_module_on_workers(procs, filename::String, verbosity) +function import_module_on_workers( + procs, + filename::String, + @nospecialize(worker_imports::Union{Vector{Symbol},Nothing}), + verbosity, +) loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules] included_as_local = "SymbolicRegression" ∉ loaded_modules_head_worker @@ -244,13 +249,16 @@ function import_module_on_workers(procs, filename::String, verbosity) :Enzyme, :LoopVectorization, :SymbolicUtils, + :TensorBoardLogger, :Zygote, ] filter!(m -> String(m) ∈ loaded_modules_head_worker, relevant_extensions) # HACK TODO – this workaround is very fragile. Likely need to submit a bug report # to JuliaLang. - for ext in relevant_extensions + all_extensions = vcat(relevant_extensions, @something(worker_imports, Symbol[])) + + for ext in all_extensions push!( expr.args, quote @@ -329,6 +337,7 @@ function configure_workers(; numprocs::Int, addprocs_function::Function, options::AbstractOptions, + @nospecialize(worker_imports::Union{Vector{Symbol},Nothing}), project_path, file, exeflags::Cmd, @@ -343,7 +352,7 @@ function configure_workers(; end if we_created_procs - import_module_on_workers(procs, file, verbosity) + import_module_on_workers(procs, file, worker_imports, verbosity) end move_functions_to_workers(procs, options, example_dataset, verbosity) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index d270054c..3a8d0a1d 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -58,6 +58,7 @@ function modelexpr(model_name::Symbol) procs::Union{Vector{Int},Nothing} = nothing addprocs_function::Union{Function,Nothing} = nothing heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing + worker_imports::Union{Vector{Symbol},Nothing} = nothing logger::Union{AbstractSRLogger,Nothing} = nothing runtests::Bool = true run_id::Union{String,Nothing} = nothing @@ -263,6 +264,7 @@ function _update( procs=m.procs, addprocs_function=m.addprocs_function, heap_size_hint_in_bytes=m.heap_size_hint_in_bytes, + worker_imports=m.worker_imports, runtests=m.runtests, saved_state=(old_fitresult === nothing ? nothing : old_fitresult.state), return_state=true, @@ -651,6 +653,9 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt is close to the recommended size. This is important for long-running distributed jobs where each process has an independent memory, and can help avoid out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`. + - `worker_imports::Union{Vector{Symbol},Nothing}=nothing`: If you want to import + additional modules on each worker, pass them here as a vector of symbols. + By default some of the extensions will automatically be loaded when needed. - `runtests::Bool=true`: Whether to run (quick) tests before starting the search, to see if there will be any problems during the equation search related to the host environment. diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index abcb971e..89ca03ee 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -55,6 +55,7 @@ struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE,LOGGER} <: AbstractRuntim init_procs::Union{Vector{Int},Nothing} addprocs_function::Function exeflags::Cmd + worker_imports::Union{Vector{Symbol},Nothing} runtests::Bool verbosity::Int64 progress::Bool @@ -89,6 +90,7 @@ end procs::Union{Vector{Int},Nothing}=nothing, addprocs_function::Union{Function,Nothing}=nothing, heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing, + worker_imports::Union{Vector{Symbol},Nothing}=nothing, runtests::Bool=true, return_state::VRS=nothing, run_id::Union{String,Nothing}=nothing, @@ -182,6 +184,7 @@ end procs, _addprocs_function, exeflags, + worker_imports, runtests, _verbosity, _progress, diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index b843f7d2..124fcaad 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -382,6 +382,9 @@ which is useful for debugging and profiling. is close to the recommended size. This is important for long-running distributed jobs where each process has an independent memory, and can help avoid out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`. +- `worker_imports::Union{Vector{Symbol},Nothing}=nothing`: If you want to import + additional modules on each worker, pass them here as a vector of symbols. + By default some of the extensions will automatically be loaded when needed. - `runtests::Bool=true`: Whether to run (quick) tests before starting the search, to see if there will be any problems during the equation search related to the host environment. @@ -433,6 +436,7 @@ function equation_search( procs::Union{Vector{Int},Nothing}=nothing, addprocs_function::Union{Function,Nothing}=nothing, heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing, + worker_imports::Union{Vector{Symbol},Nothing}=nothing, runtests::Bool=true, saved_state=nothing, return_state::Union{Bool,Nothing,Val}=nothing, @@ -482,6 +486,7 @@ function equation_search( procs=procs, addprocs_function=addprocs_function, heap_size_hint_in_bytes=heap_size_hint_in_bytes, + worker_imports=worker_imports, runtests=runtests, saved_state=saved_state, return_state=return_state, @@ -599,6 +604,7 @@ end ropt.numprocs, ropt.addprocs_function, options, + worker_imports=ropt.worker_imports, project_path=splitdir(Pkg.project().path)[1], file=@__FILE__, ropt.exeflags, diff --git a/test/test_logging.jl b/test/test_logging.jl index 6385f11c..33d87c6d 100644 --- a/test/test_logging.jl +++ b/test/test_logging.jl @@ -23,6 +23,9 @@ niterations, populations, logger, + parallelism=:multiprocessing, + # Test we can load extra packages: + worker_imports=[:LoggingExtras], ) X = (a=rand(500), b=rand(500))