diff --git a/doc/source/nfloat.rst b/doc/source/nfloat.rst
index 5acdaf55da..9ac02883cd 100644
--- a/doc/source/nfloat.rst
+++ b/doc/source/nfloat.rst
@@ -467,7 +467,8 @@ intermediate results (including rounding errors) lie in `(-1,1)`.
indicate the offset in number of limbs between consecutive entries
and may be negative.
-.. function:: void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
+.. function:: void _nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
+ void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs)
void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
@@ -475,3 +476,24 @@ intermediate results (including rounding errors) lie in `(-1,1)`.
Matrix multiplication using various algorithms.
The *strassen* variant takes a *cutoff* parameter specifying where
to switch from basecase multiplication to Strassen multiplication.
+ The *classical_precise* version computes with one extra limb of
+ internal precision; this is only intended for testing purposes.
+
+.. function:: void _nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
+ void _nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
+ void _nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs)
+ void _nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
+ void _nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs)
+
+ For the respective matrix multiplication algorithm, computes bounds
+ for a size `m \times n \times p` product at precision *nlimbs*
+ given entrywise bounds *A* and *B*.
+
+ The *bound* output is set to a bound for the entries in all intermediate
+ variables of the computation. This should be < 1 to
+ ensure correctness. The *error* output is set to a bound for the
+ output error, measured in ulp.
+ The caller can assume that the computed bounds are nondecreasing
+ functions of *A* and *B*.
+
+ For complex multiplication, the entrywise bounds are for `A+Bi` and `C+Di`.
diff --git a/src/nfloat.h b/src/nfloat.h
index d73c32bbc4..c60c32e8b5 100644
--- a/src/nfloat.h
+++ b/src/nfloat.h
@@ -590,11 +590,19 @@ void _nfixed_dot_6(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ys
void _nfixed_dot_7(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len);
void _nfixed_dot_8(nn_ptr res, nn_srcptr x, slong xstride, nn_srcptr y, slong ystride, slong len);
+void _nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_waksman(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
void _nfixed_mat_mul_strassen(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong cutoff, slong nlimbs);
void _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs);
+void _nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
+void _nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
+void _nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs);
+void _nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs);
+void _nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs);
+
+
#ifdef __cplusplus
}
#endif
diff --git a/src/nfloat/mat_mul.c b/src/nfloat/mat_mul.c
index 6643015a91..81d7c11faa 100644
--- a/src/nfloat/mat_mul.c
+++ b/src/nfloat/mat_mul.c
@@ -9,6 +9,7 @@
(at your option) any later version. See .
*/
+#include "double_extras.h"
#include "mpn_extras.h"
#include "gr.h"
#include "gr_vec.h"
@@ -500,27 +501,39 @@ nfloat_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slong max_e
if (Adelta > 10 * prec || Bdelta > 10 * prec)
return GR_UNABLE;
- /*
- To double check: for Waksman,
- * The intermediate entries are bounded by 8n max(|A|,|B|)^2.
- * The error, including error from converting
- the input matrices, is bounded by 8n ulps.
- */
+ /* We must scale inputs to 2^(-pad_top) so that intermediate
+ entries satisfy |x| < 1. */
+ {
+ double Abound, Bbound, bound, error;
+
+ pad_top = 2;
+ Abound = Bbound = ldexp(1.0, -pad_top);
+ /* Option: improve accuracy by adding more trailing guard bits. */
+ /* pad_bot = 3 + FLINT_BIT_COUNT(n); */
+ pad_bot = 2;
- pad_top = 3 + FLINT_BIT_COUNT(n);
- pad_bot = 3 + FLINT_BIT_COUNT(n);
+ while (1)
+ {
+ Aexp = Amax + pad_top;
+ Bexp = Bmax + pad_top;
+ extra_bits = Adelta + Bdelta + pad_top + pad_bot;
- extra_bits = Adelta + Bdelta + pad_top + pad_bot;
+ if (extra_bits >= max_extra_bits)
+ return GR_UNABLE;
- if (extra_bits >= max_extra_bits)
- return GR_UNABLE;
+ fbits = prec + extra_bits;
+ fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;
- Aexp = Amax + pad_top;
- Bexp = Bmax + pad_top;
- fbits = prec + extra_bits;
- fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;
+ _nfixed_mat_mul_bound(&bound, &error, A->r, n, B->c, Abound, Bbound, fnlimbs);
- return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);
+ if (bound < 0.999)
+ return _nfloat_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);
+
+ pad_top++;
+ Abound *= 0.5;
+ Bbound *= 0.5;
+ }
+ }
}
static void
@@ -1389,27 +1402,39 @@ nfloat_complex_mat_mul_fixed(gr_mat_t C, const gr_mat_t A, const gr_mat_t B, slo
if (Adelta > 10 * prec || Bdelta > 10 * prec)
return GR_UNABLE;
- /*
- To double check: for Waksman,
- * The intermediate entries are bounded by 8n max(|A|,|B|)^2.
- * The error, including error from converting
- the input matrices, is bounded by 8n ulps.
- */
+ /* We must scale inputs to 2^(-pad_top) so that intermediate
+ entries satisfy |x| < 1. */
+ {
+ double Abound, Bbound, bound, error;
+
+ pad_top = 2;
+ Abound = Bbound = ldexp(1.0, -pad_top);
+ /* Option: improve accuracy by adding more trailing guard bits. */
+ /* pad_bot = 3 + FLINT_BIT_COUNT(n); */
+ pad_bot = 2;
- pad_top = 3 + FLINT_BIT_COUNT(n);
- pad_bot = 3 + FLINT_BIT_COUNT(n);
+ while (1)
+ {
+ Aexp = Amax + pad_top;
+ Bexp = Bmax + pad_top;
+ extra_bits = Adelta + Bdelta + pad_top + pad_bot;
- extra_bits = Adelta + Bdelta + pad_top + pad_bot;
+ if (extra_bits >= max_extra_bits)
+ return GR_UNABLE;
- if (extra_bits >= max_extra_bits)
- return GR_UNABLE;
+ fbits = prec + extra_bits;
+ fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;
- Aexp = Amax + pad_top;
- Bexp = Bmax + pad_top;
- fbits = prec + extra_bits;
- fnlimbs = (fbits + FLINT_BITS - 1) / FLINT_BITS;
+ _nfixed_complex_mat_mul_bound(&bound, &error, A->r, n, B->c, Abound, Abound, Bbound, Bbound, fnlimbs);
- return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);
+ if (bound < 0.999)
+ return _nfloat_complex_mat_mul_fixed_given_exp(C, A, B, Aexp, Bexp, fnlimbs, ctx);
+
+ pad_top++;
+ Abound *= 0.5;
+ Bbound *= 0.5;
+ }
+ }
}
FLINT_FORCE_INLINE slong
diff --git a/src/nfloat/nfixed.c b/src/nfloat/nfixed.c
index 2c8eaef16e..597bc96a67 100644
--- a/src/nfloat/nfixed.c
+++ b/src/nfloat/nfixed.c
@@ -9,6 +9,7 @@
(at your option) any later version. See .
*/
+#include "double_extras.h"
#include "mpn_extras.h"
#include "gr.h"
#include "gr_vec.h"
@@ -758,6 +759,51 @@ _nfixed_mat_mul_classical(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n,
#undef C_ENTRY
}
+/* todo: optimize */
+/* A is (m x n), B is (n x p), C is (m x p) */
+void
+_nfixed_mat_mul_classical_precise(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, slong nlimbs)
+{
+ slong i;
+ nn_ptr t, tA, tB, tC, u;
+ nn_srcptr s;
+
+ t = flint_malloc(((m * n) + (n * p) + (m * p)) * (nlimbs + 2) * sizeof(ulong));
+ tA = t;
+ tB = tA + (m * n) * (nlimbs + 2);
+ tC = tB + (n * p) * (nlimbs + 2);
+
+ for (i = 0; i < m * n; i++)
+ {
+ s = A + i * (nlimbs + 1);
+ u = tA + i* (nlimbs + 2);
+ flint_mpn_copyi(u + 2, s + 1, nlimbs);
+ u[0] = s[0];
+ u[1] = 0;
+ }
+
+ for (i = 0; i < n * p; i++)
+ {
+ s = B + i * (nlimbs + 1);
+ u = tB + i * (nlimbs + 2);
+ flint_mpn_copyi(u + 2, s + 1, nlimbs);
+ u[0] = s[0];
+ u[1] = 0;
+ }
+
+ _nfixed_mat_mul_classical(tC, tA, tB, m, n, p, nlimbs + 1);
+
+ for (i = 0; i < m * p; i++)
+ {
+ s = tC + i * (nlimbs + 2);
+ u = C + i * (nlimbs + 1);
+ flint_mpn_copyi(u + 1, s + 2, nlimbs);
+ u[0] = s[0];
+ }
+
+ flint_free(t);
+}
+
/* compute c += (a1 + b1) * (a2 + b2) */
/* val0, val1, val2 are scratch space */
FLINT_FORCE_INLINE void
@@ -1225,6 +1271,7 @@ _nfixed_mat_mul_strassen2(_nfixed_mat_t C, const _nfixed_mat_t A, const _nfixed_
nn = FLINT_MIN(ar, ac);
nn = FLINT_MIN(nn, bc);
+ /* Important: if the cutoff handling changes, _nfixed_mat_mul_bound_strassen must change too. */
if (cutoff < 0)
cutoff = nfixed_mat_mul_strassen_cutoff(nn, ac & 1, nlimbs);
else
@@ -1392,6 +1439,7 @@ _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, s
d = FLINT_MIN(m, n);
d = FLINT_MIN(d, p);
+ /* Important: if the cutoff handling changes, _nfixed_mat_mul_bound must change too. */
if (d > 10)
{
cutoff = nfixed_mat_mul_strassen_cutoff(d, n & 1, nlimbs);
@@ -1408,3 +1456,172 @@ _nfixed_mat_mul(nn_ptr C, nn_srcptr A, nn_srcptr B, slong m, slong n, slong p, s
else
_nfixed_mat_mul_classical(C, A, B, m, n, p, nlimbs);
}
+
+/*
+ Given an m x n x p matrix multiplication with inputs bounded
+ entrywise by A, B and nlimbs precision:
+
+ - Set *bound* to a bound for the entries in all intermediate
+ variables of the computation. This should be < 1 to
+ ensure correctness.
+ - Set *error* to a bound for the output error, measured in ulp.
+
+ The caller can assume that the bound is a nondecreasing function
+ of A and B.
+
+ IMPORTANT: when changing the algorithm in _nfixed_mat_mul, this
+ must be changed to correspond.
+*/
+
+void
+_nfixed_mat_mul_bound_classical(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
+{
+ double fixup = 1.0 + 1e-6;
+
+ /* Error bound (in ulp) for naive scalar multiplication, and for dot product */
+ double error_mul = (2 * nlimbs - 1);
+ double error_dot = n * error_mul;
+
+ *bound = (n * A * B) * fixup;
+ *error = error_dot * fixup;
+}
+
+void
+_nfixed_mat_mul_bound_waksman(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
+{
+ double fixup = 1.0 + 1e-6;
+
+ /* Error bound (in ulp) for naive scalar multiplication */
+ double error_mul = (2 * nlimbs - 1);
+
+ *bound = FLINT_MAX(A + B, 6 * (n / 2) * (A + B) * (A + B) + A * B) * fixup;
+ *error = ((6 * (n / 2) + 1) * error_mul + 5) * fixup;
+}
+
+void
+_nfixed_mat_mul_bound_strassen(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong cutoff, slong nlimbs)
+{
+ double fixup = 1.0 + 1e-6;
+ slong d;
+
+ /* Error bound (in ulp) for naive scalar multiplication, and for dot product */
+ double error_mul = (2 * nlimbs - 1);
+ double error_dot = n * error_mul;
+
+ d = FLINT_MIN(m, n);
+ d = FLINT_MIN(d, p);
+
+ if (cutoff < 0)
+ cutoff = nfixed_mat_mul_strassen_cutoff(d, n, nlimbs);
+ else
+ cutoff = FLINT_MAX(cutoff, 2);
+
+ if (d < cutoff)
+ {
+ if (nfixed_mat_mul_use_waksman(d, nlimbs))
+ _nfixed_mat_mul_bound_waksman(bound, error, m, n, p, A, B, nlimbs);
+ else
+ _nfixed_mat_mul_bound_classical(bound, error, m, n, p, A, B, nlimbs);
+ return;
+ }
+
+ slong mm, nn, pp;
+ double bound_transformed_A, bound_transformed_B;
+ double bound_everything, bound_recursive, error_recursive;
+ double bound_recombination, error_recombination;
+ double ulp;
+
+ /* Bound for entries of transformed block matrices */
+ /* S1 = A22 + A12 <= 2A
+ S2 = A22 - A21 <= 2A
+ S3 = S2 + A12 <= 3A
+ S4 = S3 - A11 <= 3A, and similarly Ti for B */
+ bound_transformed_A = 3.0 * A;
+ bound_transformed_B = 3.0 * B;
+ bound_everything = FLINT_MAX(bound_transformed_A, bound_transformed_B);
+
+ /* Bound intermediate entries and errors for recursive multiplications. */
+ mm = m / 2;
+ nn = n / 2;
+ pp = p / 2;
+ _nfixed_mat_mul_bound_strassen(&bound_recursive, &error_recursive, mm, nn, pp, bound_transformed_A, bound_transformed_B, cutoff, nlimbs);
+ bound_everything = FLINT_MAX(bound_everything, bound_recursive);
+
+ /* Bound for recombinations. We don't use bound_recursive here,
+ because this can be a huge overestimate if the basecase
+ is nonclassical multiplication. Instead, we use the
+ theoretical bounds for the subproducts and add the
+ bound u = error_recursive (in the event the recursive
+ multiplications were not rounded down). */
+ /* P1 = S1 T1 <= 4 nn A B + u
+ P2 = S2 T2 <= 4 nn A B + u
+ P3 = S3 T3 <= 9 nn A B + u
+ P4 = A11 B11 <= nn A B + u
+ P5 = A12 B21 <= nn A B + u
+ P6 = S4 B12 <= 3 nn A B + u
+ P7 = A21 T4 <= 3 nn A B + u
+ U1 = P3 + P5 <= 10 nn A B + 2u
+ U2 = P1 - U1 <= 14 nn A B + 3u
+ U3 = U1 - P2 <= 14 nn A B + 3u
+ C11 = P4 + P5 <= 2 nn A B + 2u
+ C12 = U3 - P6 <= 17 nn A B + 4u
+ C21 = U2 - P7 <= 17 nn A B + 4u
+ C22 = P2 + U2 <= 18 nn A B + 4u
+ */
+
+ ulp = ldexp(1.0, FLINT_MAX(-128, -nlimbs * FLINT_BITS));
+
+ error_recombination = 4 * error_recursive;
+ bound_recombination = 18 * nn * A * B + error_recombination * ulp;
+
+ /* Bound for border corrections when m, n, and/or p is odd.
+ Todo: could be added conditionally. Assumes border
+ corrections use classical multiplication. */
+ bound_recombination += A * B;
+ bound_recombination = FLINT_MAX(bound_recombination, n * A * B);
+ error_recombination += error_mul;
+ error_recombination = FLINT_MAX(error_recombination, error_dot);
+
+ bound_everything = FLINT_MAX(bound_everything, bound_recombination);
+
+ *bound = bound_everything * fixup;
+ *error = error_recombination * fixup;
+}
+
+void
+_nfixed_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, slong nlimbs)
+{
+ slong d, cutoff;
+
+ d = FLINT_MIN(m, n);
+ d = FLINT_MIN(d, p);
+
+ if (d > 10)
+ {
+ cutoff = nfixed_mat_mul_strassen_cutoff(d, n & 1, nlimbs);
+
+ if (n >= cutoff)
+ {
+ _nfixed_mat_mul_bound_strassen(bound, error, m, n, p, A, B, -1, nlimbs);
+ return;
+ }
+ }
+
+ if (nfixed_mat_mul_use_waksman(d, nlimbs))
+ _nfixed_mat_mul_bound_waksman(bound, error, m, n, p, A, B, nlimbs);
+ else
+ _nfixed_mat_mul_bound_classical(bound, error, m, n, p, A, B, nlimbs);
+}
+
+/* Karatsuba formula */
+/* (A C - B D) + ((A + B)(C + D) - A C - B D) i */
+void
+_nfixed_complex_mat_mul_bound(double * bound, double * error, slong m, slong n, slong p, double A, double B, double C, double D, slong nlimbs)
+{
+ double rbound, rerror, fixup = 1.0 + 1e-6;
+
+ _nfixed_mat_mul_bound(&rbound, &rerror, m, n, p, A + B, C + D, nlimbs);
+
+ (*bound) = FLINT_MAX(3.0 * rbound, FLINT_MAX(A + B, C + D)) * fixup;
+ (*error) = rerror * (3.0 * fixup);
+}
diff --git a/src/nfloat/test/main.c b/src/nfloat/test/main.c
index 88f9a6c362..0054610e07 100644
--- a/src/nfloat/test/main.c
+++ b/src/nfloat/test/main.c
@@ -17,6 +17,9 @@
#include "t-mat_mul.c"
#include "t-nfixed_dot.c"
#include "t-nfixed_mat_mul.c"
+#include "t-nfixed_mat_mul_classical.c"
+#include "t-nfixed_mat_mul_strassen.c"
+#include "t-nfixed_mat_mul_waksman.c"
#include "t-nfloat.c"
#include "t-nfloat_complex.c"
@@ -30,6 +33,9 @@ test_struct tests[] =
TEST_FUNCTION(mat_mul),
TEST_FUNCTION(nfixed_dot),
TEST_FUNCTION(nfixed_mat_mul),
+ TEST_FUNCTION(nfixed_mat_mul_classical),
+ TEST_FUNCTION(nfixed_mat_mul_strassen),
+ TEST_FUNCTION(nfixed_mat_mul_waksman),
TEST_FUNCTION(nfloat),
TEST_FUNCTION(nfloat_complex),
};
diff --git a/src/nfloat/test/t-nfixed_mat_mul.c b/src/nfloat/test/t-nfixed_mat_mul.c
index 4ef678057d..54788e4f4b 100644
--- a/src/nfloat/test/t-nfixed_mat_mul.c
+++ b/src/nfloat/test/t-nfixed_mat_mul.c
@@ -10,6 +10,7 @@
*/
#include "test_helpers.h"
+#include "double_extras.h"
#include "fmpq.h"
#include "arf.h"
#include "gr_vec.h"
@@ -21,23 +22,35 @@ TEST_FUNCTION_START(nfixed_mat_mul, state)
slong iter, m, n, p, i, nlimbs;
nn_ptr A, B, C, D, t;
nn_ptr a;
- int which;
- slong MAXN = 12;
+ slong MAXN = 20;
slong MINLIMBS = 2;
slong MAXLIMBS = 12;
- for (iter = 0; iter < 10000 * flint_test_multiplier(); iter++)
+ for (iter = 0; iter < 1000 * flint_test_multiplier(); iter++)
{
- which = n_randint(state, 6);
-
m = 1 + n_randint(state, MAXN);
n = 1 + n_randint(state, MAXN);
p = 1 + n_randint(state, MAXN);
nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1);
- ulong maxerr = 2 * (2 * nlimbs - 1) * n;
+ ulong maxerr;
+
+ int top;
+ double bound, error, classical_precise_error;
+
+ top = 1;
+ while (1)
+ {
+ _nfixed_mat_mul_bound(&bound, &error, m, n, p, ldexp(1.0, -top), ldexp(1.0, -top), nlimbs);
+ if (bound < 1.0)
+ break;
+ top++;
+ }
+
+ classical_precise_error = 1.01;
+ maxerr = (ulong) (error + classical_precise_error + 1.0);
A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong));
B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong));
@@ -50,7 +63,7 @@ TEST_FUNCTION_START(nfixed_mat_mul, state)
a = A + i * (nlimbs + 1);
a[0] = n_randint(state, 2);
flint_mpn_rrandom(a + 1, state, nlimbs);
- a[nlimbs] >>= 10;
+ a[nlimbs] >>= top;
}
for (i = 0; i < n * p; i++)
@@ -58,7 +71,7 @@ TEST_FUNCTION_START(nfixed_mat_mul, state)
a = B + i * (nlimbs + 1);
a[0] = n_randint(state, 2);
flint_mpn_rrandom(a + 1, state, nlimbs);
- a[nlimbs] >>= 10;
+ a[nlimbs] >>= top;
}
for (i = 0; i < m * p; i++)
@@ -72,12 +85,8 @@ TEST_FUNCTION_START(nfixed_mat_mul, state)
flint_mpn_rrandom(a + 1, state, nlimbs);
}
- _nfixed_mat_mul_classical(C, A, B, m, n, p, nlimbs);
-
- if (which == 0)
- _nfixed_mat_mul_waksman(D, A, B, m, n, p, nlimbs);
- else
- _nfixed_mat_mul_strassen(D, A, B, m, n, p, which, nlimbs);
+ _nfixed_mat_mul_classical_precise(C, A, B, m, n, p, nlimbs);
+ _nfixed_mat_mul(D, A, B, m, n, p, nlimbs);
for (i = 0; i < m * p; i++)
{
diff --git a/src/nfloat/test/t-nfixed_mat_mul_classical.c b/src/nfloat/test/t-nfixed_mat_mul_classical.c
new file mode 100644
index 0000000000..a6c86ab88c
--- /dev/null
+++ b/src/nfloat/test/t-nfixed_mat_mul_classical.c
@@ -0,0 +1,110 @@
+/*
+ Copyright (C) 2024 Fredrik Johansson
+
+ This file is part of FLINT.
+
+ FLINT is free software: you can redistribute it and/or modify it under
+ the terms of the GNU Lesser General Public License (LGPL) as published
+ by the Free Software Foundation; either version 3 of the License, or
+ (at your option) any later version. See .
+*/
+
+#include "test_helpers.h"
+#include "double_extras.h"
+#include "fmpq.h"
+#include "arf.h"
+#include "gr_vec.h"
+#include "gr_special.h"
+#include "nfloat.h"
+
+TEST_FUNCTION_START(nfixed_mat_mul_classical, state)
+{
+ slong iter, m, n, p, i, nlimbs;
+ nn_ptr A, B, C, D, t;
+ nn_ptr a;
+
+ slong MAXN = 20;
+ slong MINLIMBS = 2;
+ slong MAXLIMBS = 12;
+
+ for (iter = 0; iter < 1000 * flint_test_multiplier(); iter++)
+ {
+ m = 1 + n_randint(state, MAXN);
+ n = 1 + n_randint(state, MAXN);
+ p = 1 + n_randint(state, MAXN);
+
+ nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1);
+
+ ulong maxerr;
+
+ int top;
+ double bound, error, classical_precise_error;
+
+ top = 1;
+ while (1)
+ {
+ _nfixed_mat_mul_bound_classical(&bound, &error, m, n, p, ldexp(1.0, -top), ldexp(1.0, -top), nlimbs);
+ if (bound < 1.0)
+ break;
+ top++;
+ }
+
+ classical_precise_error = 1.01;
+ maxerr = (ulong) (error + classical_precise_error + 1.0);
+
+ A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong));
+ B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong));
+ C = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong));
+ D = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong));
+ t = flint_malloc((nlimbs + 1) * sizeof(ulong));
+
+ for (i = 0; i < m * n; i++)
+ {
+ a = A + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ a[nlimbs] >>= top;
+ }
+
+ for (i = 0; i < n * p; i++)
+ {
+ a = B + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ a[nlimbs] >>= top;
+ }
+
+ for (i = 0; i < m * p; i++)
+ {
+ a = C + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+
+ a = D + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ }
+
+ _nfixed_mat_mul_classical_precise(C, A, B, m, n, p, nlimbs);
+ _nfixed_mat_mul_classical(D, A, B, m, n, p, nlimbs);
+
+ for (i = 0; i < m * p; i++)
+ {
+ nfixed_sub(t, C + i * (nlimbs + 1), D + i * (nlimbs + 1), nlimbs);
+
+ if (!flint_mpn_zero_p(t + 2, nlimbs - 1) || t[1] > maxerr)
+ {
+ TEST_FUNCTION_FAIL("nlimbs = %wd, m = %wd, n = %wd, p = %wd, top = %d\n\nt = %{ulong*}, maxerr = %wu\n\nA = %{ulong*}\n\nB = %{ulong*}\n\nC = %{ulong*}\n\nD = %{ulong*}\n\n",
+ nlimbs, m, n, p, top,
+ t, nlimbs + 1, maxerr, A, m * n * (nlimbs + 1), B, n * p * (nlimbs + 1), C, m * p * (nlimbs + 1), D, m * p * (nlimbs + 1));
+ }
+ }
+
+ flint_free(A);
+ flint_free(B);
+ flint_free(C);
+ flint_free(D);
+ }
+
+ TEST_FUNCTION_END(state);
+}
\ No newline at end of file
diff --git a/src/nfloat/test/t-nfixed_mat_mul_strassen.c b/src/nfloat/test/t-nfixed_mat_mul_strassen.c
new file mode 100644
index 0000000000..3e6438db8b
--- /dev/null
+++ b/src/nfloat/test/t-nfixed_mat_mul_strassen.c
@@ -0,0 +1,113 @@
+/*
+ Copyright (C) 2024 Fredrik Johansson
+
+ This file is part of FLINT.
+
+ FLINT is free software: you can redistribute it and/or modify it under
+ the terms of the GNU Lesser General Public License (LGPL) as published
+ by the Free Software Foundation; either version 3 of the License, or
+ (at your option) any later version. See .
+*/
+
+#include "test_helpers.h"
+#include "double_extras.h"
+#include "fmpq.h"
+#include "arf.h"
+#include "gr_vec.h"
+#include "gr_special.h"
+#include "nfloat.h"
+
+TEST_FUNCTION_START(nfixed_mat_mul_strassen, state)
+{
+ slong iter, m, n, p, i, nlimbs;
+ nn_ptr A, B, C, D, t;
+ nn_ptr a;
+ slong cutoff;
+
+ slong MAXN = 20;
+ slong MINLIMBS = 2;
+ slong MAXLIMBS = 12;
+
+ for (iter = 0; iter < 1000 * flint_test_multiplier(); iter++)
+ {
+ cutoff = n_randint(state, 6);
+
+ m = 1 + n_randint(state, MAXN);
+ n = 1 + n_randint(state, MAXN);
+ p = 1 + n_randint(state, MAXN);
+
+ nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1);
+
+ ulong maxerr;
+
+ int top;
+ double bound, error, classical_precise_error;
+
+ top = 1;
+ while (1)
+ {
+ _nfixed_mat_mul_bound_strassen(&bound, &error, m, n, p, ldexp(1.0, -top), ldexp(1.0, -top), cutoff, nlimbs);
+ if (bound < 1.0)
+ break;
+ top++;
+ }
+
+ classical_precise_error = 1.01;
+ maxerr = (ulong) (error + classical_precise_error + 1.0);
+
+ A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong));
+ B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong));
+ C = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong));
+ D = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong));
+ t = flint_malloc((nlimbs + 1) * sizeof(ulong));
+
+ for (i = 0; i < m * n; i++)
+ {
+ a = A + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ a[nlimbs] >>= top;
+ }
+
+ for (i = 0; i < n * p; i++)
+ {
+ a = B + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ a[nlimbs] >>= top;
+ }
+
+ for (i = 0; i < m * p; i++)
+ {
+ a = C + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+
+ a = D + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ }
+
+ _nfixed_mat_mul_classical(C, A, B, m, n, p, nlimbs);
+ _nfixed_mat_mul_strassen(D, A, B, m, n, p, cutoff, nlimbs);
+
+ for (i = 0; i < m * p; i++)
+ {
+ nfixed_sub(t, C + i * (nlimbs + 1), D + i * (nlimbs + 1), nlimbs);
+
+ if (!flint_mpn_zero_p(t + 2, nlimbs - 1) || t[1] > maxerr)
+ {
+ TEST_FUNCTION_FAIL("nlimbs = %wd, m = %wd, n = %wd, p = %wd\n\nt = %{ulong*}, maxerr = %wu\n\nA = %{ulong*}\n\nB = %{ulong*}\n\nC = %{ulong*}\n\nD = %{ulong*}\n\n",
+ nlimbs, m, n, p,
+ t, nlimbs + 1, maxerr, A, m * n * (nlimbs + 1), B, n * p * (nlimbs + 1), C, m * p * (nlimbs + 1), D, m * p * (nlimbs + 1));
+ }
+ }
+
+ flint_free(A);
+ flint_free(B);
+ flint_free(C);
+ flint_free(D);
+ }
+
+ TEST_FUNCTION_END(state);
+}
\ No newline at end of file
diff --git a/src/nfloat/test/t-nfixed_mat_mul_waksman.c b/src/nfloat/test/t-nfixed_mat_mul_waksman.c
new file mode 100644
index 0000000000..2a1495a43d
--- /dev/null
+++ b/src/nfloat/test/t-nfixed_mat_mul_waksman.c
@@ -0,0 +1,110 @@
+/*
+ Copyright (C) 2024 Fredrik Johansson
+
+ This file is part of FLINT.
+
+ FLINT is free software: you can redistribute it and/or modify it under
+ the terms of the GNU Lesser General Public License (LGPL) as published
+ by the Free Software Foundation; either version 3 of the License, or
+ (at your option) any later version. See .
+*/
+
+#include "test_helpers.h"
+#include "double_extras.h"
+#include "fmpq.h"
+#include "arf.h"
+#include "gr_vec.h"
+#include "gr_special.h"
+#include "nfloat.h"
+
+TEST_FUNCTION_START(nfixed_mat_mul_waksman, state)
+{
+ slong iter, m, n, p, i, nlimbs;
+ nn_ptr A, B, C, D, t;
+ nn_ptr a;
+
+ slong MAXN = 20;
+ slong MINLIMBS = 2;
+ slong MAXLIMBS = 12;
+
+ for (iter = 0; iter < 1000 * flint_test_multiplier(); iter++)
+ {
+ m = 1 + n_randint(state, MAXN);
+ n = 1 + n_randint(state, MAXN);
+ p = 1 + n_randint(state, MAXN);
+
+ nlimbs = MINLIMBS + n_randint(state, MAXLIMBS - MINLIMBS + 1);
+
+ ulong maxerr;
+
+ int top;
+ double bound, error, classical_precise_error;
+
+ top = 1;
+ while (1)
+ {
+ _nfixed_mat_mul_bound_waksman(&bound, &error, m, n, p, ldexp(1.0, -top), ldexp(1.0, -top), nlimbs);
+ if (bound < 1.0)
+ break;
+ top++;
+ }
+
+ classical_precise_error = 1.01;
+ maxerr = (ulong) (error + classical_precise_error + 1.0);
+
+ A = flint_malloc((nlimbs + 1) * (m * n) * sizeof(ulong));
+ B = flint_malloc((nlimbs + 1) * (n * p) * sizeof(ulong));
+ C = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong));
+ D = flint_malloc((nlimbs + 1) * (m * p) * sizeof(ulong));
+ t = flint_malloc((nlimbs + 1) * sizeof(ulong));
+
+ for (i = 0; i < m * n; i++)
+ {
+ a = A + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ a[nlimbs] >>= top;
+ }
+
+ for (i = 0; i < n * p; i++)
+ {
+ a = B + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ a[nlimbs] >>= top;
+ }
+
+ for (i = 0; i < m * p; i++)
+ {
+ a = C + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+
+ a = D + i * (nlimbs + 1);
+ a[0] = n_randint(state, 2);
+ flint_mpn_rrandom(a + 1, state, nlimbs);
+ }
+
+ _nfixed_mat_mul_classical_precise(C, A, B, m, n, p, nlimbs);
+ _nfixed_mat_mul_waksman(D, A, B, m, n, p, nlimbs);
+
+ for (i = 0; i < m * p; i++)
+ {
+ nfixed_sub(t, C + i * (nlimbs + 1), D + i * (nlimbs + 1), nlimbs);
+
+ if (!flint_mpn_zero_p(t + 2, nlimbs - 1) || t[1] > maxerr)
+ {
+ TEST_FUNCTION_FAIL("nlimbs = %wd, m = %wd, n = %wd, p = %wd\n\nt = %{ulong*}, maxerr = %wu\n\nA = %{ulong*}\n\nB = %{ulong*}\n\nC = %{ulong*}\n\nD = %{ulong*}\n\n",
+ nlimbs, m, n, p,
+ t, nlimbs + 1, maxerr, A, m * n * (nlimbs + 1), B, n * p * (nlimbs + 1), C, m * p * (nlimbs + 1), D, m * p * (nlimbs + 1));
+ }
+ }
+
+ flint_free(A);
+ flint_free(B);
+ flint_free(C);
+ flint_free(D);
+ }
+
+ TEST_FUNCTION_END(state);
+}
\ No newline at end of file