/* spv.c - "small prime vector" functions for arithmetic on vectors of
residues modulo a single small prime
Copyright 2005 Dave Newman.
The SP Library is free software; you can redistribute it and/or modify
it under the terms of the GNU Lesser General Public License as published by
the Free Software Foundation; either version 2.1 of the License, or (at your
option) any later version.
The SP Library is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public
License for more details.
You should have received a copy of the GNU Lesser General Public License
along with the SP Library; see the file COPYING.LIB. If not, write to
the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
MA 02111-1307, USA.
*/
#include <string.h> /* for memset */
#include "sp.h"
/* Routines for vectors of integers modulo r common small prime
*
* These are low-overhead routines that don't do memory allocation,
* other than for temporary variables. Unless otherwise specified, any
* of the input pointers can be equal. */
/* r = x */
void
spv_set (spv_t r, spv_t x, spv_size_t len)
{
#ifdef HAVE_MEMMOVE
/* memmove doesn't rely on the assertion below */
memmove (r, x, len * sizeof (sp_t));
#else
spv_size_t i;
ASSERT (r >= x + len || x >= r);
for (i = 0; i < len; i++)
r[i] = x[i];
#endif
}
/* r = [y, y, ... ] */
void
spv_set_sp (spv_t r, sp_t y, spv_size_t len)
{
spv_size_t i;
for (i = 0; i < len; i++)
r[i] = y;
}
void
spv_set_zero (spv_t r, spv_size_t len)
{
memset (r, 0, len * sizeof (sp_t));
}
int
spv_cmp (spv_t x, spv_t y, spv_size_t len)
{
spv_size_t i;
for (i = 0; i < len; i++)
if (x[i] != y[i])
return 1;
return 0;
}
/* r = x + y */
void
spv_add (spv_t r, spv_t x, spv_t y, spv_size_t len, sp_t m)
{
spv_size_t i;
ASSERT (r >= x + len || x >= r);
ASSERT (r >= y + len || y >= r);
for (i = 0; i < len; i++)
r[i] = sp_add (x[i], y[i], m);
}
/* r = [x[0] + y, x[1] + y, ... ] */
void
spv_add_sp (spv_t r, spv_t x, sp_t c, spv_size_t len, sp_t m)
{
spv_size_t i;
for (i = 0; i < len; i++)
r[i] = sp_add (x[i], c, m);
}
/* r = x - y */
void
spv_sub (spv_t r, spv_t x, spv_t y, spv_size_t len, sp_t m)
{
spv_size_t i;
ASSERT (r >= x + len || x >= r);
ASSERT (r >= y + len || y >= r);
for (i = 0; i < len; i++)
r[i] = sp_sub (x[i], y[i], m);
}
/* r = [x[0] - y, x[1] - y, ... ] */
void
spv_sub_sp (spv_t r, spv_t x, sp_t c, spv_size_t len, sp_t m)
{
spv_size_t i;
for (i = 0; i < len; i++)
r[i] = sp_sub (x[i], c, m);
}
/* r = [-x[0], -x[1], ... ] */
void
spv_neg (spv_t r, spv_t x, spv_size_t len, sp_t m)
{
spv_size_t i;
for (i = 0; i < len; i++)
r[i] = sp_sub (0, x[i], m);
}
/* Pointwise multiplication
* r = [x[0] * y[0], x[1] * y[1], ... ] */
void
spv_pwmul (spv_t r, spv_t x, spv_t y, spv_size_t len, sp_t m, sp_t d)
{
spv_size_t i;
ASSERT (r >= x + len || x >= r);
ASSERT (r >= y + len || y >= r);
for (i = 0; i < len; i++)
r[i] = sp_mul (x[i], y[i], m, d);
}
/* dst = src * y */
void
spv_mul_sp (spv_t r, spv_t x, sp_t c, spv_size_t len, sp_t m, sp_t d)
{
spv_size_t i;
ASSERT (r >= x + len || x >= r);
for (i = 0; i < len; i++)
r[i] = sp_mul (x[i], c, m, d);
}
#if 0
/* r += x * y */
void
spv_addmul_sp (spv_t r, spv_t x, sp_t c, spv_size_t len, sp_t m, sp_t d)
{
spv_size_t i;
sp_t t;
ASSERT (r >= x + len || x >= r);
for (i = 0; i < len; i++)
{
t = sp_mul (x[i], c, m, d);
r[i] = sp_add (r[i], t, m);
}
}
/* r -= x * y */
void
spv_submul_sp (spv_t r, spv_t x, sp_t c, spv_size_t len, sp_t m, sp_t d)
{
spv_size_t i;
sp_t t;
ASSERT (r >= x + len || x >= r);
for (i = 0; i < len; i++)
{
t = sp_mul (x[i], c, m, d);
r[i] = sp_sub (x[i], t, m);
}
}
/* r = x * y by grammar-school polynomial multiplication
*
* r must be distinct from both x and y
* x_len > 0, y_len > 0 */
void
spv_mul_basecase (spv_t r, spv_t x, spv_t y, spv_size_t x_len,
spv_size_t y_len, sp_t m, sp_t d)
{
spv_size_t i;
ASSERT (r >= x + x_len || x >= r + x_len + y_len - 1);
ASSERT (r >= y + y_len || y >= r + x_len + y_len - 1);
ASSERT (x_len > 0);
ASSERT (y_len > 0);
if (x_len > y_len)
{
spv_mul_sp (r, x, y[0], x_len, m, d);
spv_set_sp (r + x_len, 0, y_len - 1);
for (i = 1; i < y_len; i++)
spv_addmul_sp (r + i, x, y[i], x_len, m, d);
}
else
{
spv_mul_sp (r, y, x[0], y_len, m, d);
spv_set_sp (r + y_len, 0, x_len - 1);
for (i = 1; i < x_len; i++)
spv_addmul_sp (r + i, y, x[i], y_len, m, d);
}
}
/* dst = src1 * src2 by karatsuba multiplication
*
* dst must be distinct from both src1 and src2
* src1 and src2 have the same len > 0
* t is a temporary array which must be at least M(len) large, where
*
* M(1)=0, M(K) = max(3*l-1,2*l-2+M(l)) <= 2*K-1 where l = ceil(K/2).
*
* Code adapted from the mpz_t karatsuba in gmp-ecm, in which multiplies
* are substantially more expensive than additions. This is also true for
* sp's, but to a lesser extent, so it might be the case that minimising the
* total number of operations is more important than minimising the number
* of muls. */
void
spv_mul_karatsuba (spv_t r, spv_t x, spv_t y, spv_t t, spv_size_t len,
sp_t m, sp_t d)
{
spv_size_t i, k, l;
spv_t z;
ASSERT (r >= x + len || x >= r + 2 * len - 1);
ASSERT (r >= y + len || y >= r + 2 * len - 1);
ASSERT (len > 0);
/* FIXME: add assertions for t */
if (len == 1)
{
r[0] = sp_mul (x[0], y[0], m, d);
return;
}
if (len == 2)
{
t[0] = sp_add (x[0], x[1], m);
r[1] = sp_add (y[0], y[1], m);
r[1] = sp_mul (r[1], t[0], m, d);
r[0] = sp_mul (x[0], y[0], m, d);
r[2] = sp_mul (x[1], y[1], m, d);
r[1] = sp_sub (r[1], r[0], m);
r[1] = sp_sub (r[1], r[2], m);
return;
}
if (len == 3)
{
r[0] = sp_mul (x[0], y[0], m, d);
r[2] = sp_mul (x[1], y[1], m, d);
r[4] = sp_mul (x[2], y[2], m, d);
t[0] = sp_add (x[0], x[1], m);
t[1] = sp_add (y[0], y[1], m);
r[1] = sp_mul (t[0], t[1], m, d);
r[1] = sp_sub (r[1], r[0], m);
r[1] = sp_sub (r[1], r[2], m);
t[0] = sp_add (x[1], x[2], m);
t[1] = sp_add (y[1], y[2], m);
r[3] = sp_mul (t[0], t[1], m, d);
r[3] = sp_sub (r[3], r[2], m);
r[3] = sp_sub (r[3], r[4], m);
t[0] = sp_add (x[0], x[2], m);
t[1] = sp_add (y[0], y[2], m);
t[2] = sp_mul (t[0], t[1], m, d);
t[2] = sp_sub (t[2], r[0], m);
t[2] = sp_sub (t[2], r[4], m);
r[2] = sp_add (r[2], t[2], m);
return;
}
k = len / 2;
l = len - k;
z = t + 2 * l - 1;
for (i = 0; i < k; i++)
{
z[i] = sp_sub (x[i], x[l + i], m);
r[i] = sp_sub (y[i], y[l + i], m);
}
if (l > k)
{
z[k] = x[k];
r[k] = y[k];
}
spv_mul_karatsuba (t, z, r, r + l, l, m, d);
z = t + 2 * l - 2;
r[2 * l - 1] = t[2 * l - 2];
spv_mul_karatsuba (r, x, y, z, l, m, d);
spv_mul_karatsuba (r + 2 * l, x + l, y + l, z, k, m, d);
t[2 * l - 2] = r[2 * l - 1];
r[2 * l - 1] = 0;
spv_add (r + 2 * l, r + 2 * l, r + l, l - 1, m);
if (k > 1)
{
spv_add (r + l, r + 2 * l, r, l, m);
spv_add (r + 2 * l, r + 2 * l, r + 3 * l, 2 * k - 1 - l, m);
}
else
{
r[l] = sp_add (r[2 * l], r[0], m);
if (len == 3)
r[l + 1] = r[1];
}
spv_sub (r + l, r + l, t, 2 * l - 1, m);
}
/* calculate r[k], ..., r[l - 1] of the product r = x * y
*
* other coeffs in r are undefined
* r has size at least the next power of two >= prod_len
* allow l == 0 if the full product is required */
void spv_mul (spv_t r, spv_t x, spv_size_t x_len, spv_t y,
spv_size_t y_len, spv_size_t k, spv_size_t l, int monic, spm_t spm)
{
/* FIXME: add assertions */
if (l == 0)
l = x_len + y_len - 1 + monic;
ASSERT (x_len > 0);
ASSERT (y_len > 0);
ASSERT (monic == 0 || monic == 1);
ASSERT (k < l);
ASSERT (l <= x_len + y_len - 1 + monic);
x_len = MIN (l, x_len);
y_len = MIN (l, y_len);
int square = (x == y && x_len == y_len);
int equal_op;
spv_size_t prod_len = x_len + y_len - 1 + monic;
spv_size_t max_len = MAX (x_len, y_len);
/* ensure either x != r != y or r == x */
if (r == y && r != x)
{
spv_t t = y;
spv_size_t t_len = y_len;
y = x; y_len = x_len;
x = t; x_len = t_len;
}
equal_op = (r == x);
if (prod_len < MUL_NTT_THRESHOLD)
{
spv_t t = (spv_t) malloc (2 * max_len * sizeof (sp_t));
/* the original contents of x, y, these are respectively
* either x, y themselves or copies */
spv_t x0, y0;
/* we cannot rely on x or y being large enough to allow us
* to zero-pad them to max_len so if x_len != y_len then
* we copy the smaller vector */
if (equal_op || x_len < y_len)
{
/* karatsuba implementation requires r != x, r != y */
x0 = (spv_t) malloc (max_len * sizeof (sp_t));
spv_set (x0, x, x_len);
spv_set_zero (x0 + x_len, max_len - x_len);
}
else
x0 = x;
if (y_len < x_len)
{
y0 = (spv_t) malloc (max_len * sizeof (sp_t));
spv_set (y0, y, y_len);
spv_set_zero (y0 + y_len, max_len - y_len);
}
else
y0 = square ? x0 : y;
spv_mul_karatsuba (r, x0, y0, t, max_len, spm->sp, spm->mul_c);
free (t);
if (monic)
{
/* FIXME crop to range [k, l) */
r[prod_len - 1] = 0;
spv_add (r + x_len, r + x_len, y0, y_len, spm->sp);
spv_add (r + y_len, r + y_len, x0, x_len, spm->sp);
}
if (equal_op || x_len < y_len)
free (x0);
if (y_len < x_len)
free (y0);
return;
}
spv_t x_ntt, y_ntt;
spv_size_t i, ntt_size;
ntt_size = 1 << ceil_log_2 (prod_len);
/* threshold seems to give reasonably
* good results but needs fine-tuning properly */
if (prod_len < 3 * ntt_size / 4 /* && ntt_size / 2 >= x_len && ntt_size / 2 >= y_len */)
{
ntt_size >>= 1;
if (prod_len - ntt_size > k)
/* high/middle coeffs overflow into
* middle/low coeffs */
spv_mul (r + ntt_size,
x, prod_len - ntt_size,
y, prod_len - ntt_size,
0, prod_len - ntt_size, 0, spm);
else if (l > ntt_size)
/* high/middle coeffs overflow into low coeffs */
spv_mul (r + ntt_size,
x, l - ntt_size,
y, l - ntt_size,
0, l - ntt_size, 0, spm);
/* else */
/* high coeffs overflow into low coeffs */
}
/* printf ("x_len = %u, y_len = %u, k = %u, l = %u, monic = %u, prod_len = %u, ntt_size = %u\n",
x_len, y_len, k, l, monic, prod_len, ntt_size); */
/* variable x_ntt only exists to help readability */
x_ntt = r;
if (!equal_op)
spv_set (x_ntt, x, x_len);
spv_set_zero (x_ntt + x_len, ntt_size - x_len);
if (square)
{
if (monic)
x_ntt[x_len] = 1;
spv_sqr_ntt_gfp (r, x_ntt, ntt_size, spm);
}
else
{
y_ntt = (spv_t) malloc (ntt_size * sizeof (sp_t));
spv_set (y_ntt, y, y_len);
spv_set_zero (y_ntt + y_len, ntt_size - y_len);
if (monic)
x_ntt[x_len] = y_ntt[y_len] = 1;
spv_mul_ntt_gfp (r, x_ntt, y_ntt, ntt_size, spm);
free (y_ntt);
}
if (prod_len > k + ntt_size)
{
/* FIXME maybe deal with these cases separately */
/* ####............
* ######## */
/* ############....
* ....######## */
/* ########........
* ........#### */
/* ....########....
* ........ */
for (i = 0; i < prod_len - ntt_size; i++)
{
sp_t u = r[ntt_size + i];
r[ntt_size + i] = sp_sub (r[i], u, spm->sp);
r[i] = u;
}
}
else if (l > ntt_size)
{
/* ########........
* #### */
for (i = 0; i < l - ntt_size; i++)
r[ntt_size + i] = sp_sub (r[i], r[ntt_size + i], spm->sp);
}
if (monic)
/* everything is correct except for the x^{prod_len} term which may
* have wrapped round */
r[prod_len % ntt_size] = sp_sub (r[prod_len % ntt_size], 1, spm->sp);
}
#endif
void
spv_random (spv_t x, spv_size_t len, sp_t m)
{
spv_size_t i;
mpn_random (x, len);
for (i = 0; i < len; i++)
if (x[i] >= m)
x[i] -= m;
}
syntax highlighted by Code2HTML, v. 0.9.1