From 053f3e42518ab9881aff7d6217b3d311bd9bb288 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Thu, 11 Jan 2024 16:37:53 +0000 Subject: [PATCH] squash bugs highlighted by tests --- R/dist.R | 13 +++++++------ tests/testthat/test-dist_spec.R | 2 +- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/R/dist.R b/R/dist.R index 671653d86..edc2d3fd9 100644 --- a/R/dist.R +++ b/R/dist.R @@ -968,7 +968,6 @@ dist_spec <- function(distribution = c( params_sd <- c(mean = mean_sd, sd = sd_sd) } else if (distribution == "fixed") { params_mean <- mean - params_sd <- 0 } } if (length(pmf) > 0) { @@ -981,13 +980,16 @@ dist_spec <- function(distribution = c( distribution <- "nonparametric" parameters <- list(pmf = pmf) } else { + if (length(params_sd) == 0) { + params_sd <- rep(0, length(params_mean)) + } parameters <- lapply(seq_along(params_mean), function(id) { Normal(params_mean[id], params_sd[id]) }) - names(parameters) <- natural_parameters(distribution) + names(parameters) <- natural_params(distribution) parameters$max <- max } - return(new_dist_spec(distribution, parameters)) + return(new_dist_spec(parameters, distribution)) } #' Creates a delay distribution as the sum of two other delay distributions. @@ -1636,9 +1638,8 @@ Fixed <- function(value, max = Inf) { ##' pmf(c(0.1, 0.3, 0.2, 0.4)) ##' pmf(c(0.1, 0.3, 0.2, 0.1, 0.1)) pmf <- function(mass) { - return( - new_dist_spec(parameters = list(pmf = mass / sum(mass)), "nonparametric") - ) + params <- list(pmf = mass / sum(mass)) + return(new_dist_spec(params, "nonparametric")) } ##' Get the names of the natural parameters of a distribution diff --git a/tests/testthat/test-dist_spec.R b/tests/testthat/test-dist_spec.R index f5c2f1e1b..1c02badee 100644 --- a/tests/testthat/test-dist_spec.R +++ b/tests/testthat/test-dist_spec.R @@ -263,5 +263,5 @@ test_that("delay distributions can be specified in different ways", { }) test_that("deprecated functions are deprecated", { - expect_deprecated(dist_spec(params_mean = 1.6, params_sd = 0.6, max = 19)) + expect_deprecated(dist_spec(params_mean = c(1.6, 0.6), max = 19)) })