diff --git a/include/albatross/src/models/conditional_gaussian.hpp b/include/albatross/src/models/conditional_gaussian.hpp index 4d3d12a5..a445a530 100644 --- a/include/albatross/src/models/conditional_gaussian.hpp +++ b/include/albatross/src/models/conditional_gaussian.hpp @@ -23,6 +23,10 @@ struct ConditionalFit { class ConditionalGaussian : public ModelBase { public: + ConditionalGaussian(JointDistribution &&prior, + const MarginalDistribution &truth) + : prior_(std::move(prior)), truth_(truth) {} + ConditionalGaussian(const JointDistribution &prior, const MarginalDistribution &truth) : prior_(prior), truth_(truth) {} diff --git a/include/albatross/src/models/ransac_gp.hpp b/include/albatross/src/models/ransac_gp.hpp index 8c55fc81..44848446 100644 --- a/include/albatross/src/models/ransac_gp.hpp +++ b/include/albatross/src/models/ransac_gp.hpp @@ -17,53 +17,54 @@ namespace albatross { template inline typename RansacFunctions::FitterFunc -get_gp_ransac_fitter(const ConditionalGaussian &model, +get_gp_ransac_fitter(const std::shared_ptr &model, const GroupIndexer &indexer) { return [&, model, indexer](const std::vector &groups) { auto indices = indices_from_groups(indexer, groups); - return model.fit_from_indices(indices); + return model->fit_from_indices(indices); }; } template inline typename RansacFunctions::IsValidCandidate -get_gp_ransac_is_valid_candidate(const ConditionalGaussian &model, - const GroupIndexer &indexer, - const IsValidCandidateMetric &metric) { +get_gp_ransac_is_valid_candidate( + const std::shared_ptr &model, + const GroupIndexer &indexer, + const IsValidCandidateMetric &metric) { return [&, model, indexer](const std::vector &groups) { const auto indices = indices_from_groups(indexer, groups); - const auto prior = model.get_prior(indices); - const auto truth = model.get_truth(indices); + const auto prior = model->get_prior(indices); + const auto truth = model->get_truth(indices); return metric(prior, truth); }; } template inline typename RansacFunctions::InlierMetric -get_gp_ransac_inlier_metric(const ConditionalGaussian &model, +get_gp_ransac_inlier_metric(const std::shared_ptr &model, const GroupIndexer &indexer, const InlierMetricType &metric) { return [&, indexer, model](const GroupKey &group, const ConditionalFit &fit) { const auto indices = indexer.at(group); - const auto pred = get_prediction_reference(model, fit, indices); - const auto truth = model.get_truth(indices); + const auto pred = get_prediction_reference(*model, fit, indices); + const auto truth = model->get_truth(indices); return metric(pred, truth); }; } template inline typename RansacFunctions::ConsensusMetric -get_gp_ransac_consensus_metric(const ConditionalGaussian &model, - const GroupIndexer &indexer, - const ConsensusMetric &metric) { +get_gp_ransac_consensus_metric( + const std::shared_ptr &model, + const GroupIndexer &indexer, const ConsensusMetric &metric) { return [&, model, indexer](const std::vector &groups) { const auto indices = indices_from_groups(indexer, groups); - const auto prior = model.get_prior(indices); - const auto truth = model.get_truth(indices); + const auto prior = model->get_prior(indices); + const auto truth = model->get_truth(indices); return metric(prior, truth); }; } @@ -109,9 +110,10 @@ struct AlwaysAcceptCandidateMetric { }; template + typename IsValidCandidateMetric, typename GroupKey, + typename PriorDistribution> inline RansacFunctions get_gp_ransac_functions( - const JointDistribution &prior, const MarginalDistribution &truth, + PriorDistribution &&prior, const MarginalDistribution &truth, const GroupIndexer &indexer, const InlierMetric &inlier_metric, const ConsensusMetric &consensus_metric, const IsValidCandidateMetric &is_valid_candidate_metric) { @@ -119,7 +121,9 @@ inline RansacFunctions get_gp_ransac_functions( static_assert(is_prediction_metric::value, "InlierMetric must be an PredictionMetric."); - const ConditionalGaussian model(prior, truth); + const std::shared_ptr model = + std::make_shared( + std::forward(prior), truth); const auto fitter = get_gp_ransac_fitter(model, indexer);