Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lagrange basis aurora speedup #31

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions libiop/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ target_link_libraries(test_gf256 iop gtest_main)
add_executable(test_lagrange tests/algebra/test_lagrange.cpp)
target_link_libraries(test_lagrange iop gtest_main)

add_executable(test_merkle_tree tests/algebra/test_merkle_tree.cpp)
add_executable(test_merkle_tree tests/bcs/test_merkle_tree.cpp)
target_link_libraries(test_merkle_tree iop gtest_main)

add_executable(test_linearized_polynomial tests/algebra/test_linearized_polynomial.cpp)
Expand Down Expand Up @@ -232,7 +232,7 @@ target_link_libraries(test_poseidon iop gtest_main)
add_executable(test_pow tests/snark/test_pow.cpp)
target_link_libraries(test_pow iop gtest_main)

add_executable(test_bcs_transformation tests/snark/test_bcs_transformation.cpp)
add_executable(test_bcs_transformation tests/bcs/test_bcs_transformation.cpp)
target_link_libraries(test_bcs_transformation iop gtest_main)

add_executable(test_serialization tests/snark/test_serialization.cpp)
Expand All @@ -248,4 +248,4 @@ add_executable(test_ligero_snark tests/snark/test_ligero_snark.cpp)
target_link_libraries(test_ligero_snark iop gtest_main)

add_executable(test_linking tests/snark/test_linking.cpp)
target_link_libraries(test_linking iop gtest_main)
target_link_libraries(test_linking iop gtest_main)
2 changes: 1 addition & 1 deletion libiop/profiling/instrument_aurora_snark.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ int main(int argc, const char * argv[])
instrument_aurora_snark<libff::alt_bn128_Fr, libff::alt_bn128_Fr>(
default_vals, ldt_reducer_soundness_type,
fri_soundness_type, optimize_localization);
}
}
break;
default:
throw std::invalid_argument("Field size not supported.");
Expand Down
10 changes: 7 additions & 3 deletions libiop/protocols/encoded/lincheck/basic_lincheck_aux.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ Basic R1CS multi lincheck virtual oracle

#include "libiop/relations/sparse_matrix.hpp"
#include "libiop/algebra/lagrange.hpp"
#include "libiop/algebra/polynomials/lagrange_polynomial.hpp"
#include "libiop/iop/iop.hpp"


