Skip to content

Commit

Permalink
Merge branch 'serenejiang-dev' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
shihuang047 committed Apr 16, 2021
2 parents 00adc94 + 4725e2f commit 62bb31e
Show file tree
Hide file tree
Showing 4 changed files with 88 additions and 0 deletions.
Binary file added .DS_Store
Binary file not shown.
28 changes: 28 additions & 0 deletions code_method/code_method/cv_method.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ library(ranger) # faster random forest
###########################################
#### Lasso (tune lambda) ##################
###########################################
<<<<<<< HEAD
#' @title lasso_cv
#' @description Does a K-fold cross-validation for Lasso.
#' @param datx The input data matrix.
Expand All @@ -27,6 +28,8 @@ library(ranger) # faster random forest
#' @param lambda.grid The tuning range for the regularization parameter "lambda" of "lasso".
#' @return A list of Lasso model output including MSE, and selected features.

=======
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5
lasso_cv = function(datx, y, seednum=31, family=family, ratio.training=0.8, fold.cv=10,
lambda.grid, lambda.choice='lambda.1se'){
# seednum: the seed number
Expand Down Expand Up @@ -70,6 +73,7 @@ lasso_cv = function(datx, y, seednum=31, family=family, ratio.training=0.8, fold
# reference 2 (extract final model and prediction with caret)
## https://topepo.github.io/caret/model-training-and-tuning.html

<<<<<<< HEAD
#' @title elnet_cv
#' @description Does a K-fold cross-validation for elastic-net.
#' @param datx The input data matrix.
Expand All @@ -81,6 +85,8 @@ lasso_cv = function(datx, y, seednum=31, family=family, ratio.training=0.8, fold
#' @param lambda.grid The tuning range for the regularization parameter "lambda" of "lasso".
#' @param alpha.grid The tuning range for the elastic-net mixing parameter "alpha".
#' @return A list of elnet model output including MSE, and selected features.
=======
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5

elnet_cv = function(datx, y, seednum=31, alpha.grid, lambda.grid, family=family,
ratio.training=0.8, fold.cv=10){
Expand Down Expand Up @@ -137,6 +143,7 @@ elnet_cv = function(datx, y, seednum=31, alpha.grid, lambda.grid, family=family,
## https://uc-r.github.io/random_forests
# OOB error is different from test error (see above website)

<<<<<<< HEAD
#' @title randomForest_cv
#' @description Does a K-fold cross-validation for random forests.
#' @param datx The input data matrix.
Expand All @@ -150,6 +157,8 @@ elnet_cv = function(datx, y, seednum=31, alpha.grid, lambda.grid, family=family,
#' @param pval_thr The threshold for the estimated p value of RF importance scores.
#' @param method.perm The permutation method for estimating the p value of RF importance scores.
#' @return A list of RF model output including the "mtry" grid, selected features, MSE, OOB, and p-values of selected features.
=======
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5

randomForest_cv = function(datx, y, seednum=31, fold.cv=5, ratio.training=0.8, mtry.grid=10, num_trees=500,
pval_thr=0.05, method.perm='altmann'){
Expand All @@ -171,7 +180,11 @@ randomForest_cv = function(datx, y, seednum=31, fold.cv=5, ratio.training=0.8, m
# tune parameter with cross validation
hyper.grid <- expand.grid(mtry = mtry.grid, OOB_RMSE = 0)
for (i in 1:nrow(hyper.grid)){
<<<<<<< HEAD
model = ranger::ranger(y ~., data = train,
=======
model = ranger(y ~., data = train,
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5
num.trees=500, mtry=hyper.grid$mtry[i],
seed=seednum, importance = 'permutation')
hyper.grid$OOB_RMSE[i] = sqrt(model$prediction.error)
Expand All @@ -181,12 +194,20 @@ randomForest_cv = function(datx, y, seednum=31, fold.cv=5, ratio.training=0.8, m

# permutation test on tuned random forst model to obtain chosen features
if (method.perm == 'altmann'){ # for all data types
<<<<<<< HEAD
rf.model <- ranger::ranger(y ~., data=test, num.trees = num_trees,
=======
rf.model <- ranger(y ~., data=test, num.trees = num_trees,
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5
mtry = hyper.grid$mtry[position], importance = 'permutation')
table = as.data.frame(importance_pvalues(rf.model, method = "altmann",
formula = y ~ ., data = test))
} else if (method.perm == 'janitza'){ # for high dimensional data only
<<<<<<< HEAD
rf.model <- ranger::ranger(y ~., data=test, num.trees = num_trees,
=======
rf.model <- ranger(y ~., data=test, num.trees = num_trees,
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5
mtry = hyper.grid$mtry[position], importance = 'impurity_corrected')
table = as.data.frame(importance_pvalues(rf.model, method = "janitza",
formula = y ~ ., data = test))
Expand Down Expand Up @@ -215,6 +236,7 @@ randomForest_cv = function(datx, y, seednum=31, fold.cv=5, ratio.training=0.8, m
dyn.load("../../code_Lin/cvs/cdmm.dll")
source("../../code_Lin/cvs/cdmm.R")

<<<<<<< HEAD
#' @title cons_lasso_cv
#' @description Does a K-fold cross-validation for elastic-net.
#' @param datx The input data matrix.
Expand All @@ -223,6 +245,8 @@ source("../../code_Lin/cvs/cdmm.R")
#' @param ratio.training The ratio of the whole data assigned for model training (default=0.8).
#' @return A list of compLasso model output including MSE, and selected features.

=======
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5
cons_lasso_cv = function(datx, y, seednum, ratio.training=0.8){
set.seed(seednum)
z = datx
Expand All @@ -248,7 +272,11 @@ lasso_double_cv = function(datx, y, seednum=31, family=family, fold.cv=10, lambd
set.seed(seednum)

## double loop for cross-vlidation
<<<<<<< HEAD
flds <- caret::createFolds(y, k = fold.cv, list = TRUE, returnTrain = FALSE)
=======
flds <- createFolds(y, k = fold.cv, list = TRUE, returnTrain = FALSE)
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5

## outer loop for estimating MSE
MSE = STAB = matrix(rep(0, fold.cv * length(lambda.grid)), nrow=fold.cv)
Expand Down
11 changes: 11 additions & 0 deletions code_method/code_method/cv_sim_apply.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
### apply different feature selection methods to simulated data #######################
#######################################################################################

<<<<<<< HEAD
#' library(FSA) # for se()
#' source('cv_method.R')
#' source('getStability.R')
Expand All @@ -20,6 +21,11 @@
#' @param pval_thr The threshold for the estimated p value of RF importance scores.
#' @param method.perm The permutation method for estimating the p value of RF importance scores.
#' @details
=======
# library(FSA) # for se()
# source('cv_method.R')
# source('getStability.R')
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5

sim_evaluate_cv = function(sim_file, method, seednum=31, ratio.training=0.8, fold.cv=10, family='gaussian',
lambda.grid=exp(seq(-4, -2, 0.2)), alpha.grid=seq(0.1, 0.9, 0.1),
Expand Down Expand Up @@ -77,8 +83,13 @@ sim_evaluate_cv = function(sim_file, method, seednum=31, ratio.training=0.8, fol
}

# store results
<<<<<<< HEAD
FP = paste(mean(fp), '(', round(se(fp),2), ')') # FPR?
FN = paste(mean(fn), '(', round(se(fn),2), ')') # FNR?
=======
FP = paste(mean(fp), '(', round(se(fp),2), ')')
FN = paste(mean(fn), '(', round(se(fn),2), ')')
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5
MSE = paste(round(mean(mse, na.rm=T),2), '(', round(se(mse, na.rm=T),2), ')')
Stab = round(getStability(stability.table)$stability, 2)

Expand Down
49 changes: 49 additions & 0 deletions code_method/code_method/getStability.R
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
<<<<<<< HEAD
#' ## source code: https://github.com/nogueirs/JMLR2018/blob/master/R/getStability.R
#' @title getStability
#' @description Nogueira's stability measure from a binary matrix representing
Expand Down Expand Up @@ -52,6 +53,11 @@
#' getStability(Z_alt)$stability # -0.00010001, very close to 0
#' @export
getStability <- function(X, alpha=0.05) {
=======
# source code: https://github.com/nogueirs/JMLR2018/blob/master/R/getStability.R

getStability <- function(X,alpha=0.05) {
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5
## the input X is a binary matrix of size M*d where:
## M is the number of bootstrap replicates
## d is the total number of features
Expand Down Expand Up @@ -85,6 +91,49 @@ return(list("stability"=stability,"variance"=var_stab,"lower"=lower,"upper"=uppe
}


<<<<<<< HEAD
=======
# ################################
# ## extreme cases example #######
# ################################
# d = 2 # number of features
# M = 10 # number of bootstrap replicates

# ## case 1: when stability index undefined -- Z all zeros or all ones (Nogueria2018 p.13)
# Z_all_missed = matrix(rep(0, M*d), nrow=M) # since K_bar = 0, thus SI undefined
# getStability(Z_all_missed)$stability

# Z_all_selected = matrix(rep(1, M*d), nrow=M) # since K_bar = d = 3, thus SI underfined
# getStability(Z_all_selected)$stability

# ## case 2: when stability index reaches maximum 1 -- each column of Z either all ones or all zeros (but not only zeros or only ones)(Nogueria2018 p.13)
# # this was the case when we got the wrong SI almost 1 for soil datasets: since sampled dataset the same across all
# d = 10
# for (i in 1: (d-1)){
# d_ones = sample(seq(1, d, 1), i)
# Z_tmp = matrix(rep(0, M*d), nrow=M)
# Z_tmp[, d_ones] = 1 # since Sf = 0 for all features, thus SI = 1
# SI = getStability(Z_tmp)$stability
# print(i)
# print(paste('SI:', SI, sep=''))
# }

# ## case 3: when stability index near minimum 0 (appedix D: - 1/(M-1), but as M goes to infinity, minimum asymptotically 0)
# ## when for each column of feature, it receives same numbers of 0 and 1
# d_alt = rep(c(0, 1), M/2)
# Z_alt = matrix(rep(d_alt, d), ncol=d)
# getStability(Z_alt)$stability # -0.1111111, which is - 1/(M-1)

# d_alt_2 = c(rep(0, M/2), rep(1, M/2))
# Z_alt_2 = matrix(rep(d_alt_2, d), ncol=d)
# getStability(Z_alt_2)$stability # -0.1111111, which is - 1/(M-1)

# # when M goes to infinity
# M = 10000
# d_alt = rep(c(0, 1), M/2)
# Z_alt = matrix(rep(d_alt, d), ncol=d)
# getStability(Z_alt)$stability # -0.00010001, very close to 0
>>>>>>> 36c91ddc2d7a7742a933ce6ac5bf93b876439dc5



Expand Down

0 comments on commit 62bb31e

Please sign in to comment.