From 362bc6a4d02c94f61e63dc988c33ae7eaf06be25 Mon Sep 17 00:00:00 2001 From: shinjaehyeok Date: Thu, 11 Jan 2024 00:13:19 -0800 Subject: [PATCH] Impart "update by average-style" methods from the new stcpCPP to stcpR6. --- NAMESPACE | 1 + R/stcp.R | 28 +++++++++ man/Stcp.Rd | 64 ++++++++++++++++++++ man/compute_baseline_for_sample_size.Rd | 37 ++++++++++++ src/baseline_e.h | 15 +++++ src/baseline_increment.h | 18 +++++- src/log_lr_e.h | 4 ++ src/mix_e.h | 10 ++++ src/stcp.h | 78 +++++++++++++++++++++++-- src/stcp_export.cpp | 18 ++++++ src/stcp_interface.h | 17 +++++- tests/testthat/test-sequential_test.R | 12 ++++ 12 files changed, 295 insertions(+), 7 deletions(-) create mode 100644 man/compute_baseline_for_sample_size.Rd diff --git a/NAMESPACE b/NAMESPACE index 1b50a3c..eee0c00 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,7 @@ export(Stcp) export(compute_baseline) +export(compute_baseline_for_sample_size) export(generate_sub_B_fn) export(generate_sub_E_fn) export(generate_sub_G_fn) diff --git a/R/stcp.R b/R/stcp.R index 662cfbd..4f29436 100644 --- a/R/stcp.R +++ b/R/stcp.R @@ -473,6 +473,34 @@ Stcp <- R6::R6Class( #' @param xs A numeric vector of observations. updateAndReturnHistories = function(xs) { private$m_stcpCpp$updateAndReturnHistories(xs) + }, + #' @description + #' Update the log value and related fields by passing + #' a vector of averages and number of corresponding samples. + #' + #' @param x_bars A numeric vector of averages. + #' @param ns A numeric vector of sample sizes. + updateLogValuesByAvgs = function(x_bars, ns) { + private$m_stcpCpp$updateLogValuesByAvgs(x_bars, ns) + }, + #' @description + #' Update the log value and related fields by passing + #' a vector of averages and number of corresponding samples + #' until the log value is crossing the boundary. + #' + #' @param x_bars A numeric vector of averages. + #' @param ns A numeric vector of sample sizes. + updateLogValuesUntilStopByAvgs = function(x_bars, ns) { + private$m_stcpCpp$updateLogValuesUntilStopByAvgs(x_bars, ns) + }, + #' @description + #' Update the log value and related fields then return updated log values + #' a vector of averages and number of corresponding samples. + #' + #' @param x_bars A numeric vector of averages. + #' @param ns A numeric vector of sample sizes. + updateAndReturnHistoriesByAvgs = function(x_bars, ns) { + private$m_stcpCpp$updateAndReturnHistoriesByAvgs(x_bars, ns) } ), private = list( diff --git a/man/Stcp.Rd b/man/Stcp.Rd index 4830f9d..2f18353 100644 --- a/man/Stcp.Rd +++ b/man/Stcp.Rd @@ -80,6 +80,9 @@ stcp # or stcp$print() or print(stcp) \item \href{#method-Stcp-updateLogValues}{\code{Stcp$updateLogValues()}} \item \href{#method-Stcp-updateLogValuesUntilStop}{\code{Stcp$updateLogValuesUntilStop()}} \item \href{#method-Stcp-updateAndReturnHistories}{\code{Stcp$updateAndReturnHistories()}} +\item \href{#method-Stcp-updateLogValuesByAvgs}{\code{Stcp$updateLogValuesByAvgs()}} +\item \href{#method-Stcp-updateLogValuesUntilStopByAvgs}{\code{Stcp$updateLogValuesUntilStopByAvgs()}} +\item \href{#method-Stcp-updateAndReturnHistoriesByAvgs}{\code{Stcp$updateAndReturnHistoriesByAvgs()}} } } \if{html}{\out{
}} @@ -296,4 +299,65 @@ Update the log value and related fields then return updated log values by passin \if{html}{\out{}} } } +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Stcp-updateLogValuesByAvgs}{}}} +\subsection{Method \code{updateLogValuesByAvgs()}}{ +Update the log value and related fields by passing +a vector of averages and number of corresponding samples. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Stcp$updateLogValuesByAvgs(x_bars, ns)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{x_bars}}{A numeric vector of averages.} + +\item{\code{ns}}{A numeric vector of sample sizes.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Stcp-updateLogValuesUntilStopByAvgs}{}}} +\subsection{Method \code{updateLogValuesUntilStopByAvgs()}}{ +Update the log value and related fields by passing +a vector of averages and number of corresponding samples +until the log value is crossing the boundary. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Stcp$updateLogValuesUntilStopByAvgs(x_bars, ns)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{x_bars}}{A numeric vector of averages.} + +\item{\code{ns}}{A numeric vector of sample sizes.} +} +\if{html}{\out{
}} +} +} +\if{html}{\out{
}} +\if{html}{\out{}} +\if{latex}{\out{\hypertarget{method-Stcp-updateAndReturnHistoriesByAvgs}{}}} +\subsection{Method \code{updateAndReturnHistoriesByAvgs()}}{ +Update the log value and related fields then return updated log values +a vector of averages and number of corresponding samples. +\subsection{Usage}{ +\if{html}{\out{
}}\preformatted{Stcp$updateAndReturnHistoriesByAvgs(x_bars, ns)}\if{html}{\out{
}} +} + +\subsection{Arguments}{ +\if{html}{\out{
}} +\describe{ +\item{\code{x_bars}}{A numeric vector of averages.} + +\item{\code{ns}}{A numeric vector of sample sizes.} +} +\if{html}{\out{
}} +} +} } diff --git a/man/compute_baseline_for_sample_size.Rd b/man/compute_baseline_for_sample_size.Rd new file mode 100644 index 0000000..9487288 --- /dev/null +++ b/man/compute_baseline_for_sample_size.Rd @@ -0,0 +1,37 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/compute_baseline.R +\name{compute_baseline_for_sample_size} +\alias{compute_baseline_for_sample_size} +\title{Compute baseline parameters given target variance process bounds.} +\usage{ +compute_baseline_for_sample_size( + alpha, + v_upper, + v_lower, + psi_fn_list = generate_sub_G_fn(), + v_min = 1, + k_max = 200, + tol = 1e-06 +) +} +\arguments{ +\item{alpha}{ARL parameter in (0,1)} + +\item{v_upper}{Upper bound of the target variance process bound} + +\item{v_lower}{Lower bound of the target variance process bound.} + +\item{psi_fn_list}{A list of R functions that computes psi and psi_star functions. Can be generated by \code{generate_sub_G_fn()} or counterparts for sub_B and sub_E.} + +\item{v_min}{A lower bound of v function in the baseline process. Default is \code{1}.} + +\item{k_max}{Positive integer to determine the maximum number of baselines. Default is \code{200}.} + +\item{tol}{Tolerance of root-finding, positive numeric. Default is 1e-6.} +} +\value{ +A list of baseline parameters to build \code{ci_helper} +} +\description{ +Given target variance process bounds for confidence sequences, compute baseline parameters. +} diff --git a/src/baseline_e.h b/src/baseline_e.h index 3d9f461..ee6abed 100644 --- a/src/baseline_e.h +++ b/src/baseline_e.h @@ -43,6 +43,7 @@ namespace stcp double getLogValue() override { return m_log_value; } void reset() override { m_log_value = kNegInf; } virtual void updateLogValue(const double &x) override = 0; + virtual void updateLogValueByAvg(const double &x_bar, const double &n) override = 0; protected: double m_log_value; @@ -70,6 +71,10 @@ namespace stcp { this->m_log_value += this->m_base_obj.computeLogBaseValue(x); } + void updateLogValueByAvg(const double &x_bar, const double &n) override + { + this->m_log_value += this->m_base_obj.computeLogBaseValueByAvg(x_bar, n); + } }; template class SR : public BaselineE @@ -81,6 +86,11 @@ namespace stcp this->m_log_value = log(1 + exp(this->m_log_value)) + this->m_base_obj.computeLogBaseValue(x); } + void updateLogValueByAvg(const double &x_bar, const double &n) override + { + this->m_log_value = + log(1 + exp(this->m_log_value)) + this->m_base_obj.computeLogBaseValueByAvg(x_bar, n); + } }; template class CU : public BaselineE @@ -92,6 +102,11 @@ namespace stcp this->m_log_value = std::max(0.0, this->m_log_value) + this->m_base_obj.computeLogBaseValue(x); } + void updateLogValueByAvg(const double &x_bar, const double &n) override + { + this->m_log_value = + std::max(0.0, this->m_log_value) + this->m_base_obj.computeLogBaseValueByAvg(x_bar, n); + } }; } // End of namespace stcp #endif \ No newline at end of file diff --git a/src/baseline_increment.h b/src/baseline_increment.h index c9a1dae..d1e37b8 100644 --- a/src/baseline_increment.h +++ b/src/baseline_increment.h @@ -34,6 +34,10 @@ namespace stcp { } virtual double computeLogBaseValue(const double &x) override = 0; + // Exponential baseline can support a batch update + // by using x_bar = 1/n * sum_{i=1}^n x_i and n values as inputs. + // Note batch update should take n as a double rather than integer for generality. + virtual double computeLogBaseValueByAvg(const double &x_bar, const double &n) override = 0; protected: double m_lambda{0}; @@ -61,6 +65,11 @@ namespace stcp { return m_lambda * x - m_lambda_times_mu_plus_psi; } + double computeLogBaseValueByAvg(const double &x_bar, const double &n) override + { + return n * Normal::computeLogBaseValue(x_bar); + } + protected: double m_mu{0.0}; @@ -120,6 +129,10 @@ namespace stcp throw std::runtime_error("Input must be either 0.0 or 1.0 or false or true."); } } + double computeLogBaseValueByAvg(const double &x_bar, const double &n) override + { + return n * (m_lambda * x_bar + m_log_base_val_x_zero); + } protected: double m_p{0.5}; @@ -145,7 +158,6 @@ namespace stcp } }; - // General bounded baseline increment class Bounded : public ExpBaselineIncrement { public: @@ -171,6 +183,10 @@ namespace stcp return log(1.0 + m_lambda * (x / m_mu - 1.0)); } + double computeLogBaseValueByAvg(const double &x_bar, const double &n) override + { + throw std::runtime_error("computeLogBaseValueByAvg cannot be used for the Bounded case."); + } protected: double m_mu{0.5}; diff --git a/src/log_lr_e.h b/src/log_lr_e.h index 85f1969..01ccfaa 100644 --- a/src/log_lr_e.h +++ b/src/log_lr_e.h @@ -42,6 +42,10 @@ namespace stcp double getLogValue() override { return m_log_value; } void reset() override { m_log_value = kNegInf; } virtual void updateLogValue(const double &x) override = 0; + void updateLogValueByAvg(const double &x_bar, const double &n) override + { + throw std::runtime_error("updateLogValueByAvg is not supported for LR based methods"); + } protected: double m_log_value; diff --git a/src/mix_e.h b/src/mix_e.h index a19a0d7..efba8e7 100644 --- a/src/mix_e.h +++ b/src/mix_e.h @@ -36,6 +36,7 @@ namespace stcp double getLogValue() override; void reset() override; void updateLogValue(const double &x) override; + void updateLogValueByAvg(const double &x_bar, const double &n) override; std::vector getWeights() { return m_weights; } std::vector getLogValues(); @@ -113,6 +114,15 @@ namespace stcp } } + template + inline void MixE::updateLogValueByAvg(const double &x_bar, const double &n) + { + for (auto &e_obj : m_e_objs) + { + e_obj.updateLogValueByAvg(x_bar, n); + } + } + template inline std::vector MixE::getLogValues() { diff --git a/src/stcp.h b/src/stcp.h index d560e23..c5d7335 100644 --- a/src/stcp.h +++ b/src/stcp.h @@ -45,8 +45,8 @@ namespace stcp double getThreshold() override { return m_threshold; }; bool isStopped() override { return m_is_stopped; }; - int getTime() override { return m_time; }; - int getStoppedTime() override { return m_stopped_time; } + double getTime() override { return m_time; }; + double getStoppedTime() override { return m_stopped_time; } void reset() override { @@ -60,15 +60,22 @@ namespace stcp void updateLogValues(const std::vector &xs) override; void updateLogValuesUntilStop(const std::vector &xs) override; + void updateLogValueByAvg(const double &x_bar, const double &n) override; + void updateLogValuesByAvgs(const std::vector &x_bars, const std::vector &ns) override; + void updateLogValuesUntilStopByAvgs(const std::vector &x_bars, const std::vector &ns) override; + double updateAndReturnHistory(const double &x) override; std::vector updateAndReturnHistories(const std::vector &xs) override; + double updateAndReturnHistoryByAvg(const double &x_bar, const double &n) override; + std::vector updateAndReturnHistoriesByAvgs(const std::vector &x_bars, const std::vector &ns) override; + protected: E m_e_obj{}; double m_threshold{log(1.0 / 0.05)}; // Default threshold ues alpha = 0.05. - int m_time{0}; + double m_time{0.0}; bool m_is_stopped{false}; - int m_stopped_time{0}; + double m_stopped_time{0.0}; }; // Public members @@ -108,12 +115,61 @@ namespace stcp } } template + inline void Stcp::updateLogValueByAvg(const double &x_bar, const double &n) + { + m_e_obj.updateLogValueByAvg(x_bar, n); + m_time += n; + if (this->getLogValue() > m_threshold) + { + if (!m_is_stopped) + { + // Record the first stopped time only. + m_stopped_time = m_time; + m_is_stopped = true; + } + } + } + template + inline void Stcp::updateLogValuesByAvgs(const std::vector &x_bars, const std::vector &ns) + { + if (x_bars.size() != ns.size()) + { + throw std::runtime_error("x_bars and ns do not have the same length."); + } + for (std::size_t i = 0; i < x_bars.size(); i++) + { + this->updateLogValueByAvg(x_bars[i], ns[i]); + } + } + template + inline void Stcp::updateLogValuesUntilStopByAvgs(const std::vector &x_bars, const std::vector &ns) + { + if (x_bars.size() != ns.size()) + { + throw std::runtime_error("x_bars and ns do not have the same length."); + } + for (std::size_t i = 0; i < x_bars.size(); i++) + { + this->updateLogValueByAvg(x_bars[i], ns[i]); + if (m_is_stopped) + { + break; + } + } + } + template inline double Stcp::updateAndReturnHistory(const double &x) { this->updateLogValue(x); return this->getLogValue(); } template + inline double Stcp::updateAndReturnHistoryByAvg(const double &x_bar, const double &n) + { + this->updateLogValueByAvg(x_bar, n); + return this->getLogValue(); + } + template inline std::vector Stcp::updateAndReturnHistories(const std::vector &xs) { std::vector log_values(xs.size()); @@ -123,5 +179,19 @@ namespace stcp } return log_values; } + template + inline std::vector Stcp::updateAndReturnHistoriesByAvgs(const std::vector &x_bars, const std::vector &ns) + { + if (x_bars.size() != ns.size()) + { + throw std::runtime_error("x_bars and ns do not have the same length."); + } + std::vector log_values(x_bars.size()); + for (std::size_t i = 0; i < x_bars.size(); i++) + { + log_values[i] = this->updateAndReturnHistoryByAvg(x_bars[i], ns[i]); + } + return log_values; + } } // End of namespace stcp #endif \ No newline at end of file diff --git a/src/stcp_export.cpp b/src/stcp_export.cpp index 6ebf116..59ad211 100644 --- a/src/stcp_export.cpp +++ b/src/stcp_export.cpp @@ -26,6 +26,9 @@ RCPP_MODULE(StcpMixESTNormalEx) { .method("updateLogValues", &Stcp>::updateLogValues) .method("updateLogValuesUntilStop", &Stcp>::updateLogValuesUntilStop) .method("updateAndReturnHistories", &Stcp>::updateAndReturnHistories) + .method("updateLogValuesByAvgs", &Stcp>::updateLogValuesByAvgs) + .method("updateLogValuesUntilStopByAvgs", &Stcp>::updateLogValuesUntilStopByAvgs) + .method("updateAndReturnHistoriesByAvgs", &Stcp>::updateAndReturnHistoriesByAvgs) ; @@ -56,6 +59,9 @@ RCPP_MODULE(StcpMixESRNormalEx) { .method("updateLogValues", &Stcp>::updateLogValues) .method("updateLogValuesUntilStop", &Stcp>::updateLogValuesUntilStop) .method("updateAndReturnHistories", &Stcp>::updateAndReturnHistories) + .method("updateLogValuesByAvgs", &Stcp>::updateLogValuesByAvgs) + .method("updateLogValuesUntilStopByAvgs", &Stcp>::updateLogValuesUntilStopByAvgs) + .method("updateAndReturnHistoriesByAvgs", &Stcp>::updateAndReturnHistoriesByAvgs) ; @@ -86,6 +92,9 @@ RCPP_MODULE(StcpMixECUNormalEx) { .method("updateLogValues", &Stcp>::updateLogValues) .method("updateLogValuesUntilStop", &Stcp>::updateLogValuesUntilStop) .method("updateAndReturnHistories", &Stcp>::updateAndReturnHistories) + .method("updateLogValuesByAvgs", &Stcp>::updateLogValuesByAvgs) + .method("updateLogValuesUntilStopByAvgs", &Stcp>::updateLogValuesUntilStopByAvgs) + .method("updateAndReturnHistoriesByAvgs", &Stcp>::updateAndReturnHistoriesByAvgs) ; @@ -116,6 +125,9 @@ RCPP_MODULE(StcpMixESTBerEx) { .method("updateLogValues", &Stcp>::updateLogValues) .method("updateLogValuesUntilStop", &Stcp>::updateLogValuesUntilStop) .method("updateAndReturnHistories", &Stcp>::updateAndReturnHistories) + .method("updateLogValuesByAvgs", &Stcp>::updateLogValuesByAvgs) + .method("updateLogValuesUntilStopByAvgs", &Stcp>::updateLogValuesUntilStopByAvgs) + .method("updateAndReturnHistoriesByAvgs", &Stcp>::updateAndReturnHistoriesByAvgs) ; @@ -146,6 +158,9 @@ RCPP_MODULE(StcpMixESRBerEx) { .method("updateLogValues", &Stcp>::updateLogValues) .method("updateLogValuesUntilStop", &Stcp>::updateLogValuesUntilStop) .method("updateAndReturnHistories", &Stcp>::updateAndReturnHistories) + .method("updateLogValuesByAvgs", &Stcp>::updateLogValuesByAvgs) + .method("updateLogValuesUntilStopByAvgs", &Stcp>::updateLogValuesUntilStopByAvgs) + .method("updateAndReturnHistoriesByAvgs", &Stcp>::updateAndReturnHistoriesByAvgs) ; @@ -176,6 +191,9 @@ RCPP_MODULE(StcpMixECUBerEx) { .method("updateLogValues", &Stcp>::updateLogValues) .method("updateLogValuesUntilStop", &Stcp>::updateLogValuesUntilStop) .method("updateAndReturnHistories", &Stcp>::updateAndReturnHistories) + .method("updateLogValuesByAvgs", &Stcp>::updateLogValuesByAvgs) + .method("updateLogValuesUntilStopByAvgs", &Stcp>::updateLogValuesUntilStopByAvgs) + .method("updateAndReturnHistoriesByAvgs", &Stcp>::updateAndReturnHistoriesByAvgs) ; diff --git a/src/stcp_interface.h b/src/stcp_interface.h index d4405ce..08bddb5 100644 --- a/src/stcp_interface.h +++ b/src/stcp_interface.h @@ -52,6 +52,11 @@ namespace stcp { public: virtual double computeLogBaseValue(const double &x) = 0; + // Exponential baseline can be computed + // by using x_bar = 1/n * sum_{i=1}^n x_i and n values as inputs. + // General baseline may be computed by s/v and v as inputs. + // This method must take n as a double rather than integer for generality. + virtual double computeLogBaseValueByAvg(const double &x_bar, const double &n) = 0; virtual ~IBaselineIncrement() {} }; @@ -74,6 +79,7 @@ namespace stcp virtual double getLogValue() = 0; virtual void reset() = 0; virtual void updateLogValue(const double &x) = 0; + virtual void updateLogValueByAvg(const double &x_bar, const double &n) = 0; virtual ~IGeneralE() {} }; @@ -85,19 +91,26 @@ namespace stcp virtual double getThreshold() = 0; virtual bool isStopped() = 0; - virtual int getTime() = 0; - virtual int getStoppedTime() = 0; + virtual double getTime() = 0; // Use double for generality + virtual double getStoppedTime() = 0; // Use double for generality virtual void reset() = 0; virtual void updateLogValue(const double &x) = 0; virtual void updateLogValues(const std::vector &xs) = 0; virtual void updateLogValuesUntilStop(const std::vector &xs) = 0; + + virtual void updateLogValueByAvg(const double &x_bar, const double &n) = 0; + virtual void updateLogValuesByAvgs(const std::vector &x_bars, const std::vector &ns) = 0; + virtual void updateLogValuesUntilStopByAvgs(const std::vector &x_bars, const std::vector &ns) = 0; // For visualization, IStcp support update and return updated history virtual double updateAndReturnHistory(const double &x) = 0; virtual std::vector updateAndReturnHistories(const std::vector &xs) = 0; + virtual double updateAndReturnHistoryByAvg(const double &x_bar, const double &n) = 0; + virtual std::vector updateAndReturnHistoriesByAvgs(const std::vector &x_bars, const std::vector &ns) = 0; + virtual ~IStcp() {} }; } // End of namespace stcp diff --git a/tests/testthat/test-sequential_test.R b/tests/testthat/test-sequential_test.R index 73e3e7b..74ad4b0 100644 --- a/tests/testthat/test-sequential_test.R +++ b/tests/testthat/test-sequential_test.R @@ -79,4 +79,16 @@ test_that("Sequentail test for Normal - 2. Mixture", { expect_equal(stcp$getTime(), 2) expect_equal(stcp$isStopped(), TRUE) expect_equal(stcp$getStoppedTime(), 2) + + + x_bars <- c(2, 2) + ns <- c(2, 1) + expected_log_value_by_avgs <- expected_log_value[2:3] + stcp$reset() + updates <- stcp$updateAndReturnHistoriesByAvgs(x_bars, ns) + expect_equal(updates, expected_log_value_by_avgs) + expect_equal(stcp$getTime(), length(obs)) + expect_equal(stcp$isStopped(), TRUE) + expect_equal(stcp$getStoppedTime(), 2) + })