#include <NTL/ZZX.h>

#include <NTL/new.h>

NTL_START_IMPL



const ZZX& ZZX::zero()
{
   static ZZX z;
   return z;
}



void conv(ZZ_pX& x, const ZZX& a)
{
   conv(x.rep, a.rep);
   x.normalize();
}

void conv(ZZX& x, const ZZ_pX& a)
{
   conv(x.rep, a.rep);
   x.normalize();
}


istream& operator>>(istream& s, ZZX& x)
{
   s >> x.rep;
   x.normalize();
   return s;
}

ostream& operator<<(ostream& s, const ZZX& a)
{
   return s << a.rep;
}


void ZZX::normalize()
{
   long n;
   const ZZ* p;

   n = rep.length();
   if (n == 0) return;
   p = rep.elts() + n;
   while (n > 0 && IsZero(*--p)) {
      n--;
   }
   rep.SetLength(n);
}


long IsZero(const ZZX& a)
{
   return a.rep.length() == 0;
}


long IsOne(const ZZX& a)
{
    return a.rep.length() == 1 && IsOne(a.rep[0]);
}

long operator==(const ZZX& a, const ZZX& b)
{
   long i, n;
   const ZZ *ap, *bp;

   n = a.rep.length();
   if (n != b.rep.length()) return 0;

   ap = a.rep.elts();
   bp = b.rep.elts();

   for (i = 0; i < n; i++)
      if (ap[i] != bp[i]) return 0;

   return 1;
}


long operator==(const ZZX& a, long b)
{
   if (b == 0)
      return IsZero(a);

   if (deg(a) != 0)
      return 0;

   return a.rep[0] == b;
}

long operator==(const ZZX& a, const ZZ& b)
{
   if (IsZero(b))
      return IsZero(a);

   if (deg(a) != 0)
      return 0;

   return a.rep[0] == b;
}


void GetCoeff(ZZ& x, const ZZX& a, long i)
{
   if (i < 0 || i > deg(a))
      clear(x);
   else
      x = a.rep[i];
}

void SetCoeff(ZZX& x, long i, const ZZ& a)
{
   long j, m;

   if (i < 0) 
      Error("SetCoeff: negative index");

   if (NTL_OVERFLOW(i, 1, 0))
      Error("overflow in SetCoeff");

   m = deg(x);

   if (i > m) {
      /* careful: a may alias a coefficient of x */

      long alloc = x.rep.allocated();

      if (alloc > 0 && i >= alloc) {
         ZZ aa = a;
         x.rep.SetLength(i+1);
         x.rep[i] = aa;
      }
      else {
         x.rep.SetLength(i+1);
         x.rep[i] = a;
      }
         
      for (j = m+1; j < i; j++)
         clear(x.rep[j]);
   }
   else
      x.rep[i] = a;

   x.normalize();
}


void SetCoeff(ZZX& x, long i)
{
   long j, m;

   if (i < 0) 
      Error("coefficient index out of range");

   if (NTL_OVERFLOW(i, 1, 0))
      Error("overflow in SetCoeff");

   m = deg(x);

   if (i > m) {
      x.rep.SetLength(i+1);
      for (j = m+1; j < i; j++)
         clear(x.rep[j]);
   }
   set(x.rep[i]);
   x.normalize();
}


void SetX(ZZX& x)
{
   clear(x);
   SetCoeff(x, 1);
}


long IsX(const ZZX& a)
{
   return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));
}
      
      

const ZZ& coeff(const ZZX& a, long i)
{
   if (i < 0 || i > deg(a))
      return ZZ::zero();
   else
      return a.rep[i];
}


const ZZ& LeadCoeff(const ZZX& a)
{
   if (IsZero(a))
      return ZZ::zero();
   else
      return a.rep[deg(a)];
}

const ZZ& ConstTerm(const ZZX& a)
{
   if (IsZero(a))
      return ZZ::zero();
   else
      return a.rep[0];
}



void conv(ZZX& x, const ZZ& a)
{
   if (IsZero(a))
      x.rep.SetLength(0);
   else {
      x.rep.SetLength(1);
      x.rep[0] = a;
   }
}


