diff --git a/Project.toml b/Project.toml index 39ebba98..792ffb6a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJBase" uuid = "a7f614a8-145f-11e9-1d2a-a57a1082229d" authors = ["Anthony D. Blaom "] -version = "1.4.0" +version = "1.5.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" diff --git a/src/MLJBase.jl b/src/MLJBase.jl index 0d58635d..9842323d 100644 --- a/src/MLJBase.jl +++ b/src/MLJBase.jl @@ -248,7 +248,8 @@ export params # ------------------------------------------------------------------- # exports from this module, MLJBase -# computational_resources.jl: +# get/set global constants: +export default_logger export default_resource # one_dimensional_ranges.jl: diff --git a/src/init.jl b/src/init.jl index 947bdc03..d85fede1 100644 --- a/src/init.jl +++ b/src/init.jl @@ -3,6 +3,7 @@ function __init__() global DEFAULT_RESOURCE = Ref{AbstractResource}(CPU1()) global DEFAULT_SCITYPE_CHECK_LEVEL = Ref{Int}(1) global SHOW_COLOR = Ref{Bool}(true) + global DEFAULT_LOGGER = Ref{Any}(nothing) # for testing asynchronous training of learning networks: global TESTING = parse(Bool, get(ENV, "TEST_MLJBASE", "false")) diff --git a/src/machines.jl b/src/machines.jl index 1a3f5388..8f5aa438 100644 --- a/src/machines.jl +++ b/src/machines.jl @@ -1088,6 +1088,25 @@ function save(file::Union{String,IO}, mach::Machine) serialize(file, smach) end +const ERR_INVALID_DEFAULT_LOGGER = ArgumentError( + "You have attempted to save a machine to the default logger "* + "but `default_logger()` is currently `nothing`. "* + "Either specify an explicit logger, path or stream to save to, "* + "or use `default_logger(logger)` "* + "to change the default logger. " +) + +""" + MLJ.save(mach) + MLJBase.save(mach) + +Save the current machine as an artifact at the location associated with +`default_logger`](@ref). + +""" +MLJBase.save(mach::Machine) = MLJBase.save(default_logger(), mach) +MLJBase.save(::Nothing, ::Machine) = throw(ERR_INVALID_DEFAULT_LOGGER) + report_for_serialization(mach) = mach.report # NOTE. there is also a specialization of `report_for_serialization` for `Composite` diff --git a/src/resampling.jl b/src/resampling.jl index 68ecc040..dd317092 100644 --- a/src/resampling.jl +++ b/src/resampling.jl @@ -1,4 +1,4 @@ -# # TYPE ALIASES + # TYPE ALIASES const AbstractRow = Union{AbstractVector{<:Integer}, Colon} const TrainTestPair = Tuple{AbstractRow,AbstractRow} @@ -747,6 +747,64 @@ Base.show(io::IO, e::CompactPerformanceEvaluation) = print(io, "CompactPerformanceEvaluation$(_summary(e))") + +# =============================================================== +## USER CONTROL OF DEFAULT LOGGING + +const DOC_DEFAULT_LOGGER = + """ + + The default logger is used in calls to [`evaluate!`](@ref) and [`evaluate`](@ref), and + in the constructors `TunedModel` and `IteratedModel`, unless the `logger` keyword is + explicitly specified. + + !!! note + + Prior to MLJ v0.20.7 (and MLJBase 1.5) the default logger was always `nothing`. + +""" + +""" + default_logger() + +Return the current value of the default logger for use with supported machine learning +tracking platforms, such as [MLflow](https://mlflow.org/docs/latest/index.html). + +$DOC_DEFAULT_LOGGER + + When MLJBase is first loaded, the default logger is `nothing`. To reset the logger, see + beow. + +""" +default_logger() = DEFAULT_LOGGER[] + +""" + default_logger(logger) + +Reset the default logger. + +# Example + +Suppose an [MLflow](https://mlflow.org/docs/latest/index.html) tracking service is running +on a local server at `http://127.0.0.1:500`. Then every in every `evaluate` call in which +`logger` is not specified, as in the example below, the peformance evaluation is +automatically logged to the service. + +```julia-repl +using MLJ +logger = MLJFlow.Logger("http://127.0.0.1:5000/api") +default_logger(logger) + +X, y = make_moons() +model = ConstantClassifier() +evaluate(model, X, y, measures=[log_loss, accuracy)]) + +""" +function default_logger(logger) + DEFAULT_LOGGER[] = logger +end + + # =============================================================== ## EVALUATION METHODS @@ -1068,7 +1126,8 @@ Although `evaluate!` is mutating, `mach.model` and `mach.args` are not mutated. `false` the `per_observation` field of the returned object is populated with `missing`s. Setting to `false` may reduce compute time and allocations. -- `logger` - a logger object (see [`MLJBase.log_evaluation`](@ref)) +- `logger=default_logger()` - a logger object for forwarding results to a machine learning + tracking platform; see [`default_logger`](@ref) for details. - `compact=false` - if `true`, the returned evaluation object excludes these fields: `fitted_params_per_fold`, `report_per_fold`, `train_test_rows`. @@ -1093,7 +1152,7 @@ function evaluate!( check_measure=true, per_observation=true, verbosity=1, - logger=nothing, + logger=default_logger(), compact=false, ) @@ -1544,7 +1603,7 @@ end acceleration=default_resource(), check_measure=true, per_observation=true, - logger=nothing, + logger=default_logger(), compact=false, ) @@ -1624,7 +1683,7 @@ function Resampler( repeats=1, cache=true, per_observation=true, - logger=nothing, + logger=default_logger(), compact=false, ) resampler = Resampler( diff --git a/test/resampling.jl b/test/resampling.jl index 33ee41ec..fbf26777 100644 --- a/test/resampling.jl +++ b/test/resampling.jl @@ -935,4 +935,21 @@ end end end +# DUMMY LOGGER + +struct DummyLogger end + +MLJBase.save(logger::DummyLogger, mach::Machine) = mach.model + +@testset "default logger" begin + @test isnothing(default_logger()) + model = ConstantClassifier() + mach = machine(model, make_moons(10)...) + fit!(mach, verbosity=0) + @test_throws MLJBase.ERR_INVALID_DEFAULT_LOGGER MLJBase.save(mach) + default_logger(DummyLogger()) + @test default_logger() == DummyLogger() + @test MLJBase.save(mach) == model +end + true