#include <NTL/ZZX.h>
#include <NTL/new.h>
NTL_START_IMPL
const ZZX& ZZX::zero()
{
static ZZX z;
return z;
}
void conv(ZZ_pX& x, const ZZX& a)
{
conv(x.rep, a.rep);
x.normalize();
}
void conv(ZZX& x, const ZZ_pX& a)
{
conv(x.rep, a.rep);
x.normalize();
}
istream& operator>>(istream& s, ZZX& x)
{
s >> x.rep;
x.normalize();
return s;
}
ostream& operator<<(ostream& s, const ZZX& a)
{
return s << a.rep;
}
void ZZX::normalize()
{
long n;
const ZZ* p;
n = rep.length();
if (n == 0) return;
p = rep.elts() + n;
while (n > 0 && IsZero(*--p)) {
n--;
}
rep.SetLength(n);
}
long IsZero(const ZZX& a)
{
return a.rep.length() == 0;
}
long IsOne(const ZZX& a)
{
return a.rep.length() == 1 && IsOne(a.rep[0]);
}
long operator==(const ZZX& a, const ZZX& b)
{
long i, n;
const ZZ *ap, *bp;
n = a.rep.length();
if (n != b.rep.length()) return 0;
ap = a.rep.elts();
bp = b.rep.elts();
for (i = 0; i < n; i++)
if (ap[i] != bp[i]) return 0;
return 1;
}
long operator==(const ZZX& a, long b)
{
if (b == 0)
return IsZero(a);
if (deg(a) != 0)
return 0;
return a.rep[0] == b;
}
long operator==(const ZZX& a, const ZZ& b)
{
if (IsZero(b))
return IsZero(a);
if (deg(a) != 0)
return 0;
return a.rep[0] == b;
}
void GetCoeff(ZZ& x, const ZZX& a, long i)
{
if (i < 0 || i > deg(a))
clear(x);
else
x = a.rep[i];
}
void SetCoeff(ZZX& x, long i, const ZZ& a)
{
long j, m;
if (i < 0)
Error("SetCoeff: negative index");
if (NTL_OVERFLOW(i, 1, 0))
Error("overflow in SetCoeff");
m = deg(x);
if (i > m) {
/* careful: a may alias a coefficient of x */
long alloc = x.rep.allocated();
if (alloc > 0 && i >= alloc) {
ZZ aa = a;
x.rep.SetLength(i+1);
x.rep[i] = aa;
}
else {
x.rep.SetLength(i+1);
x.rep[i] = a;
}
for (j = m+1; j < i; j++)
clear(x.rep[j]);
}
else
x.rep[i] = a;
x.normalize();
}
void SetCoeff(ZZX& x, long i)
{
long j, m;
if (i < 0)
Error("coefficient index out of range");
if (NTL_OVERFLOW(i, 1, 0))
Error("overflow in SetCoeff");
m = deg(x);
if (i > m) {
x.rep.SetLength(i+1);
for (j = m+1; j < i; j++)
clear(x.rep[j]);
}
set(x.rep[i]);
x.normalize();
}
void SetX(ZZX& x)
{
clear(x);
SetCoeff(x, 1);
}
long IsX(const ZZX& a)
{
return deg(a) == 1 && IsOne(LeadCoeff(a)) && IsZero(ConstTerm(a));
}
const ZZ& coeff(const ZZX& a, long i)
{
if (i < 0 || i > deg(a))
return ZZ::zero();
else
return a.rep[i];
}
const ZZ& LeadCoeff(const ZZX& a)
{
if (IsZero(a))
return ZZ::zero();
else
return a.rep[deg(a)];
}
const ZZ& ConstTerm(const ZZX& a)
{
if (IsZero(a))
return ZZ::zero();
else
return a.rep[0];
}
void conv(ZZX& x, const ZZ& a)
{
if (IsZero(a))
x.rep.SetLength(0);
else {
x.rep.SetLength(1);
x.rep[0] = a;
}
}
void conv(ZZX& x, long a)
{
if (a == 0)
x.rep.SetLength(0);
else {
x.rep.SetLength(1);
conv(x.rep[0], a);
}
}
void conv(ZZX& x, const vec_ZZ& a)
{
x.rep = a;
x.normalize();
}
void add(ZZX& x, const ZZX& a, const ZZX& b)
{
long da = deg(a);
long db = deg(b);
long minab = min(da, db);
long maxab = max(da, db);
x.rep.SetLength(maxab+1);
long i;
const ZZ *ap, *bp;
ZZ* xp;
for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
i; i--, ap++, bp++, xp++)
add(*xp, (*ap), (*bp));
if (da > minab && &x != &a)
for (i = da-minab; i; i--, xp++, ap++)
*xp = *ap;
else if (db > minab && &x != &b)
for (i = db-minab; i; i--, xp++, bp++)
*xp = *bp;
else
x.normalize();
}
void add(ZZX& x, const ZZX& a, const ZZ& b)
{
long n = a.rep.length();
if (n == 0) {
conv(x, b);
}
else if (&x == &a) {
add(x.rep[0], a.rep[0], b);
x.normalize();
}
else if (x.rep.MaxLength() == 0) {
x = a;
add(x.rep[0], a.rep[0], b);
x.normalize();
}
else {
// ugly...b could alias a coeff of x
ZZ *xp = x.rep.elts();
add(xp[0], a.rep[0], b);
x.rep.SetLength(n);
xp = x.rep.elts();
const ZZ *ap = a.rep.elts();
long i;
for (i = 1; i < n; i++)
xp[i] = ap[i];
x.normalize();
}
}
void add(ZZX& x, const ZZX& a, long b)
{
if (a.rep.length() == 0) {
conv(x, b);
}
else {
if (&x != &a) x = a;
add(x.rep[0], x.rep[0], b);
x.normalize();
}
}
void sub(ZZX& x, const ZZX& a, const ZZX& b)
{
long da = deg(a);
long db = deg(b);
long minab = min(da, db);
long maxab = max(da, db);
x.rep.SetLength(maxab+1);
long i;
const ZZ *ap, *bp;
ZZ* xp;
for (i = minab+1, ap = a.rep.elts(), bp = b.rep.elts(), xp = x.rep.elts();
i; i--, ap++, bp++, xp++)
sub(*xp, (*ap), (*bp));
if (da > minab && &x != &a)
for (i = da-minab; i; i--, xp++, ap++)
*xp = *ap;
else if (db > minab)
for (i = db-minab; i; i--, xp++, bp++)
negate(*xp, *bp);
else
x.normalize();
}
void sub(ZZX& x, const ZZX& a, const ZZ& b)
{
long n = a.rep.length();
if (n == 0) {
conv(x, b);
negate(x, x);
}
else if (&x == &a) {
sub(x.rep[0], a.rep[0], b);
x.normalize();
}
else if (x.rep.MaxLength() == 0) {
x = a;
sub(x.rep[0], a.rep[0], b);
x.normalize();
}
else {
// ugly...b could alias a coeff of x
ZZ *xp = x.rep.elts();
sub(xp[0], a.rep[0], b);
x.rep.SetLength(n);
xp = x.rep.elts();
const ZZ *ap = a.rep.elts();
long i;
for (i = 1; i < n; i++)
xp[i] = ap[i];
x.normalize();
}
}
void sub(ZZX& x, const ZZX& a, long b)
{
if (b == 0) {
x = a;
return;
}
if (a.rep.length() == 0) {
x.rep.SetLength(1);
conv(x.rep[0], b);
negate(x.rep[0], x.rep[0]);
}
else {
if (&x != &a) x = a;
sub(x.rep[0], x.rep[0], b);
}
x.normalize();
}
void sub(ZZX& x, long a, const ZZX& b)
{
negate(x, b);
add(x, x, a);
}
void sub(ZZX& x, const ZZ& b, const ZZX& a)
{
long n = a.rep.length();
if (n == 0) {
conv(x, b);
}
else if (x.rep.MaxLength() == 0) {
negate(x, a);
add(x.rep[0], a.rep[0], b);
x.normalize();
}
else {
// ugly...b could alias a coeff of x
ZZ *xp = x.rep.elts();
sub(xp[0], b, a.rep[0]);
x.rep.SetLength(n);
xp = x.rep.elts();
const ZZ *ap = a.rep.elts();
long i;
for (i = 1; i < n; i++)
negate(xp[i], ap[i]);
x.normalize();
}
}
void negate(ZZX& x, const ZZX& a)
{
long n = a.rep.length();
x.rep.SetLength(n);
const ZZ* ap = a.rep.elts();
ZZ* xp = x.rep.elts();
long i;
for (i = n; i; i--, ap++, xp++)
negate((*xp), (*ap));
}
long MaxBits(const ZZX& f)
{
long i, m;
m = 0;
for (i = 0; i <= deg(f); i++) {
m = max(m, NumBits(f.rep[i]));
}
return m;
}
void PlainMul(ZZX& x, const ZZX& a, const ZZX& b)
{
if (&a == &b) {
PlainSqr(x, a);
return;
}
long da = deg(a);
long db = deg(b);
if (da < 0 || db < 0) {
clear(x);
return;
}
long d = da+db;
const ZZ *ap, *bp;
ZZ *xp;
ZZX la, lb;
if (&x == &a) {
la = a;
ap = la.rep.elts();
}
else
ap = a.rep.elts();
if (&x == &b) {
lb = b;
bp = lb.rep.elts();
}
else
bp = b.rep.elts();
x.rep.SetLength(d+1);
xp = x.rep.elts();
long i, j, jmin, jmax;
ZZ t, accum;
for (i = 0; i <= d; i++) {
jmin = max(0, i-db);
jmax = min(da, i);
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, ap[j], bp[i-j]);
add(accum, accum, t);
}
xp[i] = accum;
}
x.normalize();
}
void PlainSqr(ZZX& x, const ZZX& a)
{
long da = deg(a);
if (da < 0) {
clear(x);
return;
}
long d = 2*da;
const ZZ *ap;
ZZ *xp;
ZZX la;
if (&x == &a) {
la = a;
ap = la.rep.elts();
}
else
ap = a.rep.elts();
x.rep.SetLength(d+1);
xp = x.rep.elts();
long i, j, jmin, jmax;
long m, m2;
ZZ t, accum;
for (i = 0; i <= d; i++) {
jmin = max(0, i-da);
jmax = min(da, i);
m = jmax - jmin + 1;
m2 = m >> 1;
jmax = jmin + m2 - 1;
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, ap[j], ap[i-j]);
add(accum, accum, t);
}
add(accum, accum, accum);
if (m & 1) {
sqr(t, ap[jmax + 1]);
add(accum, accum, t);
}
xp[i] = accum;
}
x.normalize();
}
static
void PlainMul(ZZ *xp, const ZZ *ap, long sa, const ZZ *bp, long sb)
{
if (sa == 0 || sb == 0) return;
long sx = sa+sb-1;
long i, j, jmin, jmax;
static ZZ t, accum;
for (i = 0; i < sx; i++) {
jmin = max(0, i-sb+1);
jmax = min(sa-1, i);
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, ap[j], bp[i-j]);
add(accum, accum, t);
}
xp[i] = accum;
}
}
static
void KarFold(ZZ *T, const ZZ *b, long sb, long hsa)
{
long m = sb - hsa;
long i;
for (i = 0; i < m; i++)
add(T[i], b[i], b[hsa+i]);
for (i = m; i < hsa; i++)
T[i] = b[i];
}
static
void KarSub(ZZ *T, const ZZ *b, long sb)
{
long i;
for (i = 0; i < sb; i++)
sub(T[i], T[i], b[i]);
}
static
void KarAdd(ZZ *T, const ZZ *b, long sb)
{
long i;
for (i = 0; i < sb; i++)
add(T[i], T[i], b[i]);
}
static
void KarFix(ZZ *c, const ZZ *b, long sb, long hsa)
{
long i;
for (i = 0; i < hsa; i++)
c[i] = b[i];
for (i = hsa; i < sb; i++)
add(c[i], c[i], b[i]);
}
static void PlainMul1(ZZ *xp, const ZZ *ap, long sa, const ZZ& b)
{
long i;
for (i = 0; i < sa; i++)
mul(xp[i], ap[i], b);
}
static
void KarMul(ZZ *c, const ZZ *a,
long sa, const ZZ *b, long sb, ZZ *stk)
{
if (sa < sb) {
{ long t = sa; sa = sb; sb = t; }
{ const ZZ *t = a; a = b; b = t; }
}
if (sb == 1) {
if (sa == 1)
mul(*c, *a, *b);
else
PlainMul1(c, a, sa, *b);
return;
}
if (sb == 2 && sa == 2) {
mul(c[0], a[0], b[0]);
mul(c[2], a[1], b[1]);
add(stk[0], a[0], a[1]);
add(stk[1], b[0], b[1]);
mul(c[1], stk[0], stk[1]);
sub(c[1], c[1], c[0]);
sub(c[1], c[1], c[2]);
return;
}
long hsa = (sa + 1) >> 1;
if (hsa < sb) {
/* normal case */
long hsa2 = hsa << 1;
ZZ *T1, *T2, *T3;
T1 = stk; stk += hsa;
T2 = stk; stk += hsa;
T3 = stk; stk += hsa2 - 1;
/* compute T1 = a_lo + a_hi */
KarFold(T1, a, sa, hsa);
/* compute T2 = b_lo + b_hi */
KarFold(T2, b, sb, hsa);
/* recursively compute T3 = T1 * T2 */
KarMul(T3, T1, hsa, T2, hsa, stk);
/* recursively compute a_hi * b_hi into high part of c */
/* and subtract from T3 */
KarMul(c + hsa2, a+hsa, sa-hsa, b+hsa, sb-hsa, stk);
KarSub(T3, c + hsa2, sa + sb - hsa2 - 1);
/* recursively compute a_lo*b_lo into low part of c */
/* and subtract from T3 */
KarMul(c, a, hsa, b, hsa, stk);
KarSub(T3, c, hsa2 - 1);
clear(c[hsa2 - 1]);
/* finally, add T3 * X^{hsa} to c */
KarAdd(c+hsa, T3, hsa2-1);
}
else {
/* degenerate case */
ZZ *T;
T = stk; stk += hsa + sb - 1;
/* recursively compute b*a_hi into high part of c */
KarMul(c + hsa, a + hsa, sa - hsa, b, sb, stk);
/* recursively compute b*a_lo into T */
KarMul(T, a, hsa, b, sb, stk);
KarFix(c, T, hsa + sb - 1, hsa);
}
}
void KarMul(ZZX& c, const ZZX& a, const ZZX& b)
{
if (IsZero(a) || IsZero(b)) {
clear(c);
return;
}
if (&a == &b) {
KarSqr(c, a);
return;
}
vec_ZZ mem;
const ZZ *ap, *bp;
ZZ *cp;
long sa = a.rep.length();
long sb = b.rep.length();
if (&a == &c) {
mem = a.rep;
ap = mem.elts();
}
else
ap = a.rep.elts();
if (&b == &c) {
mem = b.rep;
bp = mem.elts();
}
else
bp = b.rep.elts();
c.rep.SetLength(sa+sb-1);
cp = c.rep.elts();
long maxa, maxb, xover;
maxa = MaxBits(a);
maxb = MaxBits(b);
xover = 2;
if (sa < xover || sb < xover)
PlainMul(cp, ap, sa, bp, sb);
else {
/* karatsuba */
long n, hn, sp, depth;
n = max(sa, sb);
sp = 0;
depth = 0;
do {
hn = (n+1) >> 1;
sp += (hn << 2) - 1;
n = hn;
depth++;
} while (n >= xover);
ZZVec stk;
stk.SetSize(sp,
((maxa + maxb + NumBits(min(sa, sb)) + 2*depth + 10)
+ NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);
KarMul(cp, ap, sa, bp, sb, stk.elts());
}
c.normalize();
}
void PlainSqr(ZZ* xp, const ZZ* ap, long sa)
{
if (sa == 0) return;
long da = sa-1;
long d = 2*da;
long i, j, jmin, jmax;
long m, m2;
static ZZ t, accum;
for (i = 0; i <= d; i++) {
jmin = max(0, i-da);
jmax = min(da, i);
m = jmax - jmin + 1;
m2 = m >> 1;
jmax = jmin + m2 - 1;
clear(accum);
for (j = jmin; j <= jmax; j++) {
mul(t, ap[j], ap[i-j]);
add(accum, accum, t);
}
add(accum, accum, accum);
if (m & 1) {
sqr(t, ap[jmax + 1]);
add(accum, accum, t);
}
xp[i] = accum;
}
}
static
void KarSqr(ZZ *c, const ZZ *a, long sa, ZZ *stk)
{
if (sa == 1) {
sqr(*c, *a);
return;
}
if (sa == 2) {
sqr(c[0], a[0]);
sqr(c[2], a[1]);
mul(c[1], a[0], a[1]);
add(c[1], c[1], c[1]);
return;
}
if (sa == 3) {
sqr(c[0], a[0]);
mul(c[1], a[0], a[1]);
add(c[1], c[1], c[1]);
sqr(stk[0], a[1]);
mul(c[2], a[0], a[2]);
add(c[2], c[2], c[2]);
add(c[2], c[2], stk[0]);
mul(c[3], a[1], a[2]);
add(c[3], c[3], c[3]);
sqr(c[4], a[2]);
return;
}
long hsa = (sa + 1) >> 1;
long hsa2 = hsa << 1;
ZZ *T1, *T2;
T1 = stk; stk += hsa;
T2 = stk; stk += hsa2-1;
KarFold(T1, a, sa, hsa);
KarSqr(T2, T1, hsa, stk);
KarSqr(c + hsa2, a+hsa, sa-hsa, stk);
KarSub(T2, c + hsa2, sa + sa - hsa2 - 1);
KarSqr(c, a, hsa, stk);
KarSub(T2, c, hsa2 - 1);
clear(c[hsa2 - 1]);
KarAdd(c+hsa, T2, hsa2-1);
}
void KarSqr(ZZX& c, const ZZX& a)
{
if (IsZero(a)) {
clear(c);
return;
}
vec_ZZ mem;
const ZZ *ap;
ZZ *cp;
long sa = a.rep.length();
if (&a == &c) {
mem = a.rep;
ap = mem.elts();
}
else
ap = a.rep.elts();
c.rep.SetLength(sa+sa-1);
cp = c.rep.elts();
long maxa, xover;
maxa = MaxBits(a);
xover = 2;
if (sa < xover)
PlainSqr(cp, ap, sa);
else {
/* karatsuba */
long n, hn, sp, depth;
n = sa;
sp = 0;
depth = 0;
do {
hn = (n+1) >> 1;
sp += hn+hn+hn - 1;
n = hn;
depth++;
} while (n >= xover);
ZZVec stk;
stk.SetSize(sp,
((2*maxa + NumBits(sa) + 2*depth + 10)
+ NTL_ZZ_NBITS-1)/NTL_ZZ_NBITS);
KarSqr(cp, ap, sa, stk.elts());
}
c.normalize();
}
NTL_END_IMPL
syntax highlighted by Code2HTML, v. 0.9.1