Skip to content

Commit

Permalink
Merge pull request #376 from MilesCranmer/fix-distributed-logging
Browse files Browse the repository at this point in the history
Fix use of logger in distributed mode
  • Loading branch information
MilesCranmer authored Dec 3, 2024
2 parents 8356f48 + 788ce37 commit 36acbc3
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SymbolicRegression"
uuid = "8254be44-1295-4e6a-a16d-46603ac705cb"
authors = ["MilesCranmer <[email protected]>"]
version = "1.0.3"
version = "1.1.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
15 changes: 12 additions & 3 deletions src/Configure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,12 @@ function activate_env_on_workers(
end
end

function import_module_on_workers(procs, filename::String, verbosity)
function import_module_on_workers(
procs,
filename::String,
@nospecialize(worker_imports::Union{Vector{Symbol},Nothing}),
verbosity,
)
loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules]

included_as_local = "SymbolicRegression" loaded_modules_head_worker
Expand All @@ -244,13 +249,16 @@ function import_module_on_workers(procs, filename::String, verbosity)
:Enzyme,
:LoopVectorization,
:SymbolicUtils,
:TensorBoardLogger,
:Zygote,
]
filter!(m -> String(m) loaded_modules_head_worker, relevant_extensions)
# HACK TODO – this workaround is very fragile. Likely need to submit a bug report
# to JuliaLang.

for ext in relevant_extensions
all_extensions = vcat(relevant_extensions, @something(worker_imports, Symbol[]))

for ext in all_extensions
push!(
expr.args,
quote
Expand Down Expand Up @@ -329,6 +337,7 @@ function configure_workers(;
numprocs::Int,
addprocs_function::Function,
options::AbstractOptions,
@nospecialize(worker_imports::Union{Vector{Symbol},Nothing}),
project_path,
file,
exeflags::Cmd,
Expand All @@ -343,7 +352,7 @@ function configure_workers(;
end

if we_created_procs
import_module_on_workers(procs, file, verbosity)
import_module_on_workers(procs, file, worker_imports, verbosity)
end

move_functions_to_workers(procs, options, example_dataset, verbosity)
Expand Down
5 changes: 5 additions & 0 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ function modelexpr(model_name::Symbol)
procs::Union{Vector{Int},Nothing} = nothing
addprocs_function::Union{Function,Nothing} = nothing
heap_size_hint_in_bytes::Union{Integer,Nothing} = nothing
worker_imports::Union{Vector{Symbol},Nothing} = nothing
logger::Union{AbstractSRLogger,Nothing} = nothing
runtests::Bool = true
run_id::Union{String,Nothing} = nothing
Expand Down Expand Up @@ -263,6 +264,7 @@ function _update(
procs=m.procs,
addprocs_function=m.addprocs_function,
heap_size_hint_in_bytes=m.heap_size_hint_in_bytes,
worker_imports=m.worker_imports,
runtests=m.runtests,
saved_state=(old_fitresult === nothing ? nothing : old_fitresult.state),
return_state=true,
Expand Down Expand Up @@ -651,6 +653,9 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt
is close to the recommended size. This is important for long-running distributed
jobs where each process has an independent memory, and can help avoid
out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`.
- `worker_imports::Union{Vector{Symbol},Nothing}=nothing`: If you want to import
additional modules on each worker, pass them here as a vector of symbols.
By default some of the extensions will automatically be loaded when needed.
- `runtests::Bool=true`: Whether to run (quick) tests before starting the
search, to see if there will be any problems during the equation search
related to the host environment.
Expand Down
3 changes: 3 additions & 0 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ struct RuntimeOptions{PARALLELISM,DIM_OUT,RETURN_STATE,LOGGER} <: AbstractRuntim
init_procs::Union{Vector{Int},Nothing}
addprocs_function::Function
exeflags::Cmd
worker_imports::Union{Vector{Symbol},Nothing}
runtests::Bool
verbosity::Int64
progress::Bool
Expand Down Expand Up @@ -89,6 +90,7 @@ end
procs::Union{Vector{Int},Nothing}=nothing,
addprocs_function::Union{Function,Nothing}=nothing,
heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing,
worker_imports::Union{Vector{Symbol},Nothing}=nothing,
runtests::Bool=true,
return_state::VRS=nothing,
run_id::Union{String,Nothing}=nothing,
Expand Down Expand Up @@ -182,6 +184,7 @@ end
procs,
_addprocs_function,
exeflags,
worker_imports,
runtests,
_verbosity,
_progress,
Expand Down
6 changes: 6 additions & 0 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,9 @@ which is useful for debugging and profiling.
is close to the recommended size. This is important for long-running distributed
jobs where each process has an independent memory, and can help avoid
out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`.
- `worker_imports::Union{Vector{Symbol},Nothing}=nothing`: If you want to import
additional modules on each worker, pass them here as a vector of symbols.
By default some of the extensions will automatically be loaded when needed.
- `runtests::Bool=true`: Whether to run (quick) tests before starting the
search, to see if there will be any problems during the equation search
related to the host environment.
Expand Down Expand Up @@ -433,6 +436,7 @@ function equation_search(
procs::Union{Vector{Int},Nothing}=nothing,
addprocs_function::Union{Function,Nothing}=nothing,
heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing,
worker_imports::Union{Vector{Symbol},Nothing}=nothing,
runtests::Bool=true,
saved_state=nothing,
return_state::Union{Bool,Nothing,Val}=nothing,
Expand Down Expand Up @@ -482,6 +486,7 @@ function equation_search(
procs=procs,
addprocs_function=addprocs_function,
heap_size_hint_in_bytes=heap_size_hint_in_bytes,
worker_imports=worker_imports,
runtests=runtests,
saved_state=saved_state,
return_state=return_state,
Expand Down Expand Up @@ -599,6 +604,7 @@ end
ropt.numprocs,
ropt.addprocs_function,
options,
worker_imports=ropt.worker_imports,
project_path=splitdir(Pkg.project().path)[1],
file=@__FILE__,
ropt.exeflags,
Expand Down
3 changes: 3 additions & 0 deletions test/test_logging.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
niterations,
populations,
logger,
parallelism=:multiprocessing,
# Test we can load extra packages:
worker_imports=[:LoggingExtras],
)

X = (a=rand(500), b=rand(500))
Expand Down

2 comments on commit 36acbc3

@MilesCranmer
Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/120631

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v1.1.0 -m "<description of version>" 36acbc33a3f2817d406abd07a8dc082b36c6057c
git push origin v1.1.0

Please sign in to comment.