namespace libiop {

template<typename FieldT>
Expand Down Expand Up @@ -53,8 +53,12 @@ class multi_lincheck_virtual_oracle : public virtual_oracle<FieldT> {
* in the lagrange case, however some of them require minor refactors to the interface.
*/
const bool use_lagrange_ = false;
std::vector<FieldT> alpha_powers_;
lagrange_polynomial<FieldT> p_alpha_;
std::vector<FieldT> p_alpha_evals_;
std::vector<FieldT> p_alpha_ABC_evals_;
vanishing_polynomial<FieldT> variable_domain_vanishing_polynomial_;
vanishing_polynomial<FieldT> constraint_domain_vanishing_polynomial_;

std::shared_ptr<lagrange_cache<FieldT> > lagrange_coefficients_cache_;
public:
multi_lincheck_virtual_oracle(
Expand All @@ -79,4 +83,4 @@ class multi_lincheck_virtual_oracle : public virtual_oracle<FieldT> {

#include "libiop/protocols/encoded/lincheck/basic_lincheck_aux.tcc"

#endif // LIBIOP_PROTOCOLS_ENCODED_LINCHECK_BASIC_LINCHECK_AUX_HPP_
#endif // LIBIOP_PROTOCOLS_ENCODED_LINCHECK_BASIC_LINCHECK_AUX_HPP_
95 changes: 56 additions & 39 deletions libiop/protocols/encoded/lincheck/basic_lincheck_aux.tcc
Original file line number Diff line number Diff line change
Expand Up @@ -34,30 +34,16 @@ void multi_lincheck_virtual_oracle<FieldT>::set_challenge(const FieldT &alpha, c
}
this->r_Mz_ = r_Mz;

enter_block("multi_lincheck compute alpha powers");
/** Set alpha powers */
std::vector<FieldT> alpha_powers;
alpha_powers.reserve(this->constraint_domain_.num_elements());
// TODO: Make a method for this in algebra, that lowers the data dependency
FieldT cur = FieldT::one();
for (std::size_t i = 0; i < this->constraint_domain_.num_elements(); i++) {
alpha_powers.emplace_back(cur);
cur *= alpha;
}
leave_block("multi_lincheck compute alpha powers");
enter_block("multi_lincheck compute random polynomial evaluations");

enter_block("multi_lincheck compute p_alpha_prime");
/** This essentially places alpha powers into the correct spots,
* such that the zeroes when the |constraint domain| < summation domain
* are placed correctly. */
std::vector<FieldT> p_alpha_prime_over_summation_domain(
this->summation_domain_.num_elements(), FieldT::zero());
for (std::size_t i = 0; i < this->constraint_domain_.num_elements(); i++) {
const std::size_t element_index = this->summation_domain_.reindex_by_subset(
this->constraint_domain_.dimension(), i);
p_alpha_prime_over_summation_domain[element_index] = alpha_powers[i];
}
leave_block("multi_lincheck compute p_alpha_prime");
/* Set alpha polynomial, variable and constraint domain polynomials, and their evaluations */

this->p_alpha_ = lagrange_polynomial<FieldT>(alpha, this->constraint_domain_);
this->p_alpha_evals_ = this->p_alpha_.evaluations_over_field_subset(this->constraint_domain_);
this->variable_domain_vanishing_polynomial_ = vanishing_polynomial<FieldT>(this->variable_domain_);
this->constraint_domain_vanishing_polynomial_ = vanishing_polynomial<FieldT>(this->constraint_domain_);

leave_block("multi_lincheck compute random polynomial evaluations");

/* Set p_alpha_ABC_evals */
enter_block("multi_lincheck compute p_alpha_ABC");
Expand All @@ -79,22 +65,21 @@ void multi_lincheck_virtual_oracle<FieldT>::set_challenge(const FieldT &alpha, c
const std::size_t summation_index = this->summation_domain_.reindex_by_subset(
this->variable_domain_.dimension(), variable_index);
p_alpha_ABC_evals[summation_index] +=
this->r_Mz_[m_index] * term.coeff_ * alpha_powers[i];
this->r_Mz_[m_index] * term.coeff_ * this->p_alpha_evals_[i];
}
}
}
leave_block("multi_lincheck compute p_alpha_ABC");
// To use lagrange, the following IFFTs must also be moved to evaluated contents
if (this->use_lagrange_)
{
this->alpha_powers_ = alpha_powers;
this->p_alpha_ABC_evals_ = p_alpha_ABC_evals;
// this->alpha_powers_ = alpha_powers;
// this->p_alpha_ABC_evals_ = p_alpha_ABC_evals;
}
enter_block("multi_lincheck IFFT p_alphas");

this->p_alpha_ABC_ = polynomial<FieldT>(
IFFT_over_field_subset<FieldT>(p_alpha_ABC_evals, this->summation_domain_));
this->p_alpha_prime_ = polynomial<FieldT>(
IFFT_over_field_subset<FieldT>(p_alpha_prime_over_summation_domain, this->summation_domain_));
leave_block("multi_lincheck IFFT p_alphas");
}

Expand All @@ -108,9 +93,33 @@ std::shared_ptr<std::vector<FieldT>> multi_lincheck_virtual_oracle<FieldT>::eval
throw std::invalid_argument("multi_lincheck uses more constituent oracles than what was provided.");
}

/* p_{alpha}^1 in [BCRSVW18] */
std::vector<FieldT> p_alpha_prime_over_codeword_domain =
FFT_over_field_subset<FieldT>(this->p_alpha_prime_.coefficients(), this->codeword_domain_);
/* p_{alpha}^1 in [BCRSVW18], but now using the lagrange polynomial from
* [BCGGRS19] instead of powers of alpha. */
/* Compute p_alpha_prime. */
std::vector<FieldT> p_alpha_prime_over_codeword_domain;