void conv(ZZX& x, long a)
{
   if (a == 0) 
      x.rep.SetLength(0);
   else {
      x.rep.SetLength(1);
      conv(x.rep[0], a);
   }
}


void conv(ZZX& x, const vec_ZZ& a)
{
   x.rep = a;
   x.normalize();
}


void add(ZZX& x, const ZZX& a, const ZZX& b)
{
   long da = deg(a);
   long db = deg(b);
   long minab = min(da, db);
   long maxab = max(da, db);
   x.rep.SetLength(maxab+1);

   long i;
   const ZZ *ap, *bp; 
   ZZ* xp;

   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
        i; i--, ap++, bp++, xp++)
      add(*xp, (*ap), (*bp));

   if (da > minab && &x != &a)
      for (i = da-minab; i; i--, xp++, ap++)
         *xp = *ap;
   else if (db > minab && &x != &b)
      for (i = db-minab; i; i--, xp++, bp++)
         *xp = *bp;
   else
      x.normalize();
}

void add(ZZX& x, const ZZX& a, const ZZ& b)
{
   long n = a.rep.length();
   if (n == 0) {
      conv(x, b);
   }
   else if (&x == &a) {
      add(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else if (x.rep.MaxLength() == 0) {
      x = a;
      add(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else {
      // ugly...b could alias a coeff of x

      ZZ *xp = x.rep.elts(); 
      add(xp[0], a.rep[0], b);
      x.rep.SetLength(n);
      xp = x.rep.elts();
      const ZZ *ap = a.rep.elts();
      long i;
      for (i = 1; i < n; i++)
         xp[i] = ap[i];
      x.normalize();
   }
}


void add(ZZX& x, const ZZX& a, long b)
{
   if (a.rep.length() == 0) {
      conv(x, b);
   }
   else {
      if (&x != &a) x = a;
      add(x.rep[0], x.rep[0], b);
      x.normalize();
   }
}


void sub(ZZX& x, const ZZX& a, const ZZX& b)
{
   long da = deg(a);
   long db = deg(b);
   long minab = min(da, db);
   long maxab = max(da, db);
   x.rep.SetLength(maxab+1);

   long i;
   const ZZ *ap, *bp; 
   ZZ* xp;

   for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
        i; i--, ap++, bp++, xp++)
      sub(*xp, (*ap), (*bp));

   if (da > minab && &x != &a)
      for (i = da-minab; i; i--, xp++, ap++)
         *xp = *ap;
   else if (db > minab)
      for (i = db-minab; i; i--, xp++, bp++)
         negate(*xp, *bp);
   else
      x.normalize();

}

void sub(ZZX& x, const ZZX& a, const ZZ& b)
{
   long n = a.rep.length();
   if (n == 0) {
      conv(x, b);
      negate(x, x);
   }
   else if (&x == &a) {
      sub(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else if (x.rep.MaxLength() == 0) {
      x = a;
      sub(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else {
      // ugly...b could alias a coeff of x

      ZZ *xp = x.rep.elts();
      sub(xp[0], a.rep[0], b);
      x.rep.SetLength(n);
      xp = x.rep.elts();
      const ZZ *ap = a.rep.elts();
      long i;
      for (i = 1; i < n; i++)
         xp[i] = ap[i];
      x.normalize();
   }
}

void sub(ZZX& x, const ZZX& a, long b)
{
   if (b == 0) {
      x = a;
      return;
   }

   if (a.rep.length() == 0) {
      x.rep.SetLength(1);
      conv(x.rep[0], b);
      negate(x.rep[0], x.rep[0]);
   }
   else {
      if (&x != &a) x = a;
      sub(x.rep[0], x.rep[0], b);
   }
   x.normalize();
}

void sub(ZZX& x, long a, const ZZX& b)
{
   negate(x, b);
   add(x, x, a);
}


void sub(ZZX& x, const ZZ& b, const ZZX& a)
{
   long n = a.rep.length();
   if (n == 0) {
      conv(x, b);
   }
   else if (x.rep.MaxLength() == 0) {
      negate(x, a);
      add(x.rep[0], a.rep[0], b);
      x.normalize();
   }
   else {
      // ugly...b could alias a coeff of x

      ZZ *xp = x.rep.elts();
      sub(xp[0], b, a.rep[0]);
      x.rep.SetLength(n);
      xp = x.rep.elts();
      const ZZ *ap = a.rep.elts();
      long i;
      for (i = 1; i < n; i++)
         negate(xp[i], ap[i]);
      x.normalize();
   }
}



void negate(ZZX& x, const ZZX& a)
{
   long n = a.rep.length();
   x.rep.SetLength(n);

   const ZZ* ap = a.rep.elts();
   ZZ* xp = x.rep.elts();
   long i;

   for (i = n; i; i--, ap++, xp++)
      negate((*xp), (*ap));

}

long MaxBits(const ZZX& f)
{
   long i, m;
   m = 0;

   for (i = 0; i <= deg(f); i++) {
      m = max(m, NumBits(f.rep[i]));
   }

   return m;
}


void PlainMul(ZZX& x, const ZZX& a, const ZZX& b)
{
   if (&a == &b) {
      PlainSqr(x, a);
      return;
   }

   long da = deg(a);
   long db = deg(b);

   if (da < 0 || db < 0) {
      clear(x);
      return;
   }

   long d = da+db;



   const ZZ *ap, *bp;
   ZZ *xp;
   
   ZZX la, lb;

   if (&x == &a) {
      la = a;
      ap = la.rep.elts();
   }
   else
      ap = a.rep.elts();

   if (&x == &b) {
      lb = b;
      bp = lb.rep.elts();
   }
   else
      bp = b.rep.elts();

   x.rep.SetLength(d+1);

   xp = x.rep.elts();

   long i, j, jmin, jmax;
   ZZ t, accum;

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-db);
      jmax = min(da, i);
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
	 mul(t, ap[j], bp[i-j]);
	 add(accum, accum, t);
      }
      xp[i] = accum;
   }
   x.normalize();
}

void PlainSqr(ZZX& x, const ZZX& a)
{
   long da = deg(a);

   if (da < 0) {
      clear(x);
      return;
   }

   long d = 2*da;

   const ZZ *ap;
   ZZ *xp;

   ZZX la;

   if (&x == &a) {
      la = a;
      ap = la.rep.elts();
   }
   else
      ap = a.rep.elts();


   x.rep.SetLength(d+1);

   xp = x.rep.elts();

   long i, j, jmin, jmax;
   long m, m2;
   ZZ t, accum;

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-da);
      jmax = min(da, i);
      m = jmax - jmin + 1;
      m2 = m >> 1;
      jmax = jmin + m2 - 1;
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
	 mul(t, ap[j], ap[i-j]);
	 add(accum, accum, t);
      }
      add(accum, accum, accum);
      if (m & 1) {
	 sqr(t, ap[jmax + 1]);
	 add(accum, accum, t);
      }

      xp[i] = accum;
   }

   x.normalize();
}



static
void PlainMul(ZZ *xp, const ZZ *ap, long sa, const ZZ *bp, long sb)
{
   if (sa == 0 || sb == 0) return;

   long sx = sa+sb-1;

   long i, j, jmin, jmax;
   static ZZ t, accum;

   for (i = 0; i < sx; i++) {
      jmin = max(0, i-sb+1);
      jmax = min(sa-1, i);
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
         mul(t, ap[j], bp[i-j]);
         add(accum, accum, t);
      }
      xp[i] = accum;
   }
}



