Skip to content

Commit

Permalink
add a default_logger
Browse files Browse the repository at this point in the history
  • Loading branch information
ablaom committed May 27, 2024
1 parent 14441aa commit 31898d9
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 5 deletions.
1 change: 1 addition & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ function __init__()
global DEFAULT_RESOURCE = Ref{AbstractResource}(CPU1())
global DEFAULT_SCITYPE_CHECK_LEVEL = Ref{Int}(1)
global SHOW_COLOR = Ref{Bool}(true)
global DEFAULT_LOGGER = Ref{Any}(nothing)

# for testing asynchronous training of learning networks:
global TESTING = parse(Bool, get(ENV, "TEST_MLJBASE", "false"))
Expand Down
18 changes: 18 additions & 0 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,24 @@ function save(file::Union{String,IO}, mach::Machine)
serialize(file, smach)
end

const ERR_INVALID_DEFAULT_LOGGER = ArgumentError(
"`default_logger()` is currently `nothing`. "*
"Either specify an explicit path or stream as "*
"target of the save, or use `default_logger(logger)` "*
"to change the default logger. "
)

"""
MLJ.save(mach)
MLJBase.save(mach)
Save the current machine as an artifact at the location associated with
`default_logger`](@ref).
"""
MLJBase.save(mach::Machine) = MLJBase.save(default_logger(), mach)
MLJBase.save(::Nothing, ::Machine) = throw(ERR_INVALID_DEFAULT_LOGGER)

report_for_serialization(mach) = mach.report

# NOTE. there is also a specialization of `report_for_serialization` for `Composite`
Expand Down
69 changes: 64 additions & 5 deletions src/resampling.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # TYPE ALIASES
# TYPE ALIASES

const AbstractRow = Union{AbstractVector{<:Integer}, Colon}
const TrainTestPair = Tuple{AbstractRow,AbstractRow}
Expand Down Expand Up @@ -747,6 +747,64 @@ Base.show(io::IO, e::CompactPerformanceEvaluation) =
print(io, "CompactPerformanceEvaluation$(_summary(e))")



# ===============================================================
## USER CONTROL OF DEFAULT LOGGING

const DOC_DEFAULT_LOGGER =
"""
The default logger is used in calls to [`evaluate!`](@ref) and [`evaluate`](@ref), and
in the constructors `TunedModel` and `IteratedModel`, unless the `logger` keyword is
explicitly specified.
!!! note
In MLJ version prior to 0.21 the default logger is always `nothing`.
"""

"""
default_logger()
Return the current value of the default logger for use with supported machine learning
tracking platforms, such as [MLflow](https://mlflow.org/docs/latest/index.html).
$DOC_DEFAULT_LOGGER
When MLJBase is first loaded, the default logger is `nothing`. To reset the logger, see
beow.
"""
default_logger() = DEFAULT_LOGGER[]

"""
default_logger(logger)
Reset the default logger.
# Example
Suppose an [MLflow](https://mlflow.org/docs/latest/index.html) tracking service is running
on a local server at `http://127.0.0.1:500`. Then every in every `evaluate` call in which
`logger` is not specified, as in the example below, the peformance evaluation is
automatically logged to the service.
```julia-repl
using MLJ
logger = MLJFlow.Logger("http://127.0.0.1:5000/api")
default_logger(logger)
X, y = make_moons()
model = ConstantClassifier()
evaluate(model, X, y, measures=[log_loss, accuracy)])
"""
function default_logger(logger)
DEFAULT_LOGGER[] = logger
end


# ===============================================================
## EVALUATION METHODS

Expand Down Expand Up @@ -1068,7 +1126,8 @@ Although `evaluate!` is mutating, `mach.model` and `mach.args` are not mutated.
`false` the `per_observation` field of the returned object is populated with
`missing`s. Setting to `false` may reduce compute time and allocations.
- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref))
- `logger=default_logger()` - a logger object for forwarding results to a machine learning
tracking platform; see [`default_logger`](@ref) for details.
- `compact=false` - if `true`, the returned evaluation object excludes these fields:
`fitted_params_per_fold`, `report_per_fold`, `train_test_rows`.
Expand All @@ -1093,7 +1152,7 @@ function evaluate!(
check_measure=true,
per_observation=true,
verbosity=1,
logger=nothing,
logger=default_logger(),
compact=false,
)

Expand Down Expand Up @@ -1544,7 +1603,7 @@ end
acceleration=default_resource(),
check_measure=true,
per_observation=true,
logger=nothing,
logger=default_logger(),
compact=false,
)
Expand Down Expand Up @@ -1632,7 +1691,7 @@ function Resampler(
repeats=1,
cache=true,
per_observation=true,
logger=nothing,
logger=default_logger(),
compact=false,
)
resampler = Resampler(
Expand Down

0 comments on commit 31898d9

Please sign in to comment.