From d0002b326ac30167f0494a91f07c99bb5317347e Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 23 Apr 2024 16:09:33 +1200 Subject: [PATCH] make test of iterator(...) more robust --- .../one_dimensional_range_methods.jl | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/test/hyperparam/one_dimensional_range_methods.jl b/test/hyperparam/one_dimensional_range_methods.jl index 386cadf4..2ffd8101 100644 --- a/test/hyperparam/one_dimensional_range_methods.jl +++ b/test/hyperparam/one_dimensional_range_methods.jl @@ -233,31 +233,31 @@ end @testset "NominalSampler" begin r = range(Char, :(model.dummy), values=collect("cab")) + N = 10000 - @testset "probability vector specified" begin - s = MLJBase.sampler(r, [0.1, 0.2, 0.7]) - rng = StableRNG(600) - dict = Dist.countmap(rand(rng,s, 1000)) - c, a, b = map(x -> dict[x], collect("cab")) - @test a == 201 && b == 714 && c == 85 + # to compute half-width of 95% confidence intervals, for counts of a Bernoulli process + # with probability `p`, sampled `N` times: + halfwidth(p, N) = 1.96*sqrt(p*(1 - p))*sqrt(N) - rng = StableRNG(89); - dict = Dist.countmap(rand(rng, s, 1000)) - c, a, b = map(x -> dict[x], collect("cab")) - @test a == 173 && b == 733 && c == 94 + @testset "probability vector specified" begin + p = Dict('c'=>0.1, 'a'=>0.2, 'b'=>0.7) + rng = StableRNG(660) + s = MLJBase.sampler(r, [p[class] for class in "cab"]) + counts = Dist.countmap(rand(rng,s, N)) + for class in "abc" + μ = p[class]*N + @test abs(counts[class] - μ) < halfwidth(p[class], N) + end end @testset "probability vector unspecified (uniform)" begin s = MLJBase.sampler(r) - rng = StableRNG(55) - dict = Dist.countmap(rand(rng,s, 1000)) - c, a, b = map(x -> dict[x], collect("cab")) - @test a == 361 && b == 335 && c == 304 - - rng = StableRNG(550) - dict = Dist.countmap(rand(rng, s, 1000)) - c, a, b = map(x -> dict[x], collect("cab")) - @test a == 332 && b == 356 && c == 312 + rng = StableRNG(660) + counts = Dist.countmap(rand(rng,s, N)) + for class in "abc" + μ = N/3 + @test abs(counts[class] - μ) < halfwidth(1/3, N) + end end end