Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup Ransac (Avoiding Copies) #292

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/albatross/src/models/conditional_gaussian.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ struct ConditionalFit {
class ConditionalGaussian : public ModelBase<ConditionalGaussian> {

public:
ConditionalGaussian(JointDistribution &&prior,
const MarginalDistribution &truth)
: prior_(std::move(prior)), truth_(truth) {}

ConditionalGaussian(const JointDistribution &prior,
const MarginalDistribution &truth)
: prior_(prior), truth_(truth) {}
Expand Down
40 changes: 22 additions & 18 deletions include/albatross/src/models/ransac_gp.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,53 +17,54 @@ namespace albatross {

template <typename GroupKey>
inline typename RansacFunctions<ConditionalFit, GroupKey>::FitterFunc
get_gp_ransac_fitter(const ConditionalGaussian &model,
get_gp_ransac_fitter(const std::shared_ptr<ConditionalGaussian> &model,
const GroupIndexer<GroupKey> &indexer) {

return [&, model, indexer](const std::vector<GroupKey> &groups) {
auto indices = indices_from_groups(indexer, groups);
return model.fit_from_indices(indices);
return model->fit_from_indices(indices);
};
}

template <typename IsValidCandidateMetric, typename GroupKey>
inline typename RansacFunctions<ConditionalFit, GroupKey>::IsValidCandidate
get_gp_ransac_is_valid_candidate(const ConditionalGaussian &model,
const GroupIndexer<GroupKey> &indexer,
const IsValidCandidateMetric &metric) {
get_gp_ransac_is_valid_candidate(
const std::shared_ptr<ConditionalGaussian> &model,
const GroupIndexer<GroupKey> &indexer,
const IsValidCandidateMetric &metric) {

return [&, model, indexer](const std::vector<GroupKey> &groups) {
const auto indices = indices_from_groups(indexer, groups);
const auto prior = model.get_prior(indices);
const auto truth = model.get_truth(indices);
const auto prior = model->get_prior(indices);
const auto truth = model->get_truth(indices);
return metric(prior, truth);
};
}

template <typename InlierMetricType, typename GroupKey>
inline typename RansacFunctions<ConditionalFit, GroupKey>::InlierMetric
get_gp_ransac_inlier_metric(const ConditionalGaussian &model,
get_gp_ransac_inlier_metric(const std::shared_ptr<ConditionalGaussian> &model,
const GroupIndexer<GroupKey> &indexer,
const InlierMetricType &metric) {

return [&, indexer, model](const GroupKey &group, const ConditionalFit &fit) {
const auto indices = indexer.at(group);
const auto pred = get_prediction_reference(model, fit, indices);
const auto truth = model.get_truth(indices);
const auto pred = get_prediction_reference(*model, fit, indices);
const auto truth = model->get_truth(indices);
return metric(pred, truth);
};
}

template <typename ConsensusMetric, typename GroupKey>
inline typename RansacFunctions<ConditionalFit, GroupKey>::ConsensusMetric
get_gp_ransac_consensus_metric(const ConditionalGaussian &model,
const GroupIndexer<GroupKey> &indexer,
const ConsensusMetric &metric) {
get_gp_ransac_consensus_metric(
const std::shared_ptr<ConditionalGaussian> &model,
const GroupIndexer<GroupKey> &indexer, const ConsensusMetric &metric) {

return [&, model, indexer](const std::vector<GroupKey> &groups) {
const auto indices = indices_from_groups(indexer, groups);
const auto prior = model.get_prior(indices);
const auto truth = model.get_truth(indices);
const auto prior = model->get_prior(indices);
const auto truth = model->get_truth(indices);
return metric(prior, truth);
};
}
Expand Down Expand Up @@ -109,17 +110,20 @@ struct AlwaysAcceptCandidateMetric {
};

template <typename InlierMetric, typename ConsensusMetric,
typename IsValidCandidateMetric, typename GroupKey>
typename IsValidCandidateMetric, typename GroupKey,
typename PriorDistribution>
inline RansacFunctions<ConditionalFit, GroupKey> get_gp_ransac_functions(
const JointDistribution &prior, const MarginalDistribution &truth,
PriorDistribution &&prior, const MarginalDistribution &truth,
const GroupIndexer<GroupKey> &indexer, const InlierMetric &inlier_metric,
const ConsensusMetric &consensus_metric,
const IsValidCandidateMetric &is_valid_candidate_metric) {

static_assert(is_prediction_metric<InlierMetric>::value,
"InlierMetric must be an PredictionMetric.");

const ConditionalGaussian model(prior, truth);
const std::shared_ptr<ConditionalGaussian> model =
std::make_shared<ConditionalGaussian>(
std::forward<PriorDistribution>(prior), truth);

const auto fitter = get_gp_ransac_fitter<GroupKey>(model, indexer);

Expand Down