/* If |variable_domain| > |constraint_domain|, we multiply the Lagrange sampled
polynomial (p_alpha_prime) by Z_{variable_domain}*Z_{constraint_domain}^-1*/
if (this->variable_domain_.num_elements() <= this->constraint_domain_.num_elements()){
p_alpha_prime_over_codeword_domain =
this->p_alpha_.evaluations_over_field_subset(this->codeword_domain_);
}else{
/* inverses of the evaluations of constraint domain polynomial */
std::vector<FieldT> constraint_domain_vanishing_polynomial_inverses;
std::vector<FieldT> variable_domain_vanishing_polynomial_evaluations;
p_alpha_prime_over_codeword_domain = this->p_alpha_.evaluations_over_field_subset(this->codeword_domain_);

variable_domain_vanishing_polynomial_evaluations = this->variable_domain_vanishing_polynomial_
.evaluations_over_field_subset(this->codeword_domain_);
constraint_domain_vanishing_polynomial_inverses = batch_inverse(this->constraint_domain_vanishing_polynomial_
.evaluations_over_field_subset(this->codeword_domain_));

for (int i = 0; i < variable_domain_vanishing_polynomial_evaluations.size(); i++){
p_alpha_prime_over_codeword_domain[i] *= variable_domain_vanishing_polynomial_evaluations[i]
* constraint_domain_vanishing_polynomial_inverses[i];
}

}

/* p_{alpha}^2 in [BCRSVW18] */
const std::vector<FieldT> p_alpha_ABC_over_codeword_domain =
Expand Down Expand Up @@ -150,23 +159,31 @@ FieldT multi_lincheck_virtual_oracle<FieldT>::evaluation_at_point(
const std::vector<FieldT> &constituent_oracle_evaluations) const
{
UNUSED(evaluation_position);
FieldT p_alpha_prime_X;
if (constituent_oracle_evaluations.size() != this->matrices_.size() + 1)
{
throw std::invalid_argument("multi_lincheck uses more constituent oracles than what was provided.");
}

FieldT p_alpha_prime_X = this->p_alpha_prime_.evaluation_at_point(evaluation_point);
/* If |variable_domain| > |constraint_domain|, we multiply the Lagrange sampled
polynomial (p_alpha_prime) by Z_{variable_domain}*Z_{constraint_domain}^-1.
This is done for a single point rather than across a domain.*/

if (this->variable_domain_.num_elements() < this->constraint_domain_.num_elements()){
p_alpha_prime_X = this->p_alpha_.evaluation_at_point(evaluation_point);
}
else{
p_alpha_prime_X = this->p_alpha_.evaluation_at_point(evaluation_point) *
this->variable_domain_vanishing_polynomial_.evaluation_at_point(evaluation_point) *
this->constraint_domain_vanishing_polynomial_.evaluation_at_point(evaluation_point).inverse() ;
}

FieldT p_alpha_ABC_X = this->p_alpha_ABC_.evaluation_at_point(evaluation_point);

if (this->use_lagrange_)
{
const std::vector<FieldT> lagrange_coefficients =
this->lagrange_coefficients_cache_->coefficients_for(evaluation_point);
for (size_t i = 0; i < this->constraint_domain_.num_elements(); ++i)
{
const std::size_t summation_index = this->summation_domain_.reindex_by_subset(
this->constraint_domain_.dimension(), i);
p_alpha_prime_X += lagrange_coefficients[summation_index] * this->alpha_powers_[i];
}
for (std::size_t i = 0; i < this->summation_domain_.num_elements(); ++i)
{
p_alpha_ABC_X += lagrange_coefficients[i] * this->p_alpha_ABC_evals_[i];
Expand All @@ -182,4 +199,4 @@ FieldT multi_lincheck_virtual_oracle<FieldT>::evaluation_at_point(
return (f_combined_Mz_x * p_alpha_prime_X - fz_X * p_alpha_ABC_X);
}

} // libiop
} // libiop
Loading