Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Precise bounds calculation for nfixed_mat_mul #2112

Merged
merged 1 commit into from
Nov 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion doc/source/nfloat.rst
Original file line number Diff line number Diff line change
Expand Up @@ -467,11 +467,33 @@ intermediate results (including rounding errors) lie in `(-1,1)`.
indicate the offset in number of limbs between consecutive entries
and may be negative.

.. function:: void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
.. function:: void _nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs)
void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)

Matrix multiplication using various algorithms.
The *strassen* variant takes a *cutoff* parameter specifying where
to switch from basecase multiplication to Strassen multiplication.
The *classical_precise* version computes with one extra limb of
internal precision; this is only intended for testing purposes.

.. function:: void _nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
void _nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
void _nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs)
void _nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
void _nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs)

For the respective matrix multiplication algorithm, computes bounds
for a size `m \times n \times p` product at precision *nlimbs*
given entrywise bounds *A* and *B*.

The *bound* output is set to a bound for the entries in all intermediate
variables of the computation. This should be < 1 to
ensure correctness. The *error* output is set to a bound for the
output error, measured in ulp.
The caller can assume that the computed bounds are nondecreasing
functions of *A* and *B*.

For complex multiplication, the entrywise bounds are for `A+Bi` and `C+Di`.
8 changes: 8 additions & 0 deletions src/nfloat.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,11 +590,19 @@ void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ys
void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len);
void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len);

void _nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs);
void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);

void _nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
void _nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
void _nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs);
void _nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
void _nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs);


#ifdef __cplusplus
}
#endif
Expand Down
89 changes: 57 additions & 32 deletions src/nfloat/mat_mul.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "double_extras.h"
#include "mpn_extras.h"
#include "gr.h"
#include "gr_vec.h"
Expand Down Expand Up @@ -500,27 +501,39 @@ nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_e
if (Adelta > 10 * prec || Bdelta > 10 * prec)
return GR_UNABLE;

/*
To double check: for Waksman,
* The intermediate entries are bounded by 8n max(|A|,|B|)^2.
* The error, including error from converting
the input matrices, is bounded by 8n ulps.
*/
/* We must scale inputs to 2^(-pad_top) so that intermediate
entries satisfy |x| < 1. */
{
double Abound, Bbound, bound, error;

pad_top = 2;
Abound = Bbound = ldexp(1.0, -pad_top);
/* Option: improve accuracy by adding more trailing guard bits. */
/* pad_bot = 3 + FLINT_BIT_COUNT(n); */
pad_bot = 2;

pad_top = 3 + FLINT_BIT_COUNT(n);
pad_bot = 3 + FLINT_BIT_COUNT(n);
while (1)
{
Aexp = Amax + pad_top;
Bexp = Bmax + pad_top;
extra_bits = Adelta + Bdelta + pad_top + pad_bot;

extra_bits = Adelta + Bdelta + pad_top + pad_bot;
if (extra_bits >= max_extra_bits)
return GR_UNABLE;

if (extra_bits >= max_extra_bits)
return GR_UNABLE;
fbits = prec + extra_bits;
fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;

Aexp = Amax + pad_top;
Bexp = Bmax + pad_top;
fbits = prec + extra_bits;
fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;
_nfixed_mat_mul_bound(&bound, &error, A->r, n, B->c, Abound, Bbound, fnlimbs);

return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);
if (bound < 0.999)
return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);

pad_top++;
Abound *= 0.5;
Bbound *= 0.5;
}
}
}

static void
Expand Down Expand Up @@ -1389,27 +1402,39 @@ nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slo
if (Adelta > 10 * prec || Bdelta > 10 * prec)
return GR_UNABLE;

/*
To double check: for Waksman,
* The intermediate entries are bounded by 8n max(|A|,|B|)^2.
* The error, including error from converting
the input matrices, is bounded by 8n ulps.
*/
/* We must scale inputs to 2^(-pad_top) so that intermediate
entries satisfy |x| < 1. */
{
double Abound, Bbound, bound, error;

pad_top = 2;
Abound = Bbound = ldexp(1.0, -pad_top);
/* Option: improve accuracy by adding more trailing guard bits. */
/* pad_bot = 3 + FLINT_BIT_COUNT(n); */
pad_bot = 2;

pad_top = 3 + FLINT_BIT_COUNT(n);
pad_bot = 3 + FLINT_BIT_COUNT(n);
while (1)
{
Aexp = Amax + pad_top;
Bexp = Bmax + pad_top;
extra_bits = Adelta + Bdelta + pad_top + pad_bot;

extra_bits = Adelta + Bdelta + pad_top + pad_bot;
if (extra_bits >= max_extra_bits)
return GR_UNABLE;

if (extra_bits >= max_extra_bits)
return GR_UNABLE;
fbits = prec + extra_bits;
fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;

Aexp = Amax + pad_top;
Bexp = Bmax + pad_top;
fbits = prec + extra_bits;
fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;
_nfixed_complex_mat_mul_bound(&bound, &error, A->r, n, B->c, Abound, Abound, Bbound, Bbound, fnlimbs);

return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);
if (bound < 0.999)
return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);

pad_top++;
Abound *= 0.5;
Bbound *= 0.5;
}
}
}

FLINT_FORCE_INLINE slong
Expand Down
Loading
Loading