/* mpzspm.c - "mpz small prime moduli" - pick a set of small primes large
   enough to represent a mpzv

  Copyright 2005 Dave Newman.

  The SP Library is free software; you can redistribute it and/or modify
  it under the terms of the GNU Lesser General Public License as published by
  the Free Software Foundation; either version 2.1 of the License, or (at your
  option) any later version.

  The SP Library is distributed in the hope that it will be useful, but
  WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
  or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU Lesser General Public
  License for more details.

  You should have received a copy of the GNU Lesser General Public License
  along with the SP Library; see the file COPYING.LIB.  If not, write to
  the Free Software Foundation, Inc., 59 Temple Place - Suite 330, Boston,
  MA 02111-1307, USA.
*/

#include <stdio.h> /* for printf */
#include <stdlib.h>
#include "sp.h"

mpzspm_t
mpzspm_init (spv_size_t max_len, mpz_t modulus)
{
  unsigned int ub, i, j;
  mpz_t P, S, T;
  sp_t p, a;
  mpzspm_t mpzspm;
  
  mpzspm = (mpzspm_t) malloc (sizeof (__mpzspm_struct));
  
  /* Upper bound for the number of primes we need.
   * Let minp, maxp denote the min, max permissible prime,
   * S the sum of p_1, p_2, ..., p_ub,
   * P the product of p_1, p_2, ..., p_ub/
   * 
   * Choose ub s.t.
   *
   *     ub * log(minp) >= log(4 * max_len * modulus^2 * maxp^4)
   * 
   * =>  P >= minp ^ ub >= 4 * max_len * modulus^2 * maxp^4
   *                    >= 4 * max_len * modulus^2 * (ub * maxp)^2
   *                    >= 4 * max_len * modulus^2 * S^2
   * 
   * So we need at most ub primes to satisfy this condition. */
  
  ub = (2 + 2 * mpz_sizeinbase (modulus, 2) + ceil_log_2 (max_len) + \
      4 * SP_NUMB_BITS) / (SP_NUMB_BITS - 1);
  
  mpzspm->spm = (spm_t *) malloc (ub * sizeof (spm_t));
  mpzspm->sp_num = 0;

  /* product of primes selected so far */
  mpz_init_set_ui (P, 1);
  /* sum of primes selected so far */
  mpz_init (S);
  mpz_init (T); 
  mpz_mul (T, modulus, modulus);
  mpz_mul_ui (T, T, max_len);
  
  /* find primes congruent to 1 mod max_len so we can do
   * a ntt of size max_len */
  p = 1;
  do
    {
      do
        p -= max_len;
      while (!sp_prime(p));
      
      /* all primes must have top bit set */
      if (p < SP_MIN)
        {
	  printf ("not enough primes in interval\n");
	  return NULL;
	}
      
      mpzspm->spm[mpzspm->sp_num++] = spm_init (max_len, p);
      
      mpz_mul_ui (P, P, p);
      mpz_add_ui (S, S, p);

      /* we want P > 4 * max_len * (modulus * S)^2 */
      mpz_mul (T, S, modulus);
      mpz_mul (T, T, T);
      mpz_mul_ui (T, T, max_len);
      mpz_add (T, T, T);
      mpz_add (T, T, T);
    }
  while (mpz_cmp (P, T) <= 0);

  mpz_init_set (mpzspm->modulus, modulus);
  
  mpzspm->max_ntt_size = max_len;
  
  mpzspm->crt1 = (mpzv_t) malloc (mpzspm->sp_num * sizeof (mpz_t));
  mpzspm->crt2 = (mpzv_t) malloc ((mpzspm->sp_num + 2) * sizeof (mpz_t));
  mpzspm->crt3 = (spv_t) malloc (mpzspm->sp_num * sizeof (sp_t));
  mpzspm->crt4 = (spv_t *) malloc (mpzspm->sp_num * sizeof (spv_t));
  mpzspm->crt5 = (spv_t) malloc (mpzspm->sp_num * sizeof (sp_t));

  for (i = 0; i < mpzspm->sp_num; i++)
    mpzspm->crt4[i] = (spv_t) malloc (mpzspm->sp_num * sizeof (sp_t));
  
  for (i = 0; i < mpzspm->sp_num; i++)
    {
      p = mpzspm->spm[i]->sp;
      
      /* crt3[i] = (P / p)^{-1} mod p */
      mpz_fdiv_q_ui (T, P, p);
      a = mpz_fdiv_ui (T, p);
      mpzspm->crt3[i] = sp_inv (a, p, mpzspm->spm[i]->mul_c);
     
      /* crt1[i] = (P / p) mod modulus */
      mpz_init (mpzspm->crt1[i]);
      mpz_mod (mpzspm->crt1[i], T, modulus);

      /* crt4[i][j] = ((P / p[i]) mod modulus) mod p[j] */
      for (j = 0; j < mpzspm->sp_num; j++)
	mpzspm->crt4[j][i] = mpz_fdiv_ui (mpzspm->crt1[i], mpzspm->spm[j]->sp);
      
      /* crt5[i] = (-P mod modulus) mod p */
      mpz_mod (T, P, modulus);
      mpz_sub (T, modulus, T);
      mpzspm->crt5[i] = mpz_fdiv_ui (T, p);
    }
  
  mpz_set_ui (T, 0);
  
  for (i = 0; i < mpzspm->sp_num + 2; i++)
    {
      mpz_mod (T, T, modulus);
      mpz_init_set (mpzspm->crt2[i], T);
      mpz_sub (T, T, P);
    }
  
  mpz_clear (P);
  mpz_clear (S);
  mpz_clear (T);

  return mpzspm;
}

void mpzspm_clear (mpzspm_t mpzspm)
{
  unsigned int i;

  for (i = 0; i < mpzspm->sp_num; i++)
    {
      mpz_clear (mpzspm->crt1[i]);
      free (mpzspm->crt4[i]);
      spm_clear (mpzspm->spm[i]);
    }

  for (i = 0; i < mpzspm->sp_num + 2; i++)
    mpz_clear (mpzspm->crt2[i]);
  
  free (mpzspm->crt1);
  free (mpzspm->crt2);
  free (mpzspm->crt3);
  free (mpzspm->crt4);
  free (mpzspm->crt5);
  
  mpz_clear (mpzspm->modulus);
  free (mpzspm->spm);
  free (mpzspm);
}



syntax highlighted by Code2HTML, v. 0.9.1