Skip to content

Commit

Permalink
Merge branch 'as/allow_arbitrary_cols' of https://github.com/xKDR/Sur…
Browse files Browse the repository at this point in the history
…vey.jl into as/allow_arbitrary_cols
  • Loading branch information
smishr committed Mar 28, 2023
2 parents 2162a59 + a7bcb94 commit 362c67d
Showing 1 changed file with 26 additions and 16 deletions.
42 changes: 26 additions & 16 deletions src/bootstrap.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,27 +25,18 @@ replicates: 1000
function bootweights(design::SurveyDesign; replicates = 4000, rng = MersenneTwister(1234))
stratified = groupby(design.data, design.strata)
H = length(keys(stratified))
substrata_dfs = DataFrame[]
substrata_dfs = Vector{DataFrame}(undef, H)
for h = 1:H
substrata = DataFrame(stratified[h])
cluster_sorted = sort(substrata, design.cluster)
psus = unique(cluster_sorted[!, design.cluster])
npsus = [(count(==(i), cluster_sorted[!, design.cluster])) for i in psus]
nh = length(psus)
cluster_sorted_designcluster = cluster_sorted[!, design.cluster]
cluster_weights = cluster_sorted[!, design.weights]
for replicate = 1:replicates
randinds = rand(rng, 1:(nh), (nh - 1))
cluster_sorted[!, "replicate_"*string(replicate)] =
vcat(
[
fill((count(==(i), randinds)) * (nh / (nh - 1)), npsus[i]) for
i = 1:nh
]...,
) .* cluster_weights
end
push!(substrata_dfs, cluster_sorted)
# Perform the inner loop in a type-stable function to improve runtime.
_bootweights_cluster_sorted!(cluster_sorted, cluster_weights,
cluster_sorted_designcluster, replicates, rng)
substrata_dfs[h] = cluster_sorted
end
df = vcat(substrata_dfs...)
df = reduce(vcat, substrata_dfs)
return ReplicateDesign(
df,
design.cluster,
Expand All @@ -60,3 +51,22 @@ function bootweights(design::SurveyDesign; replicates = 4000, rng = MersenneTwis
[Symbol("replicate_"*string(replicate)) for replicate in 1:replicates]
)
end

function _bootweights_cluster_sorted!(cluster_sorted,
cluster_weights, cluster_sorted_designcluster, replicates, rng)

psus = unique(cluster_sorted_designcluster)
npsus = [count(==(i), cluster_sorted_designcluster) for i in psus]
nh = length(psus)
for replicate = 1:replicates
randinds = rand(rng, 1:(nh), (nh - 1))
cluster_sorted[!, "replicate_"*string(replicate)] =
reduce(vcat,
[
fill((count(==(i), randinds)) * (nh / (nh - 1)), npsus[i]) for
i = 1:nh
]
) .* cluster_weights
end
cluster_sorted
end

0 comments on commit 362c67d

Please sign in to comment.