Skip to content

Commit

Permalink
allow VS() for multiple subgroups
Browse files Browse the repository at this point in the history
  • Loading branch information
zhizuio committed Sep 30, 2024
1 parent a493265 commit e6951dd
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 65 deletions.
4 changes: 2 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
Package: BayesSurvive
Title: Bayesian Survival Models for High-Dimensional Data
Version: 0.0.5
Date: 2024-09-27
Date: 2024-09-30
Authors@R: c(person("Zhi", "Zhao", role=c("aut","cre"), email = "[email protected]"),
person("Waldir", "Leoncio", role=c("aut")),
person("Katrin", "Madjar", role=c("aut")),
Expand All @@ -22,7 +22,7 @@ Imports: Rcpp, ggplot2, GGally, mvtnorm, survival, riskRegression,
Suggests: knitr, testthat
LazyData: true
NeedsCompilation: yes
Packaged: 2024-09-27 19:57:03 UTC; zhiz
Packaged: 2024-09-30 12:12:15 UTC; zhiz
Author: Zhi Zhao [aut, cre],
Waldir Leoncio [aut],
Katrin Madjar [aut],
Expand Down
143 changes: 85 additions & 58 deletions R/VS.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@
#' or a list consisting of matrices and arrays
#' @param method variable selection method to choose from
#' \code{c("CI", "SNC", "MPM", "FDR")}. Default is "FDR"
#' @param threshold SNC threshold value (default 0.5) or the Bayesian expected false
#' discovery rate threshold (default 0.05)
#' @param subgroup index of the subgroup for visualizing posterior coefficients
#' @param threshold SNC threshold value (default 0.5) or the Bayesian expected
#' false discovery rate threshold (default 0.05)
#' @param subgroup index(es) of subgroup(s) for visualizing variable selection
#'
#' @return A boolean vector of selected (= TRUE) and rejected (= FALSE)
#' variables
#' variables for one group or a list for multiple groups
#'
#' @references
#' Lee KH, Chakraborty S, Sun J (2015). Survival prediction and variable
Expand Down Expand Up @@ -59,7 +59,7 @@
#' # run Bayesian Cox with graph-structured priors
#' fit <- BayesSurvive(
#' survObj = dataset, hyperpar = hyperparPooled,
#' initial = initial, nIter = 100
#' initial = initial, nIter = 100, burnin = 50
#' )
#' # show variable selection
#' VS(fit, method = "FDR")
Expand All @@ -74,52 +74,48 @@ VS <- function(x, method = "FDR", threshold = NA, subgroup = 1) {
if (any(!sapply(x, function(xx) is.array(xx)))) {
stop("The list input has to consist of matrices and/or arrays!")
}
# If the input is a matrix or array, reformat it to be an list
if (!is.list(x)) {
x <- list(x)
}
}
# If the input is a matrix or array, reformat it to be an list
if (is.array(x)) {
x <- list(x)
}
if (!method %in% c("CI", "SNC", "MPM", "FDR")) { # "SNC-BIC",
stop("'method' should be one of c('CI', 'SNC', 'MPM', 'FDR')!")
}
if (inherits(x, "BayesSurvive")) {
if (x$input$S < max(subgroup)) {
stop("Argument 'subgroup' has subscript out of bounds!")
}
}

if (method == "CI") {
# for an input from "BayesSurvive"
if (inherits(x, "BayesSurvive")) {
betas <- coef.BayesSurvive(x,
type = "mean", CI = 95,
subgroup = subgroup
)
ret <- (betas$CI.lower > 0) | (betas$CI.upper < 0)
ret <- rep(list(NULL), length(subgroup))

for (l in seq_len(length(subgroup))) {
betas <- coef.BayesSurvive(x,
type = "mean", CI = 95,
subgroup = subgroup[l]
)
ret[[l]] <- (betas$CI.lower > 0) | (betas$CI.upper < 0)
}
} else {
# for an output from a matrix or array or list
if (!is.list(x)) {
betas <- list()
# if (is.matrix(x)) {
# betas$CI.lower <- apply(x, 2, quantile, 0.025)
# betas$CI.upper <- apply(x, 2, quantile, 0.975)
betas <- rep(list(NULL), length(x))
ret <- rep(list(NULL), length(x))
for (l in seq_len(length(x))) {
# if (length(dim(x[[l]])) == 2) {
# betas[[l]]$CI.lower <- apply(x[[l]], 2, quantile, 0.025)
# betas[[l]]$CI.upper <- apply(x[[l]], 2, quantile, 0.975)
# }
# if (is.array(x)) {
# betas$CI.lower <- apply(x, c(2,3), quantile, 0.025)
# betas$CI.upper <- apply(x, c(2,3), quantile, 0.975)
# if (length(dim(x[[l]])) == 3) {
# betas[[l]]$CI.lower <- apply(x[[l]], c(2,3), quantile, 0.025)
# betas[[l]]$CI.upper <- apply(x[[l]], c(2,3), quantile, 0.975)
# }
# ret <- (betas$CI.lower > 0) | (betas$CI.upper < 0)
} else {
betas <- rep(list(NULL), length(x))
ret <- rep(list(NULL), length(x))
for (l in seq_len(length(x))) {
# if (length(dim(x[[l]])) == 2) {
# betas[[l]]$CI.lower <- apply(x[[l]], 2, quantile, 0.025)
# betas[[l]]$CI.upper <- apply(x[[l]], 2, quantile, 0.975)
# }
# if (length(dim(x[[l]])) == 3) {
# betas[[l]]$CI.lower <- apply(x[[l]], c(2,3), quantile, 0.025)
# betas[[l]]$CI.upper <- apply(x[[l]], c(2,3), quantile, 0.975)
# }
betas[[l]]$CI.lower <- apply(x[[l]], seq_len(length(dim(x[[l]])))[-1], quantile, 0.025)
betas[[l]]$CI.upper <- apply(x[[l]], seq_len(length(dim(x[[l]])))[-1], quantile, 0.975)
ret[[l]] <- (betas[[l]]$CI.lower > 0) | (betas[[l]]$CI.upper < 0)
}
betas[[l]]$CI.lower <- apply(x[[l]], seq_len(length(dim(x[[l]])))[-1], quantile, 0.025)
betas[[l]]$CI.upper <- apply(x[[l]], seq_len(length(dim(x[[l]])))[-1], quantile, 0.975)
ret[[l]] <- (betas[[l]]$CI.lower > 0) | (betas[[l]]$CI.upper < 0)
}
}
}
Expand All @@ -130,16 +126,20 @@ VS <- function(x, method = "FDR", threshold = NA, subgroup = 1) {
}

if (inherits(x, "BayesSurvive")) {
if (x$input$S > 1 || !x$input$MRF.G) {
x$output$beta.p <- x$output$beta.p[[subgroup]]
}
beta_p <- x$output$beta.p[-(1:(x$input$burnin / x$input$thin + 1)), ]

ret <- rep(FALSE, NCOL(beta_p))

for (j in seq_len(NCOL(beta_p))) {
if (sum(abs(beta_p[, j]) > sd(beta_p[, j])) / NROW(beta_p) > threshold) {
ret[j] <- TRUE
ret <- rep(list(NULL), length(subgroup))

for (l in seq_len(length(subgroup))) {
if (x$input$S > 1 || !x$input$MRF.G) {
x$output$beta.p <- x$output$beta.p[[subgroup[l]]]
}
beta_p <- x$output$beta.p[-(1:(x$input$burnin / x$input$thin + 1)), ]

ret[[l]] <- rep(FALSE, NCOL(beta_p))

for (j in seq_len(NCOL(beta_p))) {
if (sum(abs(beta_p[, j]) > sd(beta_p[, j])) / NROW(beta_p) > threshold) {
ret[[l]][j] <- TRUE
}
}
}
} else {
Expand Down Expand Up @@ -175,10 +175,14 @@ VS <- function(x, method = "FDR", threshold = NA, subgroup = 1) {

if (method == "MPM") {
if (inherits(x, "BayesSurvive")) {
if (x$input$S > 1 || !x$input$MRF.G) {
x$output$gamma.margin <- x$output$gamma.margin[[subgroup]]
ret <- rep(list(NULL), length(subgroup))

for (l in seq_len(length(subgroup))) {
if (x$input$S > 1 || !x$input$MRF.G) {
x$output$gamma.margin <- x$output$gamma.margin[[subgroup[l]]]
}
ret[[l]] <- (x$output$gamma.margin >= 0.5)
}
ret <- (x$output$gamma.margin >= 0.5)
} else {
ret <- rep(list(NULL), length(x))

Expand Down Expand Up @@ -206,21 +210,39 @@ VS <- function(x, method = "FDR", threshold = NA, subgroup = 1) {
threshold <- 0.05
}
if (inherits(x, "BayesSurvive")) {
if (x$input$S > 1 || !x$input$MRF.G) {
x$output$gamma.margin <- x$output$gamma.margin[[subgroup]]
ret <- gammas <- rep(list(NULL), length(subgroup))
gammas_vec <- NULL
# save all mPIPs into a vector
for (l in seq_len(length(subgroup))) {
if (x$input$S > 1 || !x$input$MRF.G) {
gamma.hat <- x$output$gamma.margin[[subgroup[l]]]
} else {
gamma.hat <- x$output$gamma.margin
}
gammas_vec <- c(gammas_vec, gamma.hat)
}
gammas <- x$output$gamma.margin

sorted_gammas <- sort(gammas, decreasing = TRUE)
sorted_gammas <- sort(gammas_vec, decreasing = TRUE)
# computing the fdr
fdr <- cumsum((1 - sorted_gammas)) / seq_len(length(sorted_gammas))
# determine index of the largest fdr less than threshold
if (min(fdr) >= threshold) {
ret <- rep(FALSE, length(gammas))
ret <- rep(FALSE, length(gammas_vec))
} else {
thecut.index <- max(which(fdr < threshold))
gammas_threshold <- sorted_gammas[thecut.index]
ret <- (gammas > gammas_threshold)
ret_vec <- (gammas_vec > gammas_threshold)
}

# reformat the results into a list
for (l in seq_len(length(subgroup))) {
if (x$input$S > 1 || !x$input$MRF.G) {
gamma.hat <- x$output$gamma.margin[[subgroup[l]]]
} else {
gamma.hat <- x$output$gamma.margin
}
ret[[l]] <- ret_vec[1:length(gamma.hat)]
ret_vec <- ret_vec[-c(1:length(gamma.hat))]
}
} else {
ret <- gammas <- rep(list(NULL), length(x))
Expand Down Expand Up @@ -267,6 +289,11 @@ VS <- function(x, method = "FDR", threshold = NA, subgroup = 1) {
}
}
}

# unlist an one-component list
if (length(ret) == 1) {
ret <- unlist(ret)
}

ret
}
10 changes: 5 additions & 5 deletions man/VS.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit e6951dd

Please sign in to comment.