Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing extra data for loss function via MLJ interface #249

Closed
wants to merge 8 commits into from
1 change: 1 addition & 0 deletions docs/src/types.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ HallOfFame(options::Options, ::Type{T}, ::Type{L}) where {T<:DATA_TYPE,L<:LOSS_T
```@docs
Dataset
Dataset(X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing;
col::Int=1,
weights::Union{AbstractVector{T}, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
y_variable_name::Union{String,Nothing}=nothing,
Expand Down
3 changes: 1 addition & 2 deletions src/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@ include("OptionsStruct.jl")
include("Operators.jl")
include("Options.jl")

using .ProgramConstantsModule:
MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE
using .ProgramConstantsModule: MAX_DEGREE, RecordType, DATA_TYPE, LOSS_TYPE
using .DatasetModule: Dataset
using .OptionsStructModule: Options, ComplexityMapping, MutationWeights, sample_mutation
using .OptionsModule: Options
Expand Down
34 changes: 14 additions & 20 deletions src/Dataset.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ using DynamicQuantities:
DEFAULT_DIM_BASE_TYPE

using ..UtilsModule: subscriptify, get_base_type
using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE
using ..ProgramConstantsModule: DATA_TYPE, LOSS_TYPE
using ...InterfaceDynamicQuantitiesModule: get_si_units, get_sym_units

import ...deprecate_varmap
Expand All @@ -24,10 +24,13 @@ import ...deprecate_varmap
- `y::AbstractVector{T}`: The desired output values, with shape `(n,)`.
- `n::Int`: The number of samples.
- `nfeatures::Int`: The number of features.
- `col::Int`: For multi-target problems, this is the column of `y` which
we have selected in this dataset. For single-target problems, this is
always `1`.
- `weighted::Bool`: Whether the dataset is non-uniformly weighted.
- `weights::Union{AbstractVector{T},Nothing}`: If the dataset is weighted,
these specify the per-sample weight (with shape `(n,)`).
- `extra::NamedTuple`: Extra information to pass to a custom evaluation
- `extra::Union{NamedTuple,Base.Pairs,Nothing}`: Extra information to pass to a custom evaluation
function. Since this is an arbitrary named tuple, you could pass
any sort of dataset you wish to here.
- `avg_y`: The average value of `y` (weighted, if `weights` are passed).
Expand Down Expand Up @@ -56,7 +59,7 @@ mutable struct Dataset{
AX<:AbstractMatrix{T},
AY<:Union{AbstractVector{T},Nothing},
AW<:Union{AbstractVector{T},Nothing},
NT<:NamedTuple,
EX<:Union{NamedTuple,Base.Pairs,Nothing},
XU<:Union{AbstractVector{<:Quantity},Nothing},
YU<:Union{Quantity,Nothing},
XUS<:Union{AbstractVector{<:Quantity},Nothing},
Expand All @@ -66,9 +69,10 @@ mutable struct Dataset{
y::AY
n::Int
nfeatures::Int
col::Int
weighted::Bool
weights::AW
extra::NT
extra::EX
avg_y::Union{T,Nothing}
use_baseline::Bool
baseline_loss::L
Expand All @@ -81,27 +85,16 @@ mutable struct Dataset{
y_sym_units::YUS
end

"""
Dataset(X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing;
weights::Union{AbstractVector{T}, Nothing}=nothing,
variable_names::Union{Array{String, 1}, Nothing}=nothing,
y_variable_name::Union{String,Nothing}=nothing,
extra::NamedTuple=NamedTuple(),
loss_type::Type=Nothing,
X_units::Union{AbstractVector, Nothing}=nothing,
y_units=nothing,
) where {T<:DATA_TYPE}

Construct a dataset to pass between internal functions.
"""
"""Construct a dataset to pass between internal functions."""
function Dataset(
X::AbstractMatrix{T},
y::Union{AbstractVector{T},Nothing}=nothing;
col::Int=1,
weights::Union{AbstractVector{T},Nothing}=nothing,
variable_names::Union{Array{String,1},Nothing}=nothing,
display_variable_names=variable_names,
y_variable_name::Union{String,Nothing}=nothing,
extra::NamedTuple=NamedTuple(),
extra::Union{NamedTuple,Base.Pairs,Nothing}=nothing,
loss_type::Type{L}=Nothing,
X_units::Union{AbstractVector,Nothing}=nothing,
y_units=nothing,
Expand All @@ -113,8 +106,8 @@ function Dataset(
# Deprecation warning:
variable_names = deprecate_varmap(variable_names, varMap, :Dataset)

n = size(X, BATCH_DIM)
nfeatures = size(X, FEATURE_DIM)
n = size(X, 2)
nfeatures = size(X, 1)
weighted = weights !== nothing
variable_names = if variable_names === nothing
["x$(i)" for i in 1:nfeatures]
Expand Down Expand Up @@ -188,6 +181,7 @@ function Dataset(
y,
n,
nfeatures,
col,
weighted,
weights,
extra,
Expand Down
41 changes: 33 additions & 8 deletions src/MLJInterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,25 @@ function MMI.update(
m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing
)
options = old_fitresult === nothing ? get_options(m) : old_fitresult.options
return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options)
return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing)
end
function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options)
function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, extra)
if extra === nothing
if hasproperty(X, :data) || hasproperty(X, :extra)
@assert(
hasproperty(X, :data) &&
hasproperty(X, :extra) &&
length(propertynames(X)) == 2,
"If passing extra data to the MLJ interface of symblic regression, you must use the format " *
"(data=X, extra=extra), for the standard types of `X` and arbitrary named tuple `extra`."
)
return _update(
m, verbosity, old_fitresult, old_cache, X.data, y, w, options, X.extra
)
end
elseif !isempty(extra)
@info "Received extra arguments $(keys(extra)) which will be stored in the dataset."
end
# To speed up iterative fits, we cache the types:
types = if old_fitresult === nothing
(;
Expand All @@ -149,17 +165,13 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options)
)
X_units_clean::types.X_units_clean = clean_units(X_units)
y_units_clean::types.y_units_clean = clean_units(y_units)
w_t::types.w_t = if w !== nothing && isa(m, MultitargetSRRegressor)
@assert(isa(w, AbstractVector) && ndims(w) == 1, "Unexpected input for `w`.")
repeat(w', size(y_t, 1))
else
w
end
w_t::types.w_t = validate_weights(size(y_t, 1), m, w)
search_state::types.state = equation_search(
X_t,
y_t;
niterations=m.niterations,
weights=w_t,
extra=extra,
variable_names=variable_names,
options=options,
parallelism=m.parallelism,
Expand Down Expand Up @@ -204,6 +216,19 @@ hof_eltype(::Type{H}) where {T,H<:HallOfFame{T}} = T
hof_eltype(::Type{V}) where {V<:Vector} = hof_eltype(eltype(V))
hof_eltype(h) = hof_eltype(typeof(h))

function validate_weights(num_cols, ::MultitargetSRRegressor, w::AbstractVector)
return repeat(w', num_cols)
end
function validate_weights(_, ::SRRegressor, w::AbstractVector)
return w
end
function validate_weights(_, _, ::Nothing)
return nothing
end
function validate_weights(_, _, _)
return error("Unexpected input for `w`. This should usually be a vector.")
end

function clean_units(units)
!isa(units, AbstractDimensions) && error("Unexpected units.")
iszero(units) && return nothing
Expand Down
2 changes: 0 additions & 2 deletions src/ProgramConstants.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
module ProgramConstantsModule

const MAX_DEGREE = 2
const BATCH_DIM = 2
const FEATURE_DIM = 1
const RecordType = Dict{String,Any}

const DATA_TYPE = Number
Expand Down
3 changes: 3 additions & 0 deletions src/SearchUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ function construct_datasets(
X,
y,
weights,
extra,
variable_names,
display_variable_names,
y_variable_names,
Expand All @@ -359,7 +360,9 @@ function construct_datasets(
Dataset(
X,
y[j, :];
col=j,
weights=(weights === nothing ? weights : weights[j, :]),
extra=extra,
variable_names=variable_names,
display_variable_names=display_variable_names,
y_variable_name=if y_variable_names === nothing
Expand Down
14 changes: 12 additions & 2 deletions src/SymbolicRegression.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,6 @@ include("SearchUtils.jl")

using .CoreModule:
MAX_DEGREE,
BATCH_DIM,
FEATURE_DIM,
DATA_TYPE,
LOSS_TYPE,
RecordType,
Expand Down Expand Up @@ -259,6 +257,9 @@ which is useful for debugging and profiling.
More iterations will improve the results.
- `weights::Union{AbstractMatrix{T}, AbstractVector{T}, Nothing}=nothing`: Optionally
weight the loss for each `y` by this value (same shape as `y`).
- `extra::Union{NamedTuple, Base.Pairs, Nothing}=nothing`: Extra information to pass to a custom
evaluation function. Since this is an arbitrary named tuple, you could pass
any sort of dataset you wish to here.
- `options::Options=Options()`: The options for the search, such as
which operators to use, evolution hyperparameters, etc.
- `variable_names::Union{Vector{String}, Nothing}=nothing`: The names
Expand Down Expand Up @@ -337,6 +338,7 @@ function equation_search(
y::AbstractMatrix{T};
niterations::Int=10,
weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing,
extra::Union{NamedTuple,Base.Pairs,Nothing}=nothing,
options::Options=Options(),
variable_names::Union{AbstractVector{String},Nothing}=nothing,
display_variable_names::Union{AbstractVector{String},Nothing}=variable_names,
Expand Down Expand Up @@ -371,11 +373,19 @@ function equation_search(
@assert length(weights) == length(y)
weights = reshape(weights, size(y))
end
if extra !== nothing && !isempty(extra)
if options.loss_function === nothing
error(
"You have passed `extra`, but have not provided a custom `loss_function` to use it.",
)
end
end

datasets = construct_datasets(
X,
y,
weights,
extra,
variable_names,
display_variable_names,
y_variable_names,
Expand Down
2 changes: 2 additions & 0 deletions test/test_mixed.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ for i in 0:5
println("with default hyperparameters, Float64 type, and turbo=true")
T = Float64
turbo = true
parallelism = :multithreading
numprocs = nothing
end
if i == 5
options = SymbolicRegression.Options(;
Expand Down
94 changes: 93 additions & 1 deletion test/test_mlj.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
using SymbolicRegression: SymbolicRegression
using SymbolicRegression:
Node, SRRegressor, MultitargetSRRegressor, node_to_symbolic, symbolic_to_node
Node,
Dataset,
SRRegressor,
MultitargetSRRegressor,
node_to_symbolic,
symbolic_to_node,
eval_tree_array,
eval_grad_tree_array
using MLJTestInterface: MLJTestInterface as MTI
using MLJBase: machine, fit!, report, predict
using Test
using SymbolicUtils: SymbolicUtils
using Zygote
using Suppressor: @capture_err

macro quiet(ex)
Expand Down Expand Up @@ -175,3 +183,87 @@ end
end
@test occursin("Evaluation failed either due to", msg)
end

const WasEvaluated = Ref(false)
const HasWeights = Ref(false)
const WasEvaluatedLock = Threads.SpinLock()

# This tests both `.extra` and `idx`
function derivative_loss(tree, dataset::Dataset{T,L}, options, idx) where {T,L}
# Select from the batch indices, if given
X = idx === nothing ? dataset.X : view(dataset.X, :, idx)

ŷ, ∂ŷ, completed = eval_grad_tree_array(tree, X, options; variable=true)

!completed && return L(Inf)

y = idx === nothing ? dataset.y : view(dataset.y, idx)
∂y = idx === nothing ? dataset.extra.∂y : view(dataset.extra.∂y, idx)

mse_deriv = sum(i -> (∂ŷ[i] - ∂y[i])^2, eachindex(∂y)) / length(∂y)
mse_value = sum(i -> (ŷ[i] - y[i])^2, eachindex(y)) / length(y)

WasEvaluated[] || lock(WasEvaluatedLock) do
WasEvaluated[] = true
end
if dataset.weights !== nothing
HasWeights[] || lock(WasEvaluatedLock) do
HasWeights[] = true
end
end

return mse_value + mse_deriv
end

true_f(x) = x^3 / 3 - cos(x)
deriv_f(x) = x^2 + sin(x)

@testset "Test `extra` parameter" begin
X = reshape(0.0:0.32:10.0, :, 1)
y = true_f.(X[:, 1])
∂y = deriv_f.(X[:, 1])

model = SRRegressor(;
binary_operators=[+, -, *],
unary_operators=[cos],
loss_function=derivative_loss,
enable_autodiff=true,
batching=true,
batch_size=25,
niterations=100,
early_stop_condition=1e-6,
)
mach = machine(model, (data=X, extra=(∂y=∂y,)), y)
VERSION >= v"1.8" && @test_warn "experimental" fit!(mach)

@test WasEvaluated[]
@test predict(mach, X) ≈ y

WasEvaluated[] = false
HasWeights[] = false
# Try again but with weights parameter
w = ones(size(y))
model = SRRegressor(;
binary_operators=[+, -, *],
unary_operators=[cos],
loss_function=derivative_loss,
enable_autodiff=true,
batching=true,
batch_size=25,
niterations=100,
early_stop_condition=1e-6,
)
mach = machine(model, (data=X, extra=(; ∂y)), y, w)
VERSION >= v"1.8" && @test_warn "experimental" fit!(mach)
@test WasEvaluated[]
@test HasWeights[]
@test predict(mach, X) ≈ y

@testset "Test errors associated with `extra`" begin
# No loss function:
model = SRRegressor(; loss_function=nothing)
mach = machine(model, (data=X, extra=(∂y=∂y,)), y)
@test_throws ErrorException @quiet(fit!(mach; verbosity=0))
VERSION >= v"1.8" && @test_throws "You have passed" @quiet(fit!(mach))
end
end
Loading