Skip to content

Commit

Permalink
add proper cross validation groups
Browse files Browse the repository at this point in the history
  • Loading branch information
odunbar committed Oct 8, 2024
1 parent c340e48 commit f15d008
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 165 deletions.
70 changes: 38 additions & 32 deletions src/RandomFeature.jl
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ function calculate_mean_cov_and_coeffs(
l::ForVM,
regularization::MorUSorD,
n_features::Int,
n_train::Int,
n_test::Int,
train_idx::VV,
test_idx::VV,
batch_sizes::Union{Dict{S, Int}, Nothing},
io_pairs::PairedDataContainer,
decomp_type::S,
Expand All @@ -444,6 +444,7 @@ function calculate_mean_cov_and_coeffs(
RFI <: RandomFeatureInterface,
RNG <: AbstractRNG,
ForVM <: Union{AbstractFloat, AbstractVecOrMat},
VV <: AbstractVector,
S <: AbstractString,
MorUSorD <: Union{Matrix, UniformScaling, Diagonal},
M <: AbstractMatrix{<:AbstractFloat},
Expand All @@ -452,16 +453,15 @@ function calculate_mean_cov_and_coeffs(
}

# split data into train/test
itrain = get_inputs(io_pairs)
otrain = get_outputs(io_pairs)
itrain = get_inputs(io_pairs)[:, train_idx]
otrain = get_outputs(io_pairs)[:, train_idx]
io_train_cost = PairedDataContainer(itrain, otrain)
# itest = get_inputs(io_pairs)[:, (n_train + 1):end]
# otest = get_outputs(io_pairs)[:, (n_train + 1):end]
itest = get_inputs(io_pairs)
otest = get_outputs(io_pairs)
itest = get_inputs(io_pairs)[:, test_idx]
otest = get_outputs(io_pairs)[:, test_idx]
input_dim = size(itrain, 1)
output_dim = size(otrain, 1)
n_test = size(itest, 2)

# build and fit the RF
rfm = RFM_from_hyperparameters(
rfi,
Expand Down Expand Up @@ -605,8 +605,8 @@ function estimate_mean_and_coeffnorm_covariance(
l::ForVM,
regularization::MorUSorD,
n_features::Int,
n_train::Int,
n_test::Int,
train_idx::VV,
test_idx::VV,
batch_sizes::Union{Dict{S, Int}, Nothing},
io_pairs::PairedDataContainer,
n_samples::Int,
Expand All @@ -618,20 +618,22 @@ function estimate_mean_and_coeffnorm_covariance(
RFI <: RandomFeatureInterface,
RNG <: AbstractRNG,
ForVM <: Union{AbstractFloat, AbstractVecOrMat},
VV <: AbstractVector,
S <: AbstractString,
MorUSorD <: Union{Matrix, UniformScaling, Diagonal},
}

output_dim = size(get_outputs(io_pairs), 1)

n_test = length(test_idx)

means = zeros(output_dim, n_samples, n_test)
mean_of_covs = zeros(output_dim, output_dim, n_test)
moc_tmp = similar(mean_of_covs)
mtmp = zeros(output_dim, n_test)
buffer = zeros(n_test, output_dim, n_features)
complexity = zeros(1, n_samples)
coeffl2norm = zeros(1, n_samples)
println("estimate cov with " * string(n_samples * repeats) * " iterations...")
println("estimate cov with " * string(n_samples) * " iterations...")

for i in ProgressBar(1:n_samples)
for j in 1:repeats
Expand All @@ -641,8 +643,8 @@ function estimate_mean_and_coeffnorm_covariance(
l,
regularization,
n_features,
n_train,
n_test,
train_idx,
test_idx,
batch_sizes,
io_pairs,
decomp_type,
Expand Down Expand Up @@ -708,8 +710,8 @@ function calculate_ensemble_mean_and_coeffnorm(
lvecormat::VorM,
regularization::MorUSorD,
n_features::Int,
n_train::Int,
n_test::Int,
train_idx::VV,
test_idx::VV,
batch_sizes::Union{Dict{S, Int}, Nothing},
io_pairs::PairedDataContainer,
decomp_type::S,
Expand All @@ -719,6 +721,7 @@ function calculate_ensemble_mean_and_coeffnorm(
RFI <: RandomFeatureInterface,
RNG <: AbstractRNG,
VorM <: AbstractVecOrMat,
VV <: AbstractVector,
S <: AbstractString,
MorUSorD <: Union{Matrix, UniformScaling, Diagonal},
}
Expand All @@ -729,7 +732,8 @@ function calculate_ensemble_mean_and_coeffnorm(
end
N_ens = size(lmat, 2)
output_dim = size(get_outputs(io_pairs), 1)

n_test = length(test_idx)

means = zeros(output_dim, N_ens, n_test)
mean_of_covs = zeros(output_dim, output_dim, n_test)
buffer = zeros(n_test, output_dim, n_features)
Expand All @@ -738,7 +742,7 @@ function calculate_ensemble_mean_and_coeffnorm(
moc_tmp = similar(mean_of_covs)
mtmp = zeros(output_dim, n_test)

println("calculating " * string(N_ens * repeats) * " ensemble members...")
println("calculating " * string(N_ens) * " ensemble members...")

for i in ProgressBar(1:N_ens)
for j in collect(1:repeats)
Expand All @@ -750,8 +754,8 @@ function calculate_ensemble_mean_and_coeffnorm(
l,
regularization,
n_features,
n_train,
n_test,
train_idx,
test_idx,
batch_sizes,
io_pairs,
decomp_type,
Expand Down Expand Up @@ -795,8 +799,8 @@ function estimate_mean_and_coeffnorm_covariance(
l::ForVM,
regularization::MorUSorD,
n_features::Int,
n_train::Int,
n_test::Int,
train_idx::VV,
test_idx::VV,
batch_sizes::Union{Dict{S, Int}, Nothing},
io_pairs::PairedDataContainer,
n_samples::Int,
Expand All @@ -808,13 +812,14 @@ function estimate_mean_and_coeffnorm_covariance(
RFI <: RandomFeatureInterface,
RNG <: AbstractRNG,
ForVM <: Union{AbstractFloat, AbstractVecOrMat},
VV <: AbstractVector,
S <: AbstractString,
MorUSorD <: Union{Matrix, UniformScaling, Diagonal},
}

output_dim = size(get_outputs(io_pairs), 1)

println("estimate cov with " * string(n_samples * repeats) * " iterations...")
n_test = length(test_idx)
println("estimate cov with " * string(n_samples) * " iterations...")

nthreads = Threads.nthreads()
rng_seed = randperm(rng, 10^5)[1] # dumb way to get a random integer in 1:10^5
Expand Down Expand Up @@ -843,8 +848,8 @@ function estimate_mean_and_coeffnorm_covariance(
l,
regularization,
n_features,
n_train,
n_test,
train_idx,
test_idx,
batch_sizes,
io_pairs,
decomp_type,
Expand Down Expand Up @@ -917,8 +922,8 @@ function calculate_ensemble_mean_and_coeffnorm(
lvecormat::VorM,
regularization::MorUSorD,
n_features::Int,
n_train::Int,
n_test::Int,
train_idx::VV,
test_idx::VV,
batch_sizes::Union{Dict{S, Int}, Nothing},
io_pairs::PairedDataContainer,
decomp_type::S,
Expand All @@ -928,6 +933,7 @@ function calculate_ensemble_mean_and_coeffnorm(
RFI <: RandomFeatureInterface,
RNG <: AbstractRNG,
VorM <: AbstractVecOrMat,
VV <: AbstractVector,
S <: AbstractString,
MorUSorD <: Union{Matrix, UniformScaling, Diagonal},
}
Expand All @@ -938,10 +944,10 @@ function calculate_ensemble_mean_and_coeffnorm(
end
N_ens = size(lmat, 2)
output_dim = size(get_outputs(io_pairs), 1)
n_test = length(test_idx)



println("calculating " * string(N_ens * repeats) * " ensemble members...")
println("calculating " * string(N_ens) * " ensemble members...")

nthreads = Threads.nthreads()
c_list = [zeros(1, N_ens) for i in 1:nthreads]
Expand All @@ -968,8 +974,8 @@ function calculate_ensemble_mean_and_coeffnorm(
l,
regularization,
n_features,
n_train,
n_test,
train_idx,
test_idx,
batch_sizes,
io_pairs,
decomp_type,
Expand Down
Loading

0 comments on commit f15d008

Please sign in to comment.