static
void KarFold(ZZ *T, const ZZ *b, long sb, long hsa)
{
   long m = sb - hsa;
   long i;

   for (i = 0; i < m; i++)
      add(T[i], b[i], b[hsa+i]);

   for (i = m; i < hsa; i++)
      T[i] = b[i];
}

static
void KarSub(ZZ *T, const ZZ *b, long sb)
{
   long i;

   for (i = 0; i < sb; i++)
      sub(T[i], T[i], b[i]);
}

static
void KarAdd(ZZ *T, const ZZ *b, long sb)
{
   long i;

   for (i = 0; i < sb; i++)
      add(T[i], T[i], b[i]);
}

static
void KarFix(ZZ *c, const ZZ *b, long sb, long hsa)
{
   long i;

   for (i = 0; i < hsa; i++)
      c[i] = b[i];

   for (i = hsa; i < sb; i++)
      add(c[i], c[i], b[i]);
}

static void PlainMul1(ZZ *xp, const ZZ *ap, long sa, const ZZ& b)
{
   long i;

   for (i = 0; i < sa; i++)
      mul(xp[i], ap[i], b);
}



static
void KarMul(ZZ *c, const ZZ *a, 
            long sa, const ZZ *b, long sb, ZZ *stk)
{
   if (sa < sb) {
      { long t = sa; sa = sb; sb = t; }
      { const ZZ *t = a; a = b; b = t; }
   }

   if (sb == 1) {
      if (sa == 1)
         mul(*c, *a, *b);
      else
         PlainMul1(c, a, sa, *b);

      return;
   }

   if (sb == 2 && sa == 2) {
      mul(c[0], a[0], b[0]);
      mul(c[2], a[1], b[1]);
      add(stk[0], a[0], a[1]);
      add(stk[1], b[0], b[1]);
      mul(c[1], stk[0], stk[1]);
      sub(c[1], c[1], c[0]);
      sub(c[1], c[1], c[2]);

      return;

   }

   long hsa = (sa + 1) >> 1;

   if (hsa < sb) {
      /* normal case */

      long hsa2 = hsa << 1;

      ZZ *T1, *T2, *T3;

      T1 = stk; stk += hsa;
      T2 = stk; stk += hsa;
      T3 = stk; stk += hsa2 - 1;

      /* compute T1 = a_lo + a_hi */

      KarFold(T1, a, sa, hsa);

      /* compute T2 = b_lo + b_hi */

      KarFold(T2, b, sb, hsa);

      /* recursively compute T3 = T1 * T2 */

      KarMul(T3, T1, hsa, T2, hsa, stk);

      /* recursively compute a_hi * b_hi into high part of c */
      /* and subtract from T3 */

      KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);
      KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);


      /* recursively compute a_lo*b_lo into low part of c */
      /* and subtract from T3 */

      KarMul(c, a, hsa, b, hsa, stk);
      KarSub(T3, c, hsa2 - 1);

      clear(c[hsa2 - 1]);

      /* finally, add T3 * X^{hsa} to c */

      KarAdd(c+hsa, T3, hsa2-1);
   }
   else {
      /* degenerate case */

      ZZ *T;

      T = stk; stk += hsa + sb - 1;

      /* recursively compute b*a_hi into high part of c */

      KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk);

      /* recursively compute b*a_lo into T */

      KarMul(T, a, hsa, b, sb, stk);

      KarFix(c, T, hsa + sb - 1, hsa);
   }
}

