Skip to content

Commit

Permalink
Merge pull request #173 from MilesCranmer/fix-sum-speeds
Browse files Browse the repository at this point in the history
Improve aggregation speeds by using `eachindex` instead of `iterate`
  • Loading branch information
juliohm authored Aug 24, 2023
2 parents 7318c58 + a8f7f46 commit 0b3e74f
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/losses.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ include("losses/weighted.jl")
Return sum of `loss` values over the iterables `outputs` and `targets`.
"""
function sum(loss::SupervisedLoss, outputs, targets)
sum(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets))
sum(i -> loss(outputs[i], targets[i]), eachindex(outputs, targets))
end

"""
Expand All @@ -46,7 +46,7 @@ The `weights` determine the importance of each observation. The option
`normalize` divides the result by the sum of the weights.
"""
function sum(loss::SupervisedLoss, outputs, targets, weights; normalize=true)
s = sum(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights))
s = sum(i -> weights[i] * loss(outputs[i], targets[i]), eachindex(outputs, targets, weights))
n = normalize ? sum(weights) : one(first(weights))
s / n
end
Expand All @@ -57,7 +57,7 @@ end
Return mean of `loss` values over the iterables `outputs` and `targets`.
"""
function mean(loss::SupervisedLoss, outputs, targets)
mean(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets))
mean(i -> loss(outputs[i], targets[i]), eachindex(outputs, targets))
end

"""
Expand All @@ -68,7 +68,7 @@ The `weights` determine the importance of each observation. The option
`normalize` divides the result by the sum of the weights.
"""
function mean(loss::SupervisedLoss, outputs, targets, weights; normalize=true)
m = mean(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights))
m = mean(i -> weights[i] * loss(outputs[i], targets[i]), eachindex(outputs, targets, weights))
n = normalize ? sum(weights) : one(first(weights))
m / n
end

0 comments on commit 0b3e74f

Please sign in to comment.