Skip to content

Commit

Permalink
Merge pull request #1743 from flintlib/dot
Browse files Browse the repository at this point in the history
Faster fmpz dot products
  • Loading branch information
fredrik-johansson authored Jan 25, 2024
2 parents f9fcca0 + 5f283d7 commit a176772
Show file tree
Hide file tree
Showing 43 changed files with 737 additions and 506 deletions.
2 changes: 1 addition & 1 deletion AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ Mathieu Gouttenoire

Primality testing for Gaussian integers.

github math-gout
github math-gout

Michael Abshoff

Expand Down
19 changes: 14 additions & 5 deletions doc/source/fmpz_vec.rst
Original file line number Diff line number Diff line change
Expand Up @@ -408,13 +408,22 @@ Gaussian content
Dot product
--------------------------------------------------------------------------------

.. function:: void _fmpz_vec_dot_general_naive(fmpz_t res, const fmpz_t initial, int subtract, const fmpz * a, const fmpz * b, int reverse, slong len)
void _fmpz_vec_dot_general(fmpz_t res, const fmpz_t initial, int subtract, const fmpz * a, const fmpz * b, int reverse, slong len)

Computes the dot product of the vectors *a* and *b*, setting
*res* to `s + (-1)^{subtract} \sum_{i=0}^{len-1} a_i b_i`.
The initial term *s* is optional and can be
omitted by passing *NULL* (equivalently, `s = 0`).
The parameter *subtract* must be 0 or 1.
If the *reverse* flag is 1, the second vector is reversed.

Aliasing is allowed between ``res`` and ``initial`` but not
between ``res`` and the entries of ``a`` and ``b``.

The *naive* version is used for testing purposes.

.. function:: void _fmpz_vec_dot(fmpz_t res, const fmpz * vec1, const fmpz * vec2, slong len2)

Sets ``res`` to the dot product of ``(vec1, len2)`` and
``(vec2, len2)``.

.. function:: void _fmpz_vec_dot_ptr(fmpz_t res, const fmpz * vec1, fmpz ** const vec2, slong offset, slong len)

Sets ``res`` to the dot product of ``len`` values at ``vec1`` and the
``len`` values ``vec2[i] + offset`` for `0 \leq i < len`.
9 changes: 3 additions & 6 deletions src/arith/stirling1.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,15 @@
#include "arith.h"

/* compute single coefficient in polynomial product */
static void
FLINT_FORCE_INLINE void
_fmpz_poly_mulmid_single(fmpz_t res, const fmpz * poly1, slong len1, const fmpz * poly2, slong len2, slong i)
{
slong j, top1, top2;
slong top1, top2;

top1 = FLINT_MIN(len1 - 1, i);
top2 = FLINT_MIN(len2 - 1, i);

fmpz_mul(res, poly1 + i - top2, poly2 + top2);

for (j = 1; j < top1 + top2 - i + 1; j++)
fmpz_addmul(res, poly1 + i - top2 + j, poly2 + top2 - j);
_fmpz_vec_dot_general(res, NULL, 0, poly1 + i - top2, poly2 + i - top1, 1, top1 + top2 - i + 1);
}

