From 73d3a44f60c09722210a14ded82f9316db06b952 Mon Sep 17 00:00:00 2001 From: James Azam Date: Thu, 22 Aug 2024 16:48:31 +0100 Subject: [PATCH] Check tail of nonparametric PMFs (#752) * Add function to check pmf tail * Add test * Apply tail check to pmf * Add NEWS item * Improve documentation * Update NEWS.md Co-authored-by: Sam Abbott --------- Co-authored-by: Sam Abbott --- NEWS.md | 1 + R/checks.R | 24 ++++++++++++++++++++++++ R/dist_spec.R | 1 + man/check_sparse_pmf_tail.Rd | 23 +++++++++++++++++++++++ tests/testthat/test-checks.R | 5 +++++ 5 files changed, 54 insertions(+) create mode 100644 man/check_sparse_pmf_tail.Rd diff --git a/NEWS.md b/NEWS.md index 556bf50e4..3062b7423 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,7 @@ - The interface for defining delay distributions has been generalised to also cater for continuous distributions - When defining probability distributions these can now be truncated using the `tolerance` argument - Ornstein-Uhlenbeck and 5 / 2 Matérn kernels have been added. By @sbfnk in # and reviewed by @. +- A warning is now thrown if nonparametric PMFs passed to delay options have consecutive tail values that are below a certain low threshold as these lead to loss in speed with little gain in accuracy. By @jamesmbaazam in #752 and reviewed by @seabbs. ## Bug fixes diff --git a/R/checks.R b/R/checks.R index cf2f6de6f..4ad82662a 100644 --- a/R/checks.R +++ b/R/checks.R @@ -107,3 +107,27 @@ check_stan_delay <- function(dist) { ) } } + +#' Check that PMF tail is not sparse +#' +#' @description Checks if the tail of a PMF vector has more than `span` +#' consecutive values smaller than `tol` and throws a warning if so. +#' @param pmf A probability mass function vector +#' @param span The number of consecutive indices in the tail to check +#' @param tol The value which to consider the tail as sparse +#' +#' @return Called for its side effects. +#' @keywords internal +check_sparse_pmf_tail <- function(pmf, span = 5, tol = 1e-6) { + if (all(pmf[(length(pmf) - span + 1):length(pmf)] < tol)) { + warning( + sprintf( + "The PMF tail has %s consecutive values smaller than %s.", + span, tol + ), + " This will drastically increase run time with very small increases ", + "in accuracy. Consider increasing the tail values of the PMF.", + call. = FALSE + ) + } +} diff --git a/R/dist_spec.R b/R/dist_spec.R index 8016c8f32..ca3c52764 100644 --- a/R/dist_spec.R +++ b/R/dist_spec.R @@ -903,6 +903,7 @@ Fixed <- function(value, ...) { #' NonParametric(c(0.1, 0.3, 0.2, 0.4)) #' NonParametric(c(0.1, 0.3, 0.2, 0.1, 0.1)) NonParametric <- function(pmf, ...) { + check_sparse_pmf_tail(pmf) params <- list(pmf = pmf / sum(pmf)) return(new_dist_spec(params, "nonparametric")) } diff --git a/man/check_sparse_pmf_tail.Rd b/man/check_sparse_pmf_tail.Rd new file mode 100644 index 000000000..6ba856cac --- /dev/null +++ b/man/check_sparse_pmf_tail.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/checks.R +\name{check_sparse_pmf_tail} +\alias{check_sparse_pmf_tail} +\title{Check that PMF tail is not sparse} +\usage{ +check_sparse_pmf_tail(pmf, span = 5, tol = 1e-06) +} +\arguments{ +\item{pmf}{A probability mass function vector} + +\item{span}{The number of consecutive indices in the tail to check} + +\item{tol}{The value which to consider the tail as sparse} +} +\value{ +Called for its side effects. +} +\description{ +Checks if the tail of a PMF vector has more than \code{span} +consecutive values smaller than \code{tol} and throws a warning if so. +} +\keyword{internal} diff --git a/tests/testthat/test-checks.R b/tests/testthat/test-checks.R index e9670a6fc..0fe68715a 100644 --- a/tests/testthat/test-checks.R +++ b/tests/testthat/test-checks.R @@ -141,3 +141,8 @@ test_that("check_reports_valid errors for bad 'secondary' specifications", { # Run tests test_col_specs(secondary_col_dt, model = "estimate_secondary") }) + +test_that("check_sparse_pmf_tail throws a warning as expected", { + pmf <- c(0.4, 0.30, 0.20, 0.05, 0.049995, 4.5e-06, rep(1e-7, 5)) + expect_warning(check_sparse_pmf_tail(pmf), "consecutive values smaller than") +})