Skip to content

Commit

Permalink
replace dist_skel with discrete_pmf function (#720)
Browse files Browse the repository at this point in the history
* replace dist_skel with discrete_pmf function

* expand documentation
  • Loading branch information
sbfnk authored Jul 17, 2024
1 parent a63c89a commit 3e06224
Show file tree
Hide file tree
Showing 11 changed files with 366 additions and 172 deletions.
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ Encoding: UTF-8
Language: en-GB
LazyData: true
Roxygen: list(markdown = TRUE)
RoxygenNote: 7.3.1
RoxygenNote: 7.3.2
NeedsCompilation: yes
SystemRequirements: GNU make
C++17
Expand Down
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ importFrom(stats,pexp)
importFrom(stats,pgamma)
importFrom(stats,plnorm)
importFrom(stats,pnorm)
importFrom(stats,qexp)
importFrom(stats,qgamma)
importFrom(stats,qlnorm)
importFrom(stats,qnorm)
importFrom(stats,quantile)
importFrom(stats,quasipoisson)
importFrom(stats,rexp)
Expand Down
186 changes: 186 additions & 0 deletions R/deprecated.R
Original file line number Diff line number Diff line change
Expand Up @@ -447,3 +447,189 @@ rstan_opts <- function(object = NULL,
"stan_opts()"
)
}

#' Distribution Skeleton
#'
#' @description `r lifecycle::badge("deprecated")`
#' This function acts as a skeleton for a truncated distribution defined by
#' model type, maximum value and model parameters. It is designed to be used
#' with the output from [get_dist()].
#'
#' @param n Numeric vector, number of samples to take (or days for the
#' probability density).
#'
#' @param dist Logical, defaults to `FALSE`. Should the probability density be
#' returned rather than a number of samples.
#'
#' @param cum Logical, defaults to `TRUE`. If `dist = TRUE` should the returned
#' distribution be cumulative.
#'
#' @param model Character string, defining the model to be used. Supported
#' options are exponential ("exp"), gamma ("gamma"), and log normal
#' ("lognormal")
#'
#' @param discrete Logical, defaults to `FALSE`. Should the probability
#' distribution be discretised. In this case each entry of the probability
#' mass function corresponds to the 2-length interval ending at the entry
#' except for the first interval that covers (0, 1). That is, the probability
#' mass function is a vector where the first entry corresponds to the integral
#' over the (0,1] interval of the continuous distribution, the second entry
#' corresponds to the (0,2] interval, the third entry corresponds to the (1,
#' 3] interval etc.
#'
#' @param params A list of parameters values (by name) required for each model.
#' For the exponential model this is a rate parameter and for the gamma model
#' this is alpha and beta.
#'
#' @param max_value Numeric, the maximum value to allow. Defaults to 120.
#' Samples outside of this range are resampled.
#'
#' @return A vector of samples or a probability distribution.
#' @export
#' @examples
#'
#' ## Exponential model
#' # sample
#' dist_skel(10, model = "exp", params = list(rate = 1))
#'
#' # cumulative prob density
#' dist_skel(1:10, model = "exp", dist = TRUE, params = list(rate = 1))
#'
#' # probability density
#' dist_skel(1:10,
#' model = "exp", dist = TRUE,
#' cum = FALSE, params = list(rate = 1)
#' )
#'
#' ## Gamma model
#' # sample
#' dist_skel(10, model = "gamma", params = list(shape = 1, rate = 0.5))
#'
#' # cumulative prob density
#' dist_skel(0:10,
#' model = "gamma", dist = TRUE,
#' params = list(shape = 1, rate = 0.5)
#' )
#'
#' # probability density
#' dist_skel(0:10,
#' model = "gamma", dist = TRUE,
#' cum = FALSE, params = list(shape = 2, rate = 0.5)
#' )
#'
#' ## Log normal model
#' # sample
#' dist_skel(10,
#' model = "lognormal", params = list(meanlog = log(5), sdlog = log(2))
#' )
#'
#' # cumulative prob density
#' dist_skel(0:10,
#' model = "lognormal", dist = TRUE,
#' params = list(meanlog = log(5), sdlog = log(2))
#' )
#'
#' # probability density
#' dist_skel(0:10,
#' model = "lognormal", dist = TRUE, cum = FALSE,
#' params = list(meanlog = log(5), sdlog = log(2))
#' )
dist_skel <- function(n, dist = FALSE, cum = TRUE, model,
discrete = FALSE, params, max_value = 120) {
lifecycle::deprecate_warn(
"1.6.0", "dist_skel()"
)
## define unnormalised support function
if (model == "exp") {
updist <- function(n) {
pexp(n, params[["rate"]])
}
} else if (model == "gamma") {
updist <- function(n) {
pgamma(n, params[["shape"]], params[["rate"]])
}
} else if (model == "lognormal") {
updist <- function(n) {
plnorm(n, params[["meanlog"]], params[["sdlog"]])
}
} else if (model == "normal") {
updist <- function(n) {
pnorm(n, params[["mean"]], params[["sd"]])
}
} else if (model == "fixed") {
updist <- function(n) {
as.integer(n > params[["value"]])
}
}

if (discrete) {
cmf <- c(0, updist(1),
updist(seq_len(max_value)) + updist(seq_len(max_value) + 1)
) /
(updist(max_value) + updist(max_value + 1))
pmf <- diff(cmf)
rdist <- function(n) {
sample(
x = seq_len(max_value + 1) - 1, size = n, prob = pmf, replace = TRUE
)
}
pdist <- function(n) {
cmf[n + 1]
}
ddist <- function(n) {
pmf[n + 1]
}
} else {
pdist <- function(n) {
updist(n) / updist(max_value + 1)
}
ddist <- function(n) {
pdist(n + 1) - pdist(n)
}
if (model == "exp") {
rdist <- function(n) {
rexp(n, params[["rate"]])
}
} else if (model == "gamma") {
rdist <- function(n) {
rgamma(n, params[["shape"]], params[["rate"]])
}
} else if (model == "lognormal") {
rdist <- function(n) {
rlnorm(n, params[["meanlog"]], params[["sdlog"]])
}
}
}

# define internal sampling function
inner_skel <- function(n, dist = FALSE, cum = TRUE, max_value = NULL) {
if (dist) {
if (cum) {
ret <- pdist(n)
} else {
ret <- ddist(n)
}
ret[ret > 1] <- NA_real_
return(ret)
} else {
rdist(n)
}
}

# define truncation wrapper
truncated_skel <- function(n, dist, cum, max_value) {
n <- inner_skel(n, dist, cum, max_value)
if (!dist) {
while (any(!is.na(n) & n >= max_value)) {
n <- ifelse(n >= max_value, inner_skel(n), n)
}

n <- as.integer(n)
}
return(n)
}

# call function
sample <- truncated_skel(n, dist = dist, cum = cum, max_value = max_value)
return(sample)
}
Loading

0 comments on commit 3e06224

Please sign in to comment.