#define MAX_BASECASE 16
Expand Down
9 changes: 3 additions & 6 deletions src/fmpq_poly/exp_series.c
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ _fmpq_poly_exp_series_basecase_deriv(fmpz * B, fmpz_t Bden,
const fmpz * Aprime, const fmpz_t Aden, slong Alen, slong n)
{
fmpz_t t, u;
slong j, k;
slong k;

Alen = FLINT_MIN(Alen, n);

Expand All @@ -55,11 +55,8 @@ _fmpq_poly_exp_series_basecase_deriv(fmpz * B, fmpz_t Bden,

for (k = 1; k < n; k++)
{
fmpz_mul(t, Aprime, B + k - 1);

for (j = 2; j < FLINT_MIN(Alen, k + 1); j++)
fmpz_addmul(t, Aprime + j - 1, B + k - j);

slong l = FLINT_MIN(Alen - 1, k);
_fmpz_vec_dot_general(t, NULL, 0, Aprime, B + k - l, 1, l);
fmpz_mul_ui(u, Aden, k);
fmpz_divexact(B + k, t, u);
}
Expand Down
22 changes: 7 additions & 15 deletions src/fmpq_poly/power_sums.c
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/

#include "fmpz_vec.h"
#include "fmpz_poly.h"
#include "fmpq.h"
#include "fmpq_poly.h"
Expand Down Expand Up @@ -61,29 +62,20 @@ _fmpq_poly_power_sums(fmpz * res, fmpz_t rden, const fmpz * poly, slong len,

for (k = 1; k < FLINT_MIN(n, len); k++)
{
fmpz_mul_ui(res + k, poly + len - 1 - k, k);
fmpz_mul_si(res + k, poly + len - 1 - k, -k);
fmpz_mul(res + k, res + k, rden);

for (i = 1; i < k - 1; i++)
fmpz_mul(res + i, res + i, poly + len - 1);
for (i = 1; i < k; i++)
fmpz_addmul(res + k, poly + len - 1 - k + i, res + i);
fmpz_neg(res + k, res + k);
_fmpz_vec_scalar_mul_fmpz(res + 1, res + 1, k - 2, poly + len - 1);
_fmpz_vec_dot_general(res + k, res + k, 1, poly + len - 1 - k + 1, res + 1, 0, k - 1);
fmpz_mul(rden, rden, poly + len - 1);
}

for (k = len; k < n; k++)
{
fmpz_zero(res + k);
for (i = k - len + 1; i < k - 1; i++)
fmpz_mul(res + i, res + i, poly + len - 1);
for (i = k - len + 1; i < k; i++)
fmpz_addmul(res + k, poly + len - 1 - k + i, res + i);
fmpz_neg(res + k, res + k);
_fmpz_vec_scalar_mul_fmpz(res + k - len + 1, res + k - len + 1, len - 2, poly + len - 1);
_fmpz_vec_dot_general(res + k, NULL, 1, poly, res + k - len + 1, 0, len - 1);
}

for (i = n - len + 1; i < n - 1; i++)
fmpz_mul(res + i, res + i, poly + len - 1);
_fmpz_vec_scalar_mul_fmpz(res + n - len + 1, res + n - len + 1, len - 2, poly + len - 1);
fmpz_one(rden);

for (i = n - len; i > 0; i--)
Expand Down
13 changes: 9 additions & 4 deletions src/fmpq_poly/power_sums_to_poly.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

#include "ulong_extras.h"
#include "fmpz.h"
#include "fmpz_vec.h"
#include "fmpz_poly.h"
#include "fmpq_poly.h"

Expand All @@ -30,12 +31,16 @@ _fmpq_poly_power_sums_to_poly(fmpz * res, const fmpz * poly, const fmpz_t den,
fmpz_one(f);
for (k = 1; k <= d; k++)
{
if(k < len)
if (k < len)
{
fmpz_mul(res + d - k, poly + k, f);
_fmpz_vec_dot_general(res + d - k, res + d - k, 0, res + d - k + 1, poly + 1, 0, k - 1);

}
else
fmpz_zero(res + d - k);
for (i = 1; i < FLINT_MIN(k, len); i++)
fmpz_addmul(res + d - k, res + d - k + i, poly + i);
{
_fmpz_vec_dot_general(res + d - k, NULL, 0, res + d - k + 1, poly + 1, 0, len - 1);
}

a = n_gcd(FLINT_ABS(fmpz_fdiv_ui(res + d - k, k)), k);
fmpz_divexact_ui(res + d - k, res + d - k, a);
Expand Down
13 changes: 3 additions & 10 deletions src/fmpq_poly/revert_series_lagrange_fast.c
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ void
_fmpq_poly_revert_series_lagrange_fast(fmpz * Qinv, fmpz_t den,
const fmpz * Q, const fmpz_t Qden, slong Qlen, slong n)
{
slong i, j, k, m;
slong i, j, m;
fmpz *R, *Rden, *S, *T, *dens, *tmp;
fmpz_t Sden, Tden, t;
fmpz_t Sden, Tden;

if (Qlen <= 2)
{
Expand All @@ -65,7 +65,6 @@ _fmpq_poly_revert_series_lagrange_fast(fmpz * Qinv, fmpz_t den,

m = n_sqrt(n);

fmpz_init(t);
dens = _fmpz_vec_init(n);
R = _fmpz_vec_init((n - 1) * m);
S = _fmpz_vec_init(n - 1);
Expand Down Expand Up @@ -103,12 +102,7 @@ _fmpq_poly_revert_series_lagrange_fast(fmpz * Qinv, fmpz_t den,

for (j = 1; j < m && i + j < n; j++)
{
fmpz_mul(t, S + 0, Ri(j) + i + j - 1);

for (k = 1; k <= i + j - 1; k++)
fmpz_addmul(t, S + k, Ri(j) + i + j - 1 - k);

fmpz_set(Qinv + i + j, t);
_fmpz_vec_dot_general(Qinv + i + j, NULL, 0, S, Ri(j), 1, i + j);
fmpz_mul(dens + i + j, Sden, Rdeni(j));
fmpz_mul_ui(dens + i + j, dens + i + j, i + j);
}
Expand All @@ -126,7 +120,6 @@ _fmpq_poly_revert_series_lagrange_fast(fmpz * Qinv, fmpz_t den,
_set_vec(Qinv, den, Qinv, dens, n);
_fmpq_poly_canonicalise(Qinv, den, n);

fmpz_clear(t);
_fmpz_vec_clear(dens, n);
_fmpz_vec_clear(R, (n - 1) * m);
_fmpz_vec_clear(S, n - 1);
Expand Down
1 change: 1 addition & 0 deletions src/fmpq_poly/sin_cos_series.c
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ _fmpq_poly_sin_cos_series_basecase_can(fmpz * S, fmpz_t Sden,
fmpz_zero(t);
fmpz_zero(u);

/* todo: precompute A[j] * j, use dot products */
for (j = 1; j < FLINT_MIN(Alen, k + 1); j++)
{
fmpz_mul_ui(v, A + j, j);
Expand Down
3 changes: 1 addition & 2 deletions src/fmpz_mat.h
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,7 @@ void fmpz_mat_mul_classical(fmpz_mat_t C, const fmpz_mat_t A,

void fmpz_mat_mul_strassen(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B);

void fmpz_mat_mul_classical_inline(fmpz_mat_t C, const fmpz_mat_t A,
const fmpz_mat_t B);
#define fmpz_mat_mul_classical_inline _Pragma("GCC error \"'fmpz_mat_mul_classical_inline' is deprecated. Use 'fmpz_mat_mul_classical' instead.\"")

void _fmpz_mat_mul_fft(fmpz_mat_t C,
const fmpz_mat_t A, slong abits,
Expand Down
77 changes: 5 additions & 72 deletions src/fmpz_mat/charpoly.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "fmpz_vec.h"
#include "fmpz_mat.h"
#include "fmpz_poly.h"
#include "gr.h"
#include "gr_mat.h"

/*
Assumes that \code{mat} is an $n \times n$ matrix and sets \code{(cp,n+1)}
Expand All @@ -27,78 +29,9 @@

void _fmpz_mat_charpoly_berkowitz(fmpz *cp, const fmpz_mat_t mat)
{
const slong n = mat->r;

if (n == 0)
{
fmpz_one(cp);
}
else if (n == 1)
{
fmpz_neg(cp + 0, fmpz_mat_entry(mat, 0, 0));
fmpz_one(cp + 1);
}
else
{
slong i, j, k, t;
fmpz *a, *A, *s;

a = _fmpz_vec_init(n * n);
A = a + (n - 1) * n;

_fmpz_vec_zero(cp, n + 1);
fmpz_neg(cp + 0, fmpz_mat_entry(mat, 0, 0));

for (t = 1; t < n; t++)
{
for (i = 0; i <= t; i++)
{
fmpz_set(a + 0 * n + i, fmpz_mat_entry(mat, i, t));
}

fmpz_set(A + 0, fmpz_mat_entry(mat, t, t));

for (k = 1; k < t; k++)
{
for (i = 0; i <= t; i++)
{
s = a + k * n + i;
fmpz_zero(s);
for (j = 0; j <= t; j++)
{
fmpz_addmul(s, fmpz_mat_entry(mat, i, j), a + (k - 1) * n + j);
}
}
fmpz_set(A + k, a + k * n + t);
}

fmpz_zero(A + t);
for (j = 0; j <= t; j++)
{
fmpz_addmul(A + t, fmpz_mat_entry(mat, t, j), a + (t - 1) * n + j);
}

for (k = 0; k <= t; k++)
{
for (j = 0; j < k; j++)
{
fmpz_submul(cp + k, A + j, cp + (k - j - 1));
}
fmpz_sub(cp + k, cp + k, A + k);
}
}

/* Shift all coefficients up by one */
for (i = n; i > 0; i--)
{
fmpz_swap(cp + i, cp + (i - 1));
}
fmpz_one(cp + 0);

_fmpz_poly_reverse(cp, cp, n + 1, n + 1);

_fmpz_vec_clear(a, n * n);
}
gr_ctx_t ctx;
gr_ctx_init_fmpz(ctx);
GR_MUST_SUCCEED(_gr_mat_charpoly_berkowitz(cp, (const gr_mat_struct *) mat, ctx));
}

void fmpz_mat_charpoly_berkowitz(fmpz_poly_t cp, const fmpz_mat_t mat)
Expand Down
2 changes: 1 addition & 1 deletion src/fmpz_mat/mul.c
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,6 @@ fmpz_mat_mul(fmpz_mat_t C, const fmpz_mat_t A, const fmpz_mat_t B)
else if (abits >= 500 && bbits >= 500 && dim >= 8) /* tuning param */
fmpz_mat_mul_strassen(C, A, B);
else
fmpz_mat_mul_classical_inline(C, A, B);
fmpz_mat_mul_classical(C, A, B);
}
}
Loading

0 comments on commit a176772

Please sign in to comment.