Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support passing extra data for loss function via MLJ interface #249

Closed
wants to merge 8 commits into from

Conversation

MilesCranmer
Copy link
Owner

@MilesCranmer MilesCranmer commented Aug 12, 2023

For example, say we create a custom loss function that compares both f(x) and f'(x) against data. We can access the values with dataset.y and the derivatives with dataset.extra.y. This .extra property allows you to store arbitrary named tuples for accessing in a custom loss function.

function derivative_loss(tree, dataset::Dataset{T,L}, options, idx) where {T,L}
    # Select from the batch indices, if given
    X = idx === nothing ? dataset.X : view(dataset.X, :, idx)

    # Evaluate both f(x) and f'(x), where f is defined by `tree`
    ŷ, ∂ŷ, completed = eval_grad_tree_array(tree, X, options; variable=true)

    !completed && return L(Inf)

    y = idx === nothing ? dataset.y : view(dataset.y, idx)
    ∂y = idx === nothing ? dataset.extra.∂y : view(dataset.extra.∂y, idx)

    mse_deriv = sum(i -> (∂ŷ[i] - ∂y[i])^2, eachindex(∂y)) / length(∂y)
    mse_value = sum(i -> (ŷ[i] - y[i])^2, eachindex(y)) / length(y)

    return mse_value + mse_deriv
end

Here, we have also taken advantage of mini-batching, using the idx to sample from both dataset.y as well as dataset.extra.

You can now use this loss function this by passing a NamedTuple for the w input to machine, which is usually a vector of weights. If you pass a vector, it will be treated as the weights. But if you pass a NamedTuple, it will get added to the extra property of Dataset.

e.g.,

    model = SRRegressor(;
        binary_operators=[+, -, *],
        unary_operators=[cos],
        loss_function=derivative_loss,
        enable_autodiff=true,
        batching=true,
        batch_size=25,
        niterations=100,
        early_stop_condition=1e-6,
    )
    mach = machine(model, X, y, (; ∂y=∂y))

@MilesCranmer
Copy link
Owner Author

@OkonSamuel @ablaom I was not sure whether there is a way to pass additional custom data to a machine, so I am currently simply allowing the user to pass a NamedTuple for the weights. What do you think?

@github-actions
Copy link
Contributor

github-actions bot commented Aug 12, 2023

Benchmark Results

master 0266146... t[master]/t[0266146...]
search/multithreading 22.4 ± 1.3 s 23.2 ± 1.2 s 0.969
search/serial 30.4 ± 0.31 s 29.5 ± 0.076 s 1.03
utils/best_of_sample 1.09 ± 0.38 μs 0.892 ± 0.26 μs 1.22
utils/check_constraints_x10 12.5 ± 3.3 μs 12.5 ± 3.3 μs 1
utils/compute_complexity_x10/Float64 2.26 ± 0.12 μs 2.25 ± 0.12 μs 1
utils/compute_complexity_x10/Int64 2.25 ± 0.11 μs 2.3 ± 0.12 μs 0.978
utils/compute_complexity_x10/nothing 1.47 ± 0.12 μs 1.48 ± 0.12 μs 0.993
utils/optimize_constants_x10 29.1 ± 6.8 ms 28.4 ± 6 ms 1.03
time_to_load 1.32 ± 0.0076 s 1.35 ± 0.0069 s 0.978

Benchmark Plots

A plot of the benchmark results have been uploaded as an artifact to the workflow run for this PR.
Go to "Actions"->"Benchmark a pull request"->[the most recent run]->"Artifacts" (at the bottom).

@ablaom
Copy link

ablaom commented Aug 14, 2023

I was not sure whether there is a way to pass additional custom data to a machine, so I am currently > simply allowing the user to pass a NamedTuple for the weights. What do you think?

You want to provide ∂y as additional training data, right? Well, there's no strict requirement about the number of arguments to MLJModelInterface.fit, so you could just make ∂y an optional third (positional) argument. If you also want to optional per-observation weights, then you have a problem because both ∂y and w could have the same type, yes? Your suggestion should work, but you should understand one point that is not well documented:

Subsampling of training data in MLJ. When observations are subsampled by evaluate! (eg in cross-validation) each training data argument (X, Y, w, and so forth) is subsampled only if it is an abstract vector, an abstract matrix (first dim is the observation index) or if istable(_) is true. In all other cases, the full object is used. For example, if w is a vector of per-observation weights, then in evaluate! it is subsampled along with X and y (say) but if w is a dict of class weights (in which case subsampling makes no sense) then w is not subsampled. If you pass a named tuple (ie. (; ∂y=∂y)) then that should work, because that will be regarded as a table. (In general, the subsampling might change the table type, but in this case I think it won't.)

If per-observation weights are not ever going to be supported, then perhaps two signatures MMI.fit(model, verb, X, y) and MMI.fit(model, verb, X, y, ∂y) is cleaner than your suggestion.

Another possibility, which I quite like, is to insist that y and ∂y be passed as a two-column table or two column matrix and that predict also returns both of these as a table or matrix. Then you could use a multi-target measure for out-of-sample evaluation. (Multi-target measures are coming soon). And you could also support per-observation weights, which would have a different type AbstractVector{<:Real}, which would never be the type of the two-column matrix or table.

@tomaklutfu
Copy link

@MilesCranmer this works for my use case. Thanks again working on this quickly.

[Diff since v0.23.1](v0.23.1...v0.23.2)

**Merged pull requests:**
- Formatting overhaul (#278) (@MilesCranmer)
- Avoid julia-formatter on pre-commit.ci (#279) (@MilesCranmer)
- Make it easier to select expression from Pareto front for evaluation (#289) (@MilesCranmer)

**Closed issues:**
- Garbage collection too passive on worker processes (#237)
- How can I set the maximum number of nests? (#285)
@MilesCranmer
Copy link
Owner Author

@ablaom it seems like fit_only! does not permit extra keywords?

function fit_only!(
    mach::Machine{<:Any,cache_data};
    rows=nothing,
    verbosity=1,
    force=false,
    composite=nothing,
) where cache_data

@MilesCranmer
Copy link
Owner Author

MilesCranmer commented Feb 19, 2024

I've tried a few different strategies it doesn't seem like there's a good way to let users to pass arbitrary data (of any shape) to be used in a custom loss function. I think this isn't a limitation necessarily, it just is a point at which high-level interfaces should not be used, as such levels of customisation would break various assumptions anyways.

For now I think we need to close this @tomaklutfu, doesn't seem like there's any robust way to do this right now. I would recommend either:

  1. Declaring any extra data as a global constant and accessing it inside the custom loss, or
  2. Using the low-level interface whenever you need to pass non-standard data formats or do very custom things.

Cheers,
Miles

@ablaom
Copy link

ablaom commented Feb 21, 2024

@ablaom it seems like fit_only! does not permit extra keywords?

Correct. Custom kwargs to fit is not supported.

@tomaklutfu
Copy link

tomaklutfu commented Feb 23, 2024

Thanks @MilesCranmer . I did use custom loss function sub-typed via a struct with fields for extra data. It worked without hurdles.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants