Skip to content

Commit

Permalink
Merge pull request #979 from JuliaAI/default-logger
Browse files Browse the repository at this point in the history
Add a default_logger
  • Loading branch information
ablaom authored Jun 24, 2024
2 parents e7afc34 + 813ba0c commit 96f691c
Show file tree
Hide file tree
Showing 5 changed files with 103 additions and 6 deletions.
3 changes: 2 additions & 1 deletion src/MLJBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ export params
# -------------------------------------------------------------------
# exports from this module, MLJBase

# computational_resources.jl:
# get/set global constants:
export default_logger
export default_resource

# one_dimensional_ranges.jl:
Expand Down
1 change: 1 addition & 0 deletions src/init.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ function __init__()
global DEFAULT_RESOURCE = Ref{AbstractResource}(CPU1())
global DEFAULT_SCITYPE_CHECK_LEVEL = Ref{Int}(1)
global SHOW_COLOR = Ref{Bool}(true)
global DEFAULT_LOGGER = Ref{Any}(nothing)

# for testing asynchronous training of learning networks:
global TESTING = parse(Bool, get(ENV, "TEST_MLJBASE", "false"))
Expand Down
19 changes: 19 additions & 0 deletions src/machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1088,6 +1088,25 @@ function save(file::Union{String,IO}, mach::Machine)
serialize(file, smach)
end

const ERR_INVALID_DEFAULT_LOGGER = ArgumentError(
"You have attempted to save a machine to the default logger "*
"but `default_logger()` is currently `nothing`. "*
"Either specify an explicit logger, path or stream to save to, "*
"or use `default_logger(logger)` "*
"to change the default logger. "
)

"""
MLJ.save(mach)
MLJBase.save(mach)
Save the current machine as an artifact at the location associated with
`default_logger`](@ref).
"""
MLJBase.save(mach::Machine) = MLJBase.save(default_logger(), mach)
MLJBase.save(::Nothing, ::Machine) = throw(ERR_INVALID_DEFAULT_LOGGER)

report_for_serialization(mach) = mach.report

# NOTE. there is also a specialization of `report_for_serialization` for `Composite`
Expand Down
69 changes: 64 additions & 5 deletions src/resampling.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# # TYPE ALIASES
# TYPE ALIASES

const AbstractRow = Union{AbstractVector{<:Integer}, Colon}
const TrainTestPair = Tuple{AbstractRow,AbstractRow}
Expand Down Expand Up @@ -747,6 +747,64 @@ Base.show(io::IO, e::CompactPerformanceEvaluation) =
print(io, "CompactPerformanceEvaluation$(_summary(e))")



# ===============================================================
## USER CONTROL OF DEFAULT LOGGING

const DOC_DEFAULT_LOGGER =
"""
The default logger is used in calls to [`evaluate!`](@ref) and [`evaluate`](@ref), and
in the constructors `TunedModel` and `IteratedModel`, unless the `logger` keyword is
explicitly specified.
!!! note
Prior to MLJ v0.20.7 (and MLJBase 1.5) the default logger was always `nothing`.
"""

"""
default_logger()
Return the current value of the default logger for use with supported machine learning
tracking platforms, such as [MLflow](https://mlflow.org/docs/latest/index.html).
$DOC_DEFAULT_LOGGER
When MLJBase is first loaded, the default logger is `nothing`. To reset the logger, see
beow.
"""
default_logger() = DEFAULT_LOGGER[]

"""
default_logger(logger)
Reset the default logger.
# Example
Suppose an [MLflow](https://mlflow.org/docs/latest/index.html) tracking service is running
on a local server at `http://127.0.0.1:500`. Then every in every `evaluate` call in which
`logger` is not specified, as in the example below, the peformance evaluation is
automatically logged to the service.
```julia-repl
using MLJ
logger = MLJFlow.Logger("http://127.0.0.1:5000/api")
default_logger(logger)
X, y = make_moons()
model = ConstantClassifier()
evaluate(model, X, y, measures=[log_loss, accuracy)])
"""
function default_logger(logger)
DEFAULT_LOGGER[] = logger
end


# ===============================================================
## EVALUATION METHODS

Expand Down Expand Up @@ -1068,7 +1126,8 @@ Although `evaluate!` is mutating, `mach.model` and `mach.args` are not mutated.
`false` the `per_observation` field of the returned object is populated with
`missing`s. Setting to `false` may reduce compute time and allocations.
- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref))
- `logger=default_logger()` - a logger object for forwarding results to a machine learning
tracking platform; see [`default_logger`](@ref) for details.
- `compact=false` - if `true`, the returned evaluation object excludes these fields:
`fitted_params_per_fold`, `report_per_fold`, `train_test_rows`.
Expand All @@ -1093,7 +1152,7 @@ function evaluate!(
check_measure=true,
per_observation=true,
verbosity=1,
logger=nothing,
logger=default_logger(),
compact=false,
)

Expand Down Expand Up @@ -1544,7 +1603,7 @@ end
acceleration=default_resource(),
check_measure=true,
per_observation=true,
logger=nothing,
logger=default_logger(),
compact=false,
)
Expand Down Expand Up @@ -1624,7 +1683,7 @@ function Resampler(
repeats=1,
cache=true,
per_observation=true,
logger=nothing,
logger=default_logger(),
compact=false,
)
resampler = Resampler(
Expand Down
17 changes: 17 additions & 0 deletions test/resampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -935,4 +935,21 @@ end
end
end

# DUMMY LOGGER

struct DummyLogger end

MLJBase.save(logger::DummyLogger, mach::Machine) = mach.model

@testset "default logger" begin
@test isnothing(default_logger())
model = ConstantClassifier()
mach = machine(model, make_moons(10)...)
fit!(mach, verbosity=0)
@test_throws MLJBase.ERR_INVALID_DEFAULT_LOGGER MLJBase.save(mach)
default_logger(DummyLogger())
@test default_logger() == DummyLogger()
@test MLJBase.save(mach) == model
end

true

0 comments on commit 96f691c

Please sign in to comment.