Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rebase Inducing Points Using QR #476

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
385 changes: 295 additions & 90 deletions doc/src/sparse-gp-details.rst

Large diffs are not rendered by default.

23 changes: 23 additions & 0 deletions include/albatross/src/cereal/eigen.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,29 @@ inline void load(Archive &archive,
v.indices() = indices;
}

template <class Archive, int SizeAtCompileTime, int MaxSizeAtCompileTime,
typename _StorageIndex>
inline void
save(Archive &archive,
const Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex> &v,
const std::uint32_t) {
archive(cereal::make_nvp("indices", v.indices()));
}

template <class Archive, int SizeAtCompileTime, int MaxSizeAtCompileTime,
typename _StorageIndex>
inline void
load(Archive &archive,
Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex> &v,
const std::uint32_t) {
typename Eigen::PermutationMatrix<SizeAtCompileTime, MaxSizeAtCompileTime,
_StorageIndex>::IndicesType indices;
archive(cereal::make_nvp("indices", indices));
v.indices() = indices;
}

template <typename Archive, typename _Scalar, int SizeAtCompileTime>
inline void serialize(Archive &archive,
Eigen::DiagonalMatrix<_Scalar, SizeAtCompileTime> &matrix,
Expand Down
16 changes: 8 additions & 8 deletions include/albatross/src/cereal/gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ inline void serialize(Archive &archive, Fit<SparseGPFit<FeatureType>> &fit,
archive(cereal::make_nvp("information", fit.information));
archive(cereal::make_nvp("train_covariance", fit.train_covariance));
archive(cereal::make_nvp("train_features", fit.train_features));
archive(cereal::make_nvp("sigma_R", fit.sigma_R));
archive(cereal::make_nvp("permutation_indices", fit.permutation_indices));
archive(cereal::make_nvp("R", fit.R));
archive(cereal::make_nvp("P", fit.P));
if (version > 1) {
archive(cereal::make_nvp("numerical_rank", fit.numerical_rank));
} else {
Expand All @@ -53,19 +53,19 @@ inline void serialize(Archive &archive, Fit<SparseGPFit<FeatureType>> &fit,

template <typename Archive, typename CovFunc, typename MeanFunc,
typename ImplType>
void save(Archive &archive,
const GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t) {
inline void save(Archive &archive,
const GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t) {
archive(cereal::make_nvp("name", gp.get_name()));
archive(cereal::make_nvp("params", gp.get_params()));
archive(cereal::make_nvp("insights", gp.insights));
}

template <typename Archive, typename CovFunc, typename MeanFunc,
typename ImplType>
void load(Archive &archive,
GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t version) {
inline void load(Archive &archive,
GaussianProcessBase<CovFunc, MeanFunc, ImplType> &gp,
const std::uint32_t version) {
if (version > 0) {
std::string model_name;
archive(cereal::make_nvp("name", model_name));
Expand Down
7 changes: 7 additions & 0 deletions include/albatross/src/core/declarations.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ template <typename... Ts> class variant;

using mapbox::util::variant;

/*
* Permutations
*/
namespace Eigen {
using PermutationMatrixX = PermutationMatrix<Dynamic, Dynamic, Index>;
}

namespace albatross {

/*
Expand Down
22 changes: 12 additions & 10 deletions include/albatross/src/linalg/qr_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,31 @@ get_R(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr) {
.template triangularView<Eigen::Upper>();
}

inline Eigen::PermutationMatrixX
get_P(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr) {
return Eigen::PermutationMatrixX(
qr.colsPermutation().indices().template cast<Eigen::Index>());
}

/*
* Computes R^-T P^T rhs given R and P from a QR decomposition.
*/
template <typename MatrixType, typename PermutationIndicesType>
template <typename MatrixType, typename PermutationScalar>
inline Eigen::MatrixXd
sqrt_solve(const Eigen::MatrixXd &R,
const PermutationIndicesType &permutation_indices,
const Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic,
PermutationScalar> &P,
const MatrixType &rhs) {

Eigen::MatrixXd sqrt(rhs.rows(), rhs.cols());
for (Eigen::Index i = 0; i < permutation_indices.size(); ++i) {
sqrt.row(i) = rhs.row(permutation_indices.coeff(i));
}
sqrt = R.template triangularView<Eigen::Upper>().transpose().solve(sqrt);
return sqrt;
return R.template triangularView<Eigen::Upper>().transpose().solve(
P.transpose() * rhs);
}

template <typename MatrixType>
inline Eigen::MatrixXd
sqrt_solve(const Eigen::ColPivHouseholderQR<Eigen::MatrixXd> &qr,
const MatrixType &rhs) {
const Eigen::MatrixXd R = get_R(qr);
return sqrt_solve(R, qr.colsPermutation().indices(), rhs);
return sqrt_solve(R, qr.colsPermutation(), rhs);
}

} // namespace albatross
Expand Down
11 changes: 6 additions & 5 deletions include/albatross/src/linalg/spqr_utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,20 @@ using SparseMatrix = Eigen::SparseMatrix<double>;

using SPQR = Eigen::SPQR<SparseMatrix>;

using SparsePermutationMatrix =
Eigen::PermutationMatrix<Eigen::Dynamic, Eigen::Dynamic,
SPQR::StorageIndex>;

inline Eigen::MatrixXd get_R(const SPQR &qr) {
return qr.matrixR()
.topLeftCorner(qr.cols(), qr.cols())
.template triangularView<Eigen::Upper>();
}

inline Eigen::PermutationMatrixX get_P(const SPQR &qr) {
return Eigen::PermutationMatrixX(
qr.colsPermutation().indices().template cast<Eigen::Index>());
}

template <typename MatrixType>
inline Eigen::MatrixXd sqrt_solve(const SPQR &qr, const MatrixType &rhs) {
return sqrt_solve(get_R(qr), qr.colsPermutation().indices(), rhs);
return sqrt_solve(get_R(qr), get_P(qr), rhs);
}

// Matrices with any dimension smaller than this will use a special
Expand Down
138 changes: 97 additions & 41 deletions include/albatross/src/models/sparse_gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,21 +97,19 @@ template <typename FeatureType> struct Fit<SparseGPFit<FeatureType>> {

std::vector<FeatureType> train_features;
Eigen::SerializableLDLT train_covariance;
Eigen::MatrixXd sigma_R;
PermutationIndices permutation_indices;
Eigen::MatrixXd R;
Eigen::PermutationMatrixX P;
Eigen::VectorXd information;
Eigen::Index numerical_rank;

Fit(){};

Fit(const std::vector<FeatureType> &features_,
const Eigen::SerializableLDLT &train_covariance_,
const Eigen::MatrixXd &sigma_R_,
PermutationIndices &&permutation_indices_,
const Eigen::MatrixXd &R_, const Eigen::PermutationMatrixX &P_,
const Eigen::VectorXd &information_, Eigen::Index numerical_rank_)
: train_features(features_), train_covariance(train_covariance_),
sigma_R(sigma_R_), permutation_indices(std::move(permutation_indices_)),
information(information_), numerical_rank(numerical_rank_) {}
: train_features(features_), train_covariance(train_covariance_), R(R_),
P(P_), information(information_), numerical_rank(numerical_rank_) {}

void shift_mean(const Eigen::VectorXd &mean_shift) {
ALBATROSS_ASSERT(mean_shift.size() == information.size());
Expand All @@ -120,9 +118,8 @@ template <typename FeatureType> struct Fit<SparseGPFit<FeatureType>> {

bool operator==(const Fit<SparseGPFit<FeatureType>> &other) const {
return (train_features == other.train_features &&
train_covariance == other.train_covariance &&
sigma_R == other.sigma_R &&
permutation_indices == other.permutation_indices &&
train_covariance == other.train_covariance && R == other.R &&
P.indices() == other.P.indices() &&
information == other.information &&
numerical_rank == other.numerical_rank);
}
Expand Down Expand Up @@ -325,20 +322,17 @@ class SparseGaussianProcessRegression
compute_internal_components(old_fit.train_features, features, targets,
&A_ldlt, &K_uu_ldlt, &K_fu, &y);

const Eigen::Index n_old = old_fit.sigma_R.rows();
const Eigen::Index n_old = old_fit.R.rows();
const Eigen::Index n_new = A_ldlt.rows();
const Eigen::Index k = old_fit.sigma_R.cols();
const Eigen::Index k = old_fit.R.cols();
Eigen::MatrixXd B = Eigen::MatrixXd::Zero(n_old + n_new, k);

ALBATROSS_ASSERT(n_old == k);

// Form:
// B = |R_old P_old^T| = |Q_1| R P^T
// |A^{-1/2} K_fu| |Q_2|
for (Eigen::Index i = 0; i < old_fit.permutation_indices.size(); ++i) {
const Eigen::Index &pi = old_fit.permutation_indices.coeff(i);
B.col(pi).topRows(i + 1) = old_fit.sigma_R.col(i).topRows(i + 1);
}
B.topRows(old_fit.P.rows()) = old_fit.R * old_fit.P.transpose();
B.bottomRows(n_new) = A_ldlt.sqrt_solve(K_fu);
const auto B_qr = QRImplementation::compute(B, Base::threads_.get());

Expand All @@ -347,13 +341,9 @@ class SparseGaussianProcessRegression
// |A^{-1/2} y |
ALBATROSS_ASSERT(old_fit.information.size() == n_old);
Eigen::VectorXd y_augmented(n_old + n_new);
for (Eigen::Index i = 0; i < old_fit.permutation_indices.size(); ++i) {
y_augmented[i] =
old_fit.information[old_fit.permutation_indices.coeff(i)];
}
y_augmented.topRows(n_old) =
old_fit.sigma_R.template triangularView<Eigen::Upper>() *
y_augmented.topRows(n_old);
old_fit.R.template triangularView<Eigen::Upper>() *
(old_fit.P.transpose() * old_fit.information);

y_augmented.bottomRows(n_new) = A_ldlt.sqrt_solve(y, Base::threads_.get());
const Eigen::VectorXd v = B_qr->solve(y_augmented);
Expand All @@ -365,10 +355,9 @@ class SparseGaussianProcessRegression
Eigen::VectorXd::Constant(B_qr->cols(), details::cSparseRNugget);
}
using FitType = Fit<SparseGPFit<InducingPointFeatureType>>;
return FitType(
old_fit.train_features, old_fit.train_covariance, R,
B_qr->colsPermutation().indices().template cast<Eigen::Index>(), v,
B_qr->rank());

return FitType(old_fit.train_features, old_fit.train_covariance, R,
get_P(*B_qr), v, B_qr->rank());
}

// Here we create the QR decomposition of:
Expand Down Expand Up @@ -415,10 +404,7 @@ class SparseGaussianProcessRegression
using InducingPointFeatureType = typename std::decay<decltype(u[0])>::type;

using FitType = Fit<SparseGPFit<InducingPointFeatureType>>;
return FitType(
u, K_uu_ldlt, get_R(*B_qr),
B_qr->colsPermutation().indices().template cast<Eigen::Index>(), v,
B_qr->rank());
return FitType(u, K_uu_ldlt, get_R(*B_qr), get_P(*B_qr), v, B_qr->rank());
}

template <typename FeatureType>
Expand Down Expand Up @@ -471,9 +457,8 @@ class SparseGaussianProcessRegression
const Eigen::MatrixXd sigma_inv_sqrt = C_ldlt.sqrt_solve(K_zz);
const auto B_qr = QRImplementation::compute(sigma_inv_sqrt, nullptr);

new_fit.permutation_indices =
B_qr->colsPermutation().indices().template cast<Eigen::Index>();
new_fit.sigma_R = get_R(*B_qr);
new_fit.P = get_P(*B_qr);
new_fit.R = get_R(*B_qr);
new_fit.numerical_rank = B_qr->rank();

return output;
Expand Down Expand Up @@ -519,8 +504,8 @@ class SparseGaussianProcessRegression
Q_sqrt.cwiseProduct(Q_sqrt).array().colwise().sum();
marginal_variance -= Q_diag;

const Eigen::MatrixXd S_sqrt = sqrt_solve(
sparse_gp_fit.sigma_R, sparse_gp_fit.permutation_indices, cross_cov);
const Eigen::MatrixXd S_sqrt =
sqrt_solve(sparse_gp_fit.R, sparse_gp_fit.P, cross_cov);
const Eigen::VectorXd S_diag =
S_sqrt.cwiseProduct(S_sqrt).array().colwise().sum();
marginal_variance += S_diag;
Expand All @@ -537,8 +522,8 @@ class SparseGaussianProcessRegression
this->covariance_function_(sparse_gp_fit.train_features, features);
const Eigen::MatrixXd prior_cov = this->covariance_function_(features);

const Eigen::MatrixXd S_sqrt = sqrt_solve(
sparse_gp_fit.sigma_R, sparse_gp_fit.permutation_indices, cross_cov);
const Eigen::MatrixXd S_sqrt =
sqrt_solve(sparse_gp_fit.R, sparse_gp_fit.P, cross_cov);

const Eigen::MatrixXd Q_sqrt =
sparse_gp_fit.train_covariance.sqrt_solve(cross_cov);
Expand Down Expand Up @@ -718,15 +703,86 @@ class SparseGaussianProcessRegression

// rebase_inducing_points takes a Sparse GP which was fit using some set of
// inducing points and creates a new fit relative to new inducing points.
//
// Note that this will NOT be the equivalent to having fit the model with
// the new inducing points since some information may have been lost in
// the process.
template <typename ModelType, typename FeatureType, typename NewFeatureType>
//
// For example, consider the extreme case where your first fit
// doesn't have any inducing points at all, all the information from the first
// observations will have been lost, and when you rebase on new inducing points
// you'd have the prior for those new points.
//
// For implementation details see the online documentation.
//
// The summary involves:
// - Compute K_nn = cov(new, new)
// - Compute K_pn = cov(prev, new)
// - Compute A = L_pp^-1 K_pn
// - Solve for Lhat_nn = chol(K_nn - A^T A)
// - Solve for QRP^T = [Lat_nn
// R_p P_p^T L_pp^-T A]
// - Solve for L_nn = chol(K_nn)
// - Solve for v_n = K_nn^-1 K_np v_p
//
template <typename CovFunc, typename MeanFunc, typename GrouperFunction,
typename InducingPointStrategy, typename QRImplementation,
typename FeatureType, typename NewFeatureType>
auto rebase_inducing_points(
const FitModel<ModelType, Fit<SparseGPFit<FeatureType>>> &fit_model,
const FitModel<SparseGaussianProcessRegression<
CovFunc, MeanFunc, GrouperFunction,
InducingPointStrategy, QRImplementation>,
Fit<SparseGPFit<FeatureType>>> &fit_model,
const std::vector<NewFeatureType> &new_inducing_points) {
return fit_model.get_model().fit_from_prediction(
new_inducing_points, fit_model.predict(new_inducing_points).joint());

const auto &cov = fit_model.get_model().get_covariance();
// Compute K_nn = cov(new, new)
const Eigen::MatrixXd K_nn =
cov(new_inducing_points, fit_model.get_model().threads_.get());

// Compute K_pn = cov(prev, new)
const Fit<SparseGPFit<FeatureType>> &prev_fit = fit_model.get_fit();
const auto &prev_inducing_points = prev_fit.train_features;
const Eigen::MatrixXd K_pn = cov(prev_inducing_points, new_inducing_points,
fit_model.get_model().threads_.get());
// A = L_pp^-1 K_pn
const Eigen::MatrixXd A = prev_fit.train_covariance.sqrt_solve(K_pn);
const Eigen::Index p = K_pn.rows();
const Eigen::Index n = K_nn.rows();
Eigen::MatrixXd B = Eigen::MatrixXd::Zero(n + p, n);

// B[upper] = R P^T L_pp^-T A
const auto LTiA = prev_fit.train_covariance.sqrt_transpose_solve(A);
B.topRows(p) = prev_fit.R.template triangularView<Eigen::Upper>() *
(prev_fit.P.transpose() * LTiA);

// B[lower] = chol(K_nn - A^T A)^T
Eigen::MatrixXd S_nn = K_nn - A.transpose() * A;
// This cholesky operation here is the most likely to experience numerical
// instability because of the A^T A subtraction involved, so we add a nugget.
const double nugget =
fit_model.get_model().get_params()[details::inducing_nugget_name()].value;
assert(nugget >= 0);
S_nn.diagonal() += Eigen::VectorXd::Constant(S_nn.rows(), nugget);
B.bottomRows(n) = Eigen::SerializableLDLT(S_nn).sqrt_transpose();

const auto B_qr =
QRImplementation::compute(B, fit_model.get_model().threads_.get());

Fit<SparseGPFit<FeatureType>> new_fit;
new_fit.train_features = new_inducing_points;
new_fit.train_covariance = Eigen::SerializableLDLT(K_nn);
// v_n = K_nn^-1 K_np v_p
new_fit.information = new_fit.train_covariance.solve(
fit_model.predict(new_inducing_points).mean());
new_fit.P = get_P(*B_qr);
new_fit.R = get_R(*B_qr);
new_fit.numerical_rank = B_qr->rank();

return FitModel<
SparseGaussianProcessRegression<CovFunc, MeanFunc, GrouperFunction,
InducingPointStrategy, QRImplementation>,
Fit<SparseGPFit<FeatureType>>>(fit_model.get_model(), std::move(new_fit));
}

template <typename CovFunc, typename MeanFunc, typename GrouperFunction,
Expand Down
8 changes: 4 additions & 4 deletions tests/test_sparse_gp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,10 +322,10 @@ TYPED_TEST(SparseGaussianProcessTest, test_update) {
(updated_in_place_pred.covariance - full_pred.covariance).norm();

auto compute_sigma = [](const auto &fit_model) -> Eigen::MatrixXd {
const Eigen::Index n = fit_model.get_fit().sigma_R.cols();
Eigen::MatrixXd sigma = sqrt_solve(fit_model.get_fit().sigma_R,
fit_model.get_fit().permutation_indices,
Eigen::MatrixXd::Identity(n, n));
const Eigen::Index n = fit_model.get_fit().R.cols();
Eigen::MatrixXd sigma =
sqrt_solve(fit_model.get_fit().R, fit_model.get_fit().P,
Eigen::MatrixXd::Identity(n, n));
return sigma.transpose() * sigma;
};

Expand Down
Loading