From d8c0654dbbe23d2b1f6b3f9b4097e91221e63162 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=BAlio=20Hoffimann?= Date: Thu, 27 Apr 2023 16:09:22 -0300 Subject: [PATCH] Refactor AggMode submodule and types (#162) * Get rid of AggMode.None * Formatting issues * Get rid of dimcheck * Simplify signatures in losses.jl * Accept iterables in vectorized loss * Refactor functor interface * Remove AggMode submodule * Rename test/aggmode.jl to test/agg.jl * Update docs --- Project.toml | 1 + docs/src/introduction/gettingstarted.md | 36 ++---- docs/src/user/aggregate.md | 135 +++------------------ src/LossFunctions.jl | 12 +- src/aggmode.jl | 111 ------------------ src/losses.jl | 148 +++++++++--------------- src/losses/other.jl | 9 +- src/losses/scaled.jl | 12 +- src/losses/weighted.jl | 24 ++-- test/{aggmode.jl => agg.jl} | 40 ++++--- test/core.jl | 9 +- test/runtests.jl | 2 +- 12 files changed, 126 insertions(+), 413 deletions(-) delete mode 100644 src/aggmode.jl rename test/{aggmode.jl => agg.jl} (64%) diff --git a/Project.toml b/Project.toml index 99d4dc8..9985bbc 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.9.0" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] CategoricalArrays = "0.10" diff --git a/docs/src/introduction/gettingstarted.md b/docs/src/introduction/gettingstarted.md index 62e1a7e..af485fa 100644 --- a/docs/src/introduction/gettingstarted.md +++ b/docs/src/introduction/gettingstarted.md @@ -78,8 +78,8 @@ julia> true_targets = [ 1, 0, -2]; julia> pred_outputs = [0.5, 2, -1]; -julia> value(L2DistLoss(), pred_outputs, true_targets) -3-element Array{Float64,1}: +julia> value.(L2DistLoss(), pred_outputs, true_targets) +3-element Vector{Float64}: 0.25 4.0 1.0 @@ -92,10 +92,10 @@ This will avoid allocating a temporary array and directly compute the result. ```julia-repl -julia> value(L2DistLoss(), pred_outputs, true_targets, AggMode.Sum()) +julia> sum(L2DistLoss(), pred_outputs, true_targets) 5.25 -julia> value(L2DistLoss(), pred_outputs, true_targets, AggMode.Mean()) +julia> mean(L2DistLoss(), pred_outputs, true_targets) 1.75 ``` @@ -105,33 +105,11 @@ each observation in the predicted outputs and so allow to give certain observations a stronger influence over the result. ```julia-repl -julia> value(L2DistLoss(), pred_outputs, true_targets, AggMode.WeightedSum([2,1,1])) +julia> sum(L2DistLoss(), pred_outputs, true_targets, [2,1,1], normalize=false) 5.5 -julia> value(L2DistLoss(), pred_outputs, true_targets, AggMode.WeightedMean([2,1,1])) -1.375 -``` - -All these function signatures of [`value`](@ref) also apply for -computing the derivatives using [`deriv`](@ref) and the second -derivatives using [`deriv2`](@ref). - -```julia-repl -julia> true_targets = [ 1, 0, -2]; - -julia> pred_outputs = [0.5, 2, -1]; - -julia> deriv(L2DistLoss(), pred_outputs, true_targets) -3-element Array{Float64,1}: - -1.0 - 4.0 - 2.0 - -julia> deriv2(L2DistLoss(), pred_outputs, true_targets) -3-element Array{Float64,1}: - 2.0 - 2.0 - 2.0 +julia> mean(L2DistLoss(), pred_outputs, true_targets, [2,1,1], normalize=false) +1.8333333333333333 ``` ## Getting Help diff --git a/docs/src/user/aggregate.md b/docs/src/user/aggregate.md index a0c46fe..dc035fe 100644 --- a/docs/src/user/aggregate.md +++ b/docs/src/user/aggregate.md @@ -34,13 +34,13 @@ say "naive", because it will not give us an acceptable performance. ```jldoctest -julia> value(L1DistLoss(), [2,5,-2], [1.,2,3]) +julia> value.(L1DistLoss(), [2,5,-2], [1.,2,3]) 3-element Vector{Float64}: 1.0 3.0 5.0 -julia> sum(value(L1DistLoss(), [2,5,-2], [1.,2,3])) # WARNING: Bad code +julia> sum(value.(L1DistLoss(), [2,5,-2], [1.,2,3])) # WARNING: Bad code 9.0 ``` @@ -53,52 +53,25 @@ that we don't need in the end and could avoid. For that reason we provide special methods that compute the common accumulations efficiently without allocating temporary -arrays. These methods can be invoked using an additional -parameter which specifies how the values should be accumulated / -averaged. The type of this parameter has to be a subtype of -`AggregateMode`. - -## Aggregation Modes - -Before we discuss these memory-efficient methods, let us briefly -introduce the available aggregation mode types. We provide a number -of different aggregation modes, all of which are contained within -the namespace `AggMode`. An instance of such type can then be -used as additional parameter to [`value`](@ref), [`deriv`](@ref), -and [`deriv2`](@ref), as we will see further down. - -It follows a list of available aggregation modes. Each of which with -a short description of what their effect would be when used as an -additional parameter to the functions mentioned above. - -```@docs -AggMode.None -AggMode.Sum -AggMode.Mean -AggMode.WeightedSum -AggMode.WeightedMean -``` - -## Unweighted Sum and Mean +arrays. -As hinted before, we provide special memory efficient methods for -computing the **sum** or the **mean** of the element-wise (or -broadcasted) results of [`value`](@ref), [`deriv`](@ref), and -[`deriv2`](@ref). These methods avoid the allocation of a -temporary array and instead compute the result directly. +```jldoctest +julia> sum(L1DistLoss(), [2,5,-2], [1.,2,3]) +9.0 -## Weighted Sum and Mean +julia> mean(L1DistLoss(), [2,5,-2], [1.,2,3]) +3.0 +``` Up to this point, all the averaging was performed in an unweighted manner. That means that each observation was treated as equal and had thus the same potential influence on the result. -In this sub-section we will consider the situations in which we +In the following we will consider situations in which we do want to explicitly specify the influence of each observation (i.e. we want to weigh them). When we say we "weigh" an observation, what it effectively boils down to is multiplying the -result for that observation (i.e. the computed loss or -derivative) with some number. This is done for every observation -individually. +result for that observation (i.e. the computed loss) with some number. +This is done for every observation individually. To get a better understand of what we are talking about, let us consider performing a weighting scheme manually. The following @@ -127,88 +100,10 @@ between the different weights. In the example above the second observation was thus considered twice as important as any of the other two observations. -In the case of multi-dimensional arrays the process isn't that -simple anymore. In such a scenario, computing the weighted sum -(or weighted mean) can be thought of as having an additional -step. First we either compute the sum or (unweighted) average for -each observation (which results in a vector), and then we compute -the weighted sum of all observations. - -The following code snipped demonstrates how to compute the -`AggMode.WeightedSum([2,1])` manually. This is **not** meant as -an example of how to do it, but simply to show what is happening -qualitatively. In this example we assume that we are working in a -multi-variable regression setting, in which our data set has four -observations with two target-variables each. - -```jldoctest weight -julia> targets = reshape(1:8, (2, 4)) ./ 8 -2×4 Matrix{Float64}: - 0.125 0.375 0.625 0.875 - 0.25 0.5 0.75 1.0 - -julia> outputs = reshape(1:2:16, (2, 4)) ./ 8 -2×4 Matrix{Float64}: - 0.125 0.625 1.125 1.625 - 0.375 0.875 1.375 1.875 - -julia> # WARNING: BAD CODE - ONLY FOR ILLUSTRATION - -julia> tmp = sum(value.(L1DistLoss(), outputs, targets), dims=2) -2×1 Matrix{Float64}: - 1.5 - 2.0 - -julia> sum(tmp .* [2, 1]) # weigh 1st observation twice as high -5.0 -``` - -To manually compute the result for `AggMode.WeightedMean([2,1])` -we follow a similar approach, but use the normalized weight -vector in the last step. - ```jldoctest weight -julia> using Statistics # for access to "mean" - -julia> # WARNING: BAD CODE - ONLY FOR ILLUSTRATION - -julia> tmp = mean(value.(L1DistLoss(), outputs, targets), dims=2) -2×1 Matrix{Float64}: - 0.375 - 0.5 - -julia> sum(tmp .* [0.6666, 0.3333]) # weigh 1st observation twice as high -0.416625 -``` - -Note that you can specify explicitly if you want to normalize the -weight vector. That option is supported for computing the -weighted sum, as well as for computing the weighted mean. See the -documentation for [`AggMode.WeightedSum`](@ref) and -[`AggMode.WeightedMean`](@ref) for more information. - -The code-snippets above are of course very inefficient, because -they allocate (multiple) temporary arrays. We only included them -to demonstrate what is happening in terms of desired result / -effect. For doing those computations efficiently we provide -special methods for [`value`](@ref), [`deriv`](@ref), -[`deriv2`](@ref) and their mutating counterparts. - -```jldoctest weight -julia> value(L1DistLoss(), [2,5,-2], [1.,2,3], AggMode.WeightedSum([1,2,1])) +julia> sum(L1DistLoss(), [2,5,-2], [1.,2,3], [1,2,1], normalize=false) 12.0 -julia> value(L1DistLoss(), [2,5,-2], [1.,2,3], AggMode.WeightedMean([1,2,1])) +julia> mean(L1DistLoss(), [2,5,-2], [1.,2,3], [1,2,1]) 1.0 -``` - -We also provide this functionality for [`deriv`](@ref) and -[`deriv2`](@ref) respectively. - -```jldoctest weight -julia> deriv(L2DistLoss(), [2,5,-2], [1.,2,3], AggMode.WeightedSum([1,2,1])) -4.0 - -julia> deriv(L2DistLoss(), [2,5,-2], [1.,2,3], AggMode.WeightedMean([1,2,1])) -0.3333333333333333 -``` +``` \ No newline at end of file diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index 57e7de3..d86d1d4 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -3,8 +3,8 @@ module LossFunctions using Markdown using CategoricalArrays: CategoricalValue -# aggregation mode -include("aggmode.jl") +import Base: sum +import Statistics: mean # trait functions include("traits.jl") @@ -31,9 +31,6 @@ export islipschitzcont, islocallylipschitzcont, isclipable, isclasscalibrated, issymmetric, - # relevant submodules - AggMode, - # margin-based losses ZeroOneLoss, LogitMarginLoss, @@ -68,6 +65,9 @@ export # meta losses ScaledLoss, - WeightedMarginLoss + WeightedMarginLoss, + + # reexport mean + mean end # module diff --git a/src/aggmode.jl b/src/aggmode.jl deleted file mode 100644 index eb66e19..0000000 --- a/src/aggmode.jl +++ /dev/null @@ -1,111 +0,0 @@ -""" -Baseclass for all aggregation modes. -""" -abstract type AggregateMode end - -""" - module AggMode - -Types for aggregation of multiple observations. - -- `AggMode.None()` -- `AggMode.Sum()` -- `AggMode.Mean()` -- `AggMode.WeightedSum(weights)` -- `AggMode.WeightedMean(weights)` -""" -module AggMode - using ..LossFunctions: AggregateMode - - """ - AggMode.None() - - Opt-out of aggregation. This is usually the default value. - Using `None` will cause the element-wise results to be returned. - """ - struct None <: AggregateMode end - - """ - AggMode.Sum() - - Causes the method to return the unweighted sum of the - elements instead of the individual elements. Can be used in - combination with `ObsDim`, in which case a vector will be - returned containing the sum for each observation (useful - mainly for multivariable regression). - """ - struct Sum <: AggregateMode end - - """ - AggMode.Mean() - - Causes the method to return the unweighted mean of the - elements instead of the individual elements. Can be used in - combination with `ObsDim`, in which case a vector will be - returned containing the mean for each observation (useful - mainly for multivariable regression). - """ - struct Mean <: AggregateMode end - - """ - AggMode.WeightedSum(weights; [normalize = false]) - - Causes the method to return the weighted sum of all - observations. The variable `weights` has to be a vector of - the same length as the number of observations. - If `normalize = true`, the values of the weight vector will - be normalized in such as way that they sum to one. - - # Arguments - - - `weights::AbstractVector`: Vector of weight values that - can be used to give certain observations a stronger - influence on the sum. - - - `normalize::Bool`: Boolean that specifies if the weight - vector should be transformed in such a way that it sums to - one (i.e. normalized). This will not mutate the weight - vector but instead happen on the fly during the - accumulation. - - Defaults to `false`. Setting it to `true` only really - makes sense in multivalue-regression, otherwise the result - will be the same as for [`WeightedMean`](@ref). - """ - struct WeightedSum{W<:AbstractVector} <: AggregateMode - weights::W - normalize::Bool - end - WeightedSum(weights::AbstractVector; normalize::Bool = false) = WeightedSum(weights, normalize) - - """ - AggMode.WeightedMean(weights; [normalize = true]) - - Causes the method to return the weighted mean of all - observations. The variable `weights` has to be a vector of - the same length as the number of observations. - If `normalize = true`, the values of the weight vector will - be normalized in such as way that they sum to one. - - # Arguments - - - `weights::AbstractVector`: Vector of weight values that can - be used to give certain observations a stronger influence - on the mean. - - - `normalize::Bool`: Boolean that specifies if the weight - vector should be transformed in such a way that it sums to - one (i.e. normalized). This will not mutate the weight - vector but instead happen on the fly during the - accumulation. - - Defaults to `true`. Setting it to `false` only really makes - sense in multivalue-regression, otherwise the result will - be the same as for [`WeightedSum`](@ref). - """ - struct WeightedMean{W<:AbstractVector} <: AggregateMode - weights::W - normalize::Bool - end - WeightedMean(weights::AbstractVector; normalize::Bool = true) = WeightedMean(weights, normalize) -end diff --git a/src/losses.jl b/src/losses.jl index 8f28245..b5350c7 100644 --- a/src/losses.jl +++ b/src/losses.jl @@ -1,5 +1,8 @@ -# broadcasting behavior -Broadcast.broadcastable(loss::SupervisedLoss) = Ref(loss) +# type alias to make code more readable +Scalar = Union{Number,CategoricalValue} + +# convenient functor interface +(loss::SupervisedLoss)(output::Scalar, target::Scalar) = value(loss, output, target) # fallback to unary evaluation value(loss::DistanceLoss, output::Number, target::Number) = value(loss, output - target) @@ -10,9 +13,13 @@ value(loss::MarginLoss, output::Number, target::Number) = value(loss, target * deriv(loss::MarginLoss, output::Number, target::Number) = target * deriv(loss, target * output) deriv2(loss::MarginLoss, output::Number, target::Number) = deriv2(loss, target * output) +# broadcasting behavior +Broadcast.broadcastable(loss::SupervisedLoss) = Ref(loss) + # ------------------ # AVAILABLE LOSSES # ------------------ + include("losses/distance.jl") include("losses/margin.jl") include("losses/other.jl") @@ -21,101 +28,50 @@ include("losses/other.jl") include("losses/scaled.jl") include("losses/weighted.jl") -# helper macro (for devs) -macro dimcheck(condition) - :(($(esc(condition))) || throw(DimensionMismatch("Dimensions of the parameters don't match: $($(string(condition)))"))) +# ---------------------- +# AGGREGATION BEHAVIOR +# ---------------------- + +""" + sum(loss, outputs, targets) + +Return sum of `loss` values over the iterables `outputs` and `targets`. +""" +function sum(loss::SupervisedLoss, outputs, targets) + sum(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets)) end -# ------------------------------ -# DEFAULT AGGREGATION BEHAVIOR -# ------------------------------ -for FUN in (:value, :deriv, :deriv2) - @eval begin - # by default compute the element-wise result - @inline function ($FUN)( - loss::SupervisedLoss, - outputs::AbstractVector, - targets::AbstractVector) - ($FUN)(loss, outputs, targets, AggMode.None()) - end - - # ------------------- - # AGGREGATION: NONE - # ------------------- - @generated function ($FUN)( - loss::SupervisedLoss, - outputs::AbstractVector, - targets::AbstractVector, - ::AggMode.None) - quote - $(Expr(:meta, :inline)) - ($($FUN)).(loss, outputs, targets) - end - end - - # ------------------ - # AGGREGATION: SUM - # ------------------ - function ($FUN)( - loss::SupervisedLoss, - outputs::AbstractVector, - targets::AbstractVector, - ::AggMode.Sum) - @dimcheck length(outputs) == length(targets) - nobs = length(outputs) - f(i) = ($FUN)(loss, outputs[i], targets[i]) - sum(f, 1:nobs) - end - - # ------------------- - # AGGREGATION: MEAN - # ------------------- - function ($FUN)( - loss::SupervisedLoss, - outputs::AbstractVector, - targets::AbstractVector, - ::AggMode.Mean) - @dimcheck length(outputs) == length(targets) - nobs = length(outputs) - f(i) = ($FUN)(loss, outputs[i], targets[i]) - sum(f, 1:nobs) / nobs - end - - # --------------------------- - # AGGREGATION: WEIGHTED SUM - # --------------------------- - function ($FUN)( - loss::SupervisedLoss, - outputs::AbstractVector, - targets::AbstractVector, - agg::AggMode.WeightedSum) - @dimcheck length(outputs) == length(targets) - @dimcheck length(outputs) == length(agg.weights) - nobs = length(outputs) - wsum = sum(agg.weights) - denom = agg.normalize ? wsum : one(wsum) - f(i) = agg.weights[i] * ($FUN)(loss, outputs[i], targets[i]) - sum(f, 1:nobs) / denom - end - - # ---------------------------- - # AGGREGATION: WEIGHTED MEAN - # ---------------------------- - function ($FUN)( - loss::SupervisedLoss, - outputs::AbstractVector, - targets::AbstractVector, - agg::AggMode.WeightedMean) - @dimcheck length(outputs) == length(targets) - @dimcheck length(outputs) == length(agg.weights) - nobs = length(outputs) - wsum = sum(agg.weights) - denom = agg.normalize ? nobs * wsum : nobs * one(wsum) - f(i) = agg.weights[i] * ($FUN)(loss, outputs[i], targets[i]) - sum(f, 1:nobs) / denom - end - end +""" + sum(loss, outputs, targets, weights; normalize=true) + +Return sum of `loss` values over the iterables `outputs` and `targets`. +The `weights` determine the importance of each observation. The option +`normalize` divides the result by the sum of the weights. +""" +function sum(loss::SupervisedLoss, outputs, targets, weights; normalize=true) + s = sum(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights)) + n = normalize ? sum(weights) : one(first(weights)) + s / n end -# convenient functor interface -(loss::SupervisedLoss)(outputs::AbstractVector, targets::AbstractVector) = value(loss, outputs, targets) +""" + mean(loss, outputs, targets) + +Return mean of `loss` values over the iterables `outputs` and `targets`. +""" +function mean(loss::SupervisedLoss, outputs, targets) + mean(loss(ŷ, y) for (ŷ, y) in zip(outputs, targets)) +end + +""" + mean(loss, outputs, targets, weights; normalize=true) + +Return mean of `loss` values over the iterables `outputs` and `targets`. +The `weights` determine the importance of each observation. The option +`normalize` divides the result by the sum of the weights. +""" +function mean(loss::SupervisedLoss, outputs, targets, weights; normalize=true) + m = mean(w * loss(ŷ, y) for (ŷ, y, w) in zip(outputs, targets, weights)) + n = normalize ? sum(weights) : one(first(weights)) + m / n +end \ No newline at end of file diff --git a/src/losses/other.jl b/src/losses/other.jl index 6d4cad4..0c7bf73 100644 --- a/src/losses/other.jl +++ b/src/losses/other.jl @@ -12,17 +12,14 @@ struct MisclassLoss{R<:AbstractFloat} <: SupervisedLoss end MisclassLoss() = MisclassLoss{Float64}() -# type alias to make code more readable -NumberOrValue = Union{Number,CategoricalValue} - # return floating point to avoid big integers in aggregations value(::MisclassLoss{R}, agreement::Bool) where R = ifelse(agreement, zero(R), one(R)) deriv(::MisclassLoss{R}, agreement::Bool) where R = zero(R) deriv2(::MisclassLoss{R}, agreement::Bool) where R = zero(R) -value(loss::MisclassLoss, output::NumberOrValue, target::NumberOrValue) = value(loss, target == output) -deriv(loss::MisclassLoss, output::NumberOrValue, target::NumberOrValue) = deriv(loss, target == output) -deriv2(loss::MisclassLoss, output::NumberOrValue, target::NumberOrValue) = deriv2(loss, target == output) +value(loss::MisclassLoss, output::Scalar, target::Scalar) = value(loss, target == output) +deriv(loss::MisclassLoss, output::Scalar, target::Scalar) = deriv(loss, target == output) +deriv2(loss::MisclassLoss, output::Scalar, target::Scalar) = deriv2(loss, target == output) isminimizable(::MisclassLoss) = false isdifferentiable(::MisclassLoss) = false diff --git a/src/losses/scaled.jl b/src/losses/scaled.jl index 6ffd373..10f45c6 100644 --- a/src/losses/scaled.jl +++ b/src/losses/scaled.jl @@ -27,12 +27,12 @@ for FUN in (:value, :deriv, :deriv2) end end -for FUN in [:isminimizable, :isdifferentiable, :istwicedifferentiable, - :isconvex, :isstrictlyconvex, :isstronglyconvex, - :isnemitski, :isunivfishercons, :isfishercons, - :islipschitzcont, :islocallylipschitzcont, - :isclipable, :ismarginbased, :isclasscalibrated, - :isdistancebased, :issymmetric] +for FUN in (:isminimizable, :isdifferentiable, :istwicedifferentiable, + :isconvex, :isstrictlyconvex, :isstronglyconvex, + :isnemitski, :isunivfishercons, :isfishercons, + :islipschitzcont, :islocallylipschitzcont, + :isclipable, :ismarginbased, :isclasscalibrated, + :isdistancebased, :issymmetric) @eval ($FUN)(l::ScaledLoss) = ($FUN)(l.loss) end diff --git a/src/losses/weighted.jl b/src/losses/weighted.jl index 3cb5232..4a78f17 100644 --- a/src/losses/weighted.jl +++ b/src/losses/weighted.jl @@ -47,17 +47,15 @@ isclasscalibrated(l::WeightedMarginLoss{T,W}) where {T,W} = W == 0.5 && isclassc # TODO: Think about this semantic issymmetric(::WeightedMarginLoss) = false -for prop in [:isminimizable, :isdifferentiable, - :istwicedifferentiable, - :isconvex, :isstrictlyconvex, - :isstronglyconvex, :isnemitski, - :isunivfishercons, :isfishercons, - :islipschitzcont, :islocallylipschitzcont, - :isclipable, :ismarginbased, - :isdistancebased] - @eval ($prop)(l::WeightedMarginLoss) = ($prop)(l.loss) -end - -for prop_param in (:isdifferentiable, :istwicedifferentiable) - @eval ($prop_param)(l::WeightedMarginLoss, at) = ($prop_param)(l.loss, at) +for FUN in (:isminimizable, :isdifferentiable, :istwicedifferentiable, + :isconvex, :isstrictlyconvex, :isstronglyconvex, + :isnemitski, :isunivfishercons, :isfishercons, + :islipschitzcont, :islocallylipschitzcont, + :isclipable, :ismarginbased, + :isdistancebased) + @eval ($FUN)(l::WeightedMarginLoss) = ($FUN)(l.loss) +end + +for FUN in (:isdifferentiable, :istwicedifferentiable) + @eval ($FUN)(l::WeightedMarginLoss, at) = ($FUN)(l.loss, at) end diff --git a/test/aggmode.jl b/test/agg.jl similarity index 64% rename from test/aggmode.jl rename to test/agg.jl index d0e5325..9e54838 100644 --- a/test/aggmode.jl +++ b/test/agg.jl @@ -1,33 +1,27 @@ function test_vector_value(l, o, t) ref = [value(l, o[i], t[i]) for i in 1:length(o)] - @test @inferred(value(l, o, t, AggMode.None())) == ref - @test @inferred(value(l, o, t)) == ref - @test value.(l, o, t) == ref - @test @inferred(l(o, t)) == ref + v(l, o, t) = value.(l, o, t) + @test @inferred(v(l, o, t)) == ref n = length(ref) s = sum(ref) - @test @inferred(value(l, o, t, AggMode.Sum())) ≈ s - @test @inferred(value(l, o, t, AggMode.Mean())) ≈ s / n - ## Weighted Sum - @test @inferred(value(l, o, t, AggMode.WeightedSum(ones(n)))) ≈ s - @test @inferred(value(l, o, t, AggMode.WeightedSum(ones(n),normalize=true))) ≈ s / n - ## Weighted Mean - @test @inferred(value(l, o, t, AggMode.WeightedMean(ones(n)))) ≈ (s / n) / n - @test @inferred(value(l, o, t, AggMode.WeightedMean(ones(n),normalize=false))) ≈ s / n + @test @inferred(sum(l, o, t)) ≈ s + @test @inferred(mean(l, o, t)) ≈ s / n + @test @inferred(sum(l, o, t, ones(n), normalize=false)) ≈ s + @test @inferred(sum(l, o, t, ones(n), normalize=true)) ≈ s / n + @test @inferred(mean(l, o, t, ones(n), normalize=false)) ≈ s / n + @test @inferred(mean(l, o, t, ones(n), normalize=true)) ≈ (s / n) / n end function test_vector_deriv(l, o, t) ref = [deriv(l, o[i], t[i]) for i in 1:length(o)] - @test @inferred(deriv(l, o, t, AggMode.None())) == ref - @test @inferred(deriv(l, o, t)) == ref - @test deriv.(Ref(l), o, t) == ref + d(l, o, t) = deriv.(l, o, t) + @test @inferred(d(l, o, t)) == ref end function test_vector_deriv2(l, o, t) ref = [deriv2(l, o[i], t[i]) for i in 1:length(o)] - @test @inferred(deriv2(l, o, t, AggMode.None())) == ref - @test @inferred(deriv2(l, o, t)) == ref - @test deriv2.(Ref(l), o, t) == ref + d(l, o, t) = deriv2.(l, o, t) + @test @inferred(d(l, o, t)) == ref end @testset "Vectorized API" begin @@ -63,4 +57,14 @@ end end end end +end + +@testset "Aggregation with categorical values" begin + c = categorical(["Foo","Bar","Baz","Foo"]) + l = MisclassLoss() + @test sum(l, c, reverse(c)) == 2.0 + @test mean(l, c, reverse(c)) == 0.5 + @test sum(l, c, reverse(c), 2*ones(4), normalize=false) == 4.0 + @test mean(l, c, reverse(c), 2*ones(4), normalize=false) == 1.0 + @test mean(l, c, reverse(c), 2*ones(4), normalize=true) == 0.125 end \ No newline at end of file diff --git a/test/core.jl b/test/core.jl index fdf1e51..3dd83bf 100644 --- a/test/core.jl +++ b/test/core.jl @@ -427,14 +427,9 @@ end l = MisclassLoss() @test value(l, c[1], c[1]) == 0.0 @test value(l, c[1], c[2]) == 1.0 - @test value(l, c, reverse(c)) == [0.0, 1.0, 1.0, 0.0] - @test value(l, c, reverse(c), AggMode.Sum()) == 2.0 - @test value(l, c, reverse(c), AggMode.Mean()) == 0.5 - @test value(l, c, reverse(c), AggMode.WeightedSum(2*ones(4))) == 4.0 - @test value(l, c, reverse(c), AggMode.WeightedMean(2*ones(4),false)) == 1.0 - @test value(l, c, reverse(c), AggMode.WeightedMean(2*ones(4),true)) == 0.125 + @test value.(l, c, reverse(c)) == [0.0, 1.0, 1.0, 0.0] l = MisclassLoss{Float32}() @test value(l, c[1], c[1]) isa Float32 - @test value(l, c, c) isa Vector{Float32} + @test value.(l, c, c) isa Vector{Float32} end diff --git a/test/runtests.jl b/test/runtests.jl index e9f4fd6..3e07aa6 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,7 +10,7 @@ using Test tests = [ "core.jl", "props.jl", - "aggmode.jl" + "agg.jl" ] # for deterministic testing