Skip to content

Commit

Permalink
Impart "update by average-style" methods from the new stcpCPP to stcpR6.
Browse files Browse the repository at this point in the history
  • Loading branch information
shinjaehyeok committed Jan 11, 2024
1 parent 3fd4ab8 commit 362bc6a
Show file tree
Hide file tree
Showing 12 changed files with 295 additions and 7 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

export(Stcp)
export(compute_baseline)
export(compute_baseline_for_sample_size)
export(generate_sub_B_fn)
export(generate_sub_E_fn)
export(generate_sub_G_fn)
Expand Down
28 changes: 28 additions & 0 deletions R/stcp.R
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,34 @@ Stcp <- R6::R6Class(
#' @param xs A numeric vector of observations.
updateAndReturnHistories = function(xs) {
private$m_stcpCpp$updateAndReturnHistories(xs)
},
#' @description
#' Update the log value and related fields by passing
#' a vector of averages and number of corresponding samples.
#'
#' @param x_bars A numeric vector of averages.
#' @param ns A numeric vector of sample sizes.
updateLogValuesByAvgs = function(x_bars, ns) {
private$m_stcpCpp$updateLogValuesByAvgs(x_bars, ns)
},
#' @description
#' Update the log value and related fields by passing
#' a vector of averages and number of corresponding samples
#' until the log value is crossing the boundary.
#'
#' @param x_bars A numeric vector of averages.
#' @param ns A numeric vector of sample sizes.
updateLogValuesUntilStopByAvgs = function(x_bars, ns) {
private$m_stcpCpp$updateLogValuesUntilStopByAvgs(x_bars, ns)
},
#' @description
#' Update the log value and related fields then return updated log values
#' a vector of averages and number of corresponding samples.
#'
#' @param x_bars A numeric vector of averages.
#' @param ns A numeric vector of sample sizes.
updateAndReturnHistoriesByAvgs = function(x_bars, ns) {
private$m_stcpCpp$updateAndReturnHistoriesByAvgs(x_bars, ns)
}
),
private = list(
Expand Down
64 changes: 64 additions & 0 deletions man/Stcp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

37 changes: 37 additions & 0 deletions man/compute_baseline_for_sample_size.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 15 additions & 0 deletions src/baseline_e.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ namespace stcp
double getLogValue() override { return m_log_value; }
void reset() override { m_log_value = kNegInf; }
virtual void updateLogValue(const double &x) override = 0;
virtual void updateLogValueByAvg(const double &x_bar, const double &n) override = 0;

protected:
double m_log_value;
Expand Down Expand Up @@ -70,6 +71,10 @@ namespace stcp
{
this->m_log_value += this->m_base_obj.computeLogBaseValue(x);
}
void updateLogValueByAvg(const double &x_bar, const double &n) override
{
this->m_log_value += this->m_base_obj.computeLogBaseValueByAvg(x_bar, n);
}
};
template <typename L>
class SR : public BaselineE<L>
Expand All @@ -81,6 +86,11 @@ namespace stcp
this->m_log_value =
log(1 + exp(this->m_log_value)) + this->m_base_obj.computeLogBaseValue(x);
}
void updateLogValueByAvg(const double &x_bar, const double &n) override
{
this->m_log_value =
log(1 + exp(this->m_log_value)) + this->m_base_obj.computeLogBaseValueByAvg(x_bar, n);
}
};
template <typename L>
class CU : public BaselineE<L>
Expand All @@ -92,6 +102,11 @@ namespace stcp
this->m_log_value =
std::max(0.0, this->m_log_value) + this->m_base_obj.computeLogBaseValue(x);
}
void updateLogValueByAvg(const double &x_bar, const double &n) override
{
this->m_log_value =
std::max(0.0, this->m_log_value) + this->m_base_obj.computeLogBaseValueByAvg(x_bar, n);
}
};
} // End of namespace stcp
#endif
18 changes: 17 additions & 1 deletion src/baseline_increment.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ namespace stcp
{
}
virtual double computeLogBaseValue(const double &x) override = 0;
// Exponential baseline can support a batch update
// by using x_bar = 1/n * sum_{i=1}^n x_i and n values as inputs.
// Note batch update should take n as a double rather than integer for generality.
virtual double computeLogBaseValueByAvg(const double &x_bar, const double &n) override = 0;

protected:
double m_lambda{0};
Expand Down Expand Up @@ -61,6 +65,11 @@ namespace stcp
{
return m_lambda * x - m_lambda_times_mu_plus_psi;
}
double computeLogBaseValueByAvg(const double &x_bar, const double &n) override
{
return n * Normal::computeLogBaseValue(x_bar);
}


protected:
double m_mu{0.0};
Expand Down Expand Up @@ -120,6 +129,10 @@ namespace stcp
throw std::runtime_error("Input must be either 0.0 or 1.0 or false or true.");
}
}
double computeLogBaseValueByAvg(const double &x_bar, const double &n) override
{
return n * (m_lambda * x_bar + m_log_base_val_x_zero);
}

protected:
double m_p{0.5};
Expand All @@ -145,7 +158,6 @@ namespace stcp
}
};

// General bounded baseline increment
class Bounded : public ExpBaselineIncrement
{
public:
Expand All @@ -171,6 +183,10 @@ namespace stcp

return log(1.0 + m_lambda * (x / m_mu - 1.0));
}
double computeLogBaseValueByAvg(const double &x_bar, const double &n) override
{
throw std::runtime_error("computeLogBaseValueByAvg cannot be used for the Bounded case.");
}

protected:
double m_mu{0.5};
Expand Down
4 changes: 4 additions & 0 deletions src/log_lr_e.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ namespace stcp
double getLogValue() override { return m_log_value; }
void reset() override { m_log_value = kNegInf; }
virtual void updateLogValue(const double &x) override = 0;
void updateLogValueByAvg(const double &x_bar, const double &n) override
{
throw std::runtime_error("updateLogValueByAvg is not supported for LR based methods");
}

protected:
double m_log_value;
Expand Down
10 changes: 10 additions & 0 deletions src/mix_e.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ namespace stcp
double getLogValue() override;
void reset() override;
void updateLogValue(const double &x) override;
void updateLogValueByAvg(const double &x_bar, const double &n) override;

std::vector<double> getWeights() { return m_weights; }
std::vector<double> getLogValues();
Expand Down Expand Up @@ -113,6 +114,15 @@ namespace stcp
}
}

template <typename E>
inline void MixE<E>::updateLogValueByAvg(const double &x_bar, const double &n)
{
for (auto &e_obj : m_e_objs)
{
e_obj.updateLogValueByAvg(x_bar, n);
}
}

template <typename E>
inline std::vector<double> MixE<E>::getLogValues()
{
Expand Down
Loading

0 comments on commit 362bc6a

Please sign in to comment.