void KarMul(ZZX& c, const ZZX& a, const ZZX& b)
{
   if (IsZero(a) || IsZero(b)) {
      clear(c);
      return;
   }

   if (&a == &b) {
      KarSqr(c, a);
      return;
   }

   vec_ZZ mem;

   const ZZ *ap, *bp;
   ZZ *cp;

   long sa = a.rep.length();
   long sb = b.rep.length();

   if (&a == &c) {
      mem = a.rep;
      ap = mem.elts();
   }
   else
      ap = a.rep.elts();

   if (&b == &c) {
      mem = b.rep;
      bp = mem.elts();
   }
   else
      bp = b.rep.elts();

   c.rep.SetLength(sa+sb-1);
   cp = c.rep.elts();

   long maxa, maxb, xover;

   maxa = MaxBits(a);
   maxb = MaxBits(b);
   xover = 2;

   if (sa < xover || sb < xover)
      PlainMul(cp, ap, sa, bp, sb);
   else {
      /* karatsuba */

      long n, hn, sp, depth;

      n = max(sa, sb);
      sp = 0;
      depth = 0;
      do {
         hn = (n+1) >> 1;
         sp += (hn << 2) - 1;
         n = hn;
         depth++;
      } while (n >= xover);

      ZZVec stk;
      stk.SetSize(sp, 
         ((maxa + maxb + NumBits(min(sa, sb)) + 2*depth + 10) 
          + NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);

      KarMul(cp, ap, sa, bp, sb, stk.elts());
   }

   c.normalize();
}






void PlainSqr(ZZ* xp, const ZZ* ap, long sa)
{
   if (sa == 0) return;

   long da = sa-1;
   long d = 2*da;

   long i, j, jmin, jmax;
   long m, m2;
   static ZZ t, accum;

   for (i = 0; i <= d; i++) {
      jmin = max(0, i-da);
      jmax = min(da, i);
      m = jmax - jmin + 1;
      m2 = m >> 1;
      jmax = jmin + m2 - 1;
      clear(accum);
      for (j = jmin; j <= jmax; j++) {
	 mul(t, ap[j], ap[i-j]);
	 add(accum, accum, t);
      }
      add(accum, accum, accum);
      if (m & 1) {
	 sqr(t, ap[jmax + 1]);
	 add(accum, accum, t);
      }

      xp[i] = accum;
   }
}


static
void KarSqr(ZZ *c, const ZZ *a, long sa, ZZ *stk)
{
   if (sa == 1) {
      sqr(*c, *a);
      return;
   }

   if (sa == 2) {
      sqr(c[0], a[0]);
      sqr(c[2], a[1]);
      mul(c[1], a[0], a[1]);
      add(c[1], c[1], c[1]);

      return;
   }

   if (sa == 3) {
      sqr(c[0], a[0]);
      mul(c[1], a[0], a[1]);
      add(c[1], c[1], c[1]);
      sqr(stk[0], a[1]);
      mul(c[2], a[0], a[2]);
      add(c[2], c[2], c[2]);
      add(c[2], c[2], stk[0]);
      mul(c[3], a[1], a[2]);
      add(c[3], c[3], c[3]);
      sqr(c[4], a[2]);

      return;
 
   }

   long hsa = (sa + 1) >> 1;
   long hsa2 = hsa << 1;

   ZZ *T1, *T2;

   T1 = stk; stk += hsa;
   T2 = stk; stk += hsa2-1;

   KarFold(T1, a, sa, hsa);
   KarSqr(T2, T1, hsa, stk);


   KarSqr(c + hsa2, a+hsa, sa-hsa, stk);
   KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);


   KarSqr(c, a, hsa, stk);
   KarSub(T2, c, hsa2 - 1);

   clear(c[hsa2 - 1]);

   KarAdd(c+hsa, T2, hsa2-1);
}

      
void KarSqr(ZZX& c, const ZZX& a)
{
   if (IsZero(a)) {
      clear(c);
      return;
   }

   vec_ZZ mem;

   const ZZ *ap;
   ZZ *cp;

   long sa = a.rep.length();

   if (&a == &c) {
      mem = a.rep;
      ap = mem.elts();
   }
   else
      ap = a.rep.elts();

   c.rep.SetLength(sa+sa-1);
   cp = c.rep.elts();

   long maxa, xover;

   maxa = MaxBits(a);

   xover = 2;

   if (sa < xover)
      PlainSqr(cp, ap, sa);
   else {
      /* karatsuba */

      long n, hn, sp, depth;

      n = sa;
      sp = 0;
      depth = 0;
      do {
         hn = (n+1) >> 1;
         sp += hn+hn+hn - 1;
         n = hn;
         depth++;
      } while (n >= xover);

      ZZVec stk;
      stk.SetSize(sp, 
         ((2*maxa + NumBits(sa) + 2*depth + 10) 
          + NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);

      KarSqr(cp, ap, sa, stk.elts());
   }

   c.normalize();
}

NTL_END_IMPL


syntax highlighted by Code2HTML, v. 0.9.1