From a8f7f46dc55c2dcce71a769db8cec8e1b40a3176 Mon Sep 17 00:00:00 2001 From: Miles Cranmer Date: Thu, 24 Aug 2023 04:44:51 -0400 Subject: [PATCH] Improve aggregation speeds by summing function --- src/losses.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/losses.jl b/src/losses.jl index f988204..cac0945 100644 --- a/src/losses.jl +++ b/src/losses.jl @@ -35,7 +35,7 @@ include("losses/weighted.jl") Return sum of `loss` values over the iterables `outputs` and `targets`. """ function sum(loss::SupervisedLoss, outputs, targets) - sum(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets)) + sum(i -> loss(outputs[i], targets[i]), eachindex(outputs, targets)) end """ @@ -46,7 +46,7 @@ The `weights` determine the importance of each observation. The option `normalize` divides the result by the sum of the weights. """ function sum(loss::SupervisedLoss, outputs, targets, weights; normalize=true) - s = sum(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights)) + s = sum(i -> weights[i] * loss(outputs[i], targets[i]), eachindex(outputs, targets, weights)) n = normalize ? sum(weights) : one(first(weights)) s / n end @@ -57,7 +57,7 @@ end Return mean of `loss` values over the iterables `outputs` and `targets`. """ function mean(loss::SupervisedLoss, outputs, targets) - mean(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets)) + mean(i -> loss(outputs[i], targets[i]), eachindex(outputs, targets)) end """ @@ -68,7 +68,7 @@ The `weights` determine the importance of each observation. The option `normalize` divides the result by the sum of the weights. """ function mean(loss::SupervisedLoss, outputs, targets, weights; normalize=true) - m = mean(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights)) + m = mean(i -> weights[i] * loss(outputs[i], targets[i]), eachindex(outputs, targets, weights)) n = normalize ? sum(weights) : one(first(weights)) m / n end