diff --git a/Project.toml b/Project.toml index 935a7ae..9d19ee3 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "MLJFlow" uuid = "7b7b8358-b45c-48ea-a8ef-7ca328ad328f" authors = ["Jose Esparza "] -version = "0.1.1" +version = "0.2.0" [deps] MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83" @@ -10,7 +10,7 @@ MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" [compat] MLFlowClient = "0.4.4" -MLJBase = "0.21.14" +MLJBase = "1" MLJModelInterface = "1.9.1" julia = "1.6" @@ -18,7 +18,8 @@ julia = "1.6" MLFlowClient = "64a0f543-368b-4a9a-827a-e71edb2a0b83" MLJDecisionTreeInterface = "c6f25543-311c-4c74-83dc-3ea6d1015661" MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7" +StatisticalMeasures = "a19d573c-0a75-4610-95b3-7071388c7541" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface"] +test = ["Test", "MLFlowClient", "MLJModels", "MLJDecisionTreeInterface", "StatisticalMeasures"] diff --git a/src/MLJFlow.jl b/src/MLJFlow.jl index 82ec796..a434b6d 100644 --- a/src/MLJFlow.jl +++ b/src/MLJFlow.jl @@ -1,7 +1,6 @@ module MLJFlow -using MLJBase: info, name, Model, - Machine +using MLJBase: Model, Machine, name using MLJModelInterface: flat_params using MLFlowClient: MLFlow, logparam, logmetric, createrun, MLFlowRun, updaterun, diff --git a/src/base.jl b/src/base.jl index 2505703..880b595 100644 --- a/src/base.jl +++ b/src/base.jl @@ -3,7 +3,10 @@ function log_evaluation(logger::MLFlowLogger, performance_evaluation) artifact_location=logger.artifact_location) run = createrun(logger.service, experiment; tags=[ - Dict("key" => "resampling", "value" => string(performance_evaluation.resampling)), + Dict( + "key" => "resampling", + "value" => string(performance_evaluation.resampling) + ), Dict("key" => "repeats", "value" => string(performance_evaluation.repeats)), Dict("key" => "model type", "value" => name(performance_evaluation.model)), ] diff --git a/src/service.jl b/src/service.jl index cd4a12b..b9c90f7 100644 --- a/src/service.jl +++ b/src/service.jl @@ -18,20 +18,57 @@ function logmodelparams(service::MLFlow, run::MLFlowRun, model::Model) end end +const MLFLOW_CHAR_SET = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789_-. /" + +""" + good_name(measure) + +**Private method.** + +Returns a string representation of `measure` that can be used as a valid name in +MLflow. Includes the value of the first hyperparameter, if there is one. + +```julia +julia> good_name(macro_f1score) +"MulticlassFScore-beta_1.0" + +""" +function good_name(measure) + name = string(measure) + name = replace(name, ", …" => "") + name = replace(name, " = " => "_") + name = replace(name, "()" => "") + name = replace(name, ")" => "") + map(collect(name)) do char + char in ['(', ','] && return '-' + char == '=' && return '_' + char in MLFLOW_CHAR_SET && return char + " " + end |> join +end + """ logmachinemeasures(service::MLFlow, run::MLFlowRun, model::Model) Extracts the parameters of a model and logs them to the MLFlow server. # Arguments -- `service::MLFlow`: An MLFlow service. See [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow) -- `run::MLFlowRun`: An MLFlow run. See [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlowRun) + +- `service::MLFlow`: An MLFlow service. See + [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlow) + +- `run::MLFlowRun`: An MLFlow run. See + [MLFlowClient.jl](https://juliaai.github.io/MLFlowClient.jl/dev/reference/#MLFlowClient.MLFlowRun) + - `measures`: A vector of measures. + - `measurements`: A vector of measurements. + """ function logmachinemeasures(service::MLFlow, run::MLFlowRun, measures, measurements) - measure_names = measures .|> info .|> x -> x.name + measure_names = measures .|> good_name for (name, value) in zip(measure_names, measurements) logmetric(service, run, name, value) end diff --git a/test/runtests.jl b/test/runtests.jl index 20f209c..ce5912d 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,6 +6,9 @@ using MLJBase using MLJModels using MLFlowClient using MLJModelInterface +using StatisticalMeasures include("base.jl") include("types.jl") +include("service.jl") + diff --git a/test/service.jl b/test/service.jl new file mode 100644 index 0000000..0c0da6b --- /dev/null +++ b/test/service.jl @@ -0,0 +1,7 @@ +@testset "good_name" begin + @test MLJFlow.good_name(rms) == "RootMeanSquaredError" + @test MLJFlow.good_name(macro_f1score) == "MulticlassFScore-beta_1.0" + @test MLJFlow.good_name(log_score) == "LogScore-tol_2.22045e-16" +end + +true