/*
 $Id: linsystemB.cc,v 1.3 1996/11/20 10:00:08 roitzsch Exp $
 (C)opyright 1996 by Konrad-Zuse-Center, Berlin
 All rights reserved.
 Part of the Kaskade distribution
*/

#include "linsystem.h"

#include "sysmat.h"
#include "precond.h"

#include "cmdpars.h"
extern CmdPars Cmd;

extern ostream* infoFile;

//-------------------------------------------------------------------------
//-------------------------------------------------------------------------


Bool LinSystem:: CROmin(Vector<Num>& x, int maxIter) 
{
    int iter, i;
    Num alpha, beta, qCq;
    Real bNorm, deltaE=0.0;
    int  dim = A->Dim();
    Vector<Num> p(dim), q(dim), Cq(dim), r(dim), Cr(dim), ACr(dim);

    preCond->initialize(A,x,b);
    bNorm = normRhs();
    
    preCond->AMult(r,A,x);
    FORALL(r,i) r[i] = b[i]-r[i];		// initial residual

    preCond->invert(Cr,A,r);	   		// initial search direction

    FORALL(p,i) p[i] = Cr[i];
    preCond->AMult(q,A,p);

    if (convergenceTest(0, x, r, bNorm, 0.0))
    { 
	preCond->close(A,x,b); 
 	return True; 
    }


    for (iter=1; iter<=maxIter; ++iter)
    {
	preCond->invert(Cq,A,q); 
	qCq = dot(q,Cq);				// qCq = (p,ACA*p)
	alpha = dot(Cr,q)/qCq;

	FORALL(x,i)   
	{
	    x[i]  += alpha * p[i];
	    r[i]  -= alpha * q[i];
	    Cr[i] -= alpha * Cq[i];
	}

	deltaE = Abs(alpha*alpha*qCq);  		// alt: dot(q,p);

	if (convergenceTest(iter, x, r, bNorm, deltaE)) break;

	preCond->AMult(ACr,A,Cr);

	beta = -dot(ACr,Cq)/qCq;

	FORALL(p,i) p[i] = Cr[i]  + beta*p[i];
	FORALL(q,i) q[i] = ACr[i] + beta*q[i];
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------


Bool LinSystem:: CROdir(Vector<Num>& x, int maxIter) 
{
    int  iter, i;
    Num  alpha, sigma, gamma, qCq, qM1qM1=0.0, temp;
    Real bNorm, deltaE=0.0;
    int  dim = A->Dim();
    Vector<Num> p(dim), q(dim), r(dim), Cq(dim), ACq(dim),
    		pM1(dim), CqM1(dim), qM1(dim);

    preCond->initialize(A,x,b);
    bNorm = normRhs();
    
    preCond->AMult(r,A,x);
    FORALL(r,i) r[i] = b[i]-r[i];		// initial residual

    preCond->invert(p,A,r);	   		// initial search direction
    preCond->AMult (q,A,p);

    if (convergenceTest(0, x, r, bNorm, 0.0))
    { 
	preCond->close(A,x,b); 
	return True; 
    }


    for (iter=1; iter<=maxIter; ++iter)
    {
	preCond->invert(Cq,A,q); 
	qCq = dot(q,Cq);
	alpha = dot(r,Cq)/qCq;

	FORALL(x,i) x[i] += alpha * p[i];
	FORALL(r,i) r[i] -= alpha * q[i];

	deltaE = Abs(alpha*alpha*qCq);			// qCq = (p,ACA*p)

	if (convergenceTest(iter, x, r, bNorm, deltaE)) break;

	preCond->AMult(ACq,A,Cq);

	gamma = dot(ACq,Cq)/qCq;
	if (iter > 1) sigma = dot(ACq,CqM1)/qM1qM1;
	else          sigma = 0.0;

	FORALL(p,i) 
	{
	    temp = p[i];
	    p[i] = Cq[i] - gamma*p[i] - sigma*pM1[i];
	    pM1[i] = temp;

	    temp = q[i];
	    q[i] = ACq[i] - gamma*q[i] - sigma*qM1[i];
	    qM1[i] = temp;

	}
	// preCond->AMult(q,A,p);

	FORALL(Cq,i) CqM1[i] = Cq[i];
	qM1qM1 = qCq;
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------

	//  solve normal equation system AT*A*x = AT * b

		// right preconditioning version


Bool LinSystem:: LSQCG (Vector<Num>& x, int maxIter)
{
    int  iter, i;
    Num  alpha, beta, rr, rrM1;
    Real bNorm;
    int  dim = A->Dim();
    Vector<Num> p(dim), r(dim), aux(dim), h(dim);
    
    preCond->initialize(A,x,b);
    bNorm = normRhs();
    
    preCond->AMult(r,A,x);
    FORALL(r,i) r[i] = b[i]-r[i];	
    
    preCond->ATMult(aux,A,r);			
    preCond->invert(p,A,aux);			// initial search direction
    
    rr = cdot(p,p);

    if (convergenceTest(0, x, r, bNorm, 0.0))
    { 
	preCond->close(A,x,b); 
	return True; 
    }

    
    for (iter=1; iter<=maxIter; ++iter)
    {
	preCond->invert(aux,A,p); 
	preCond->AMult(h,A,aux);
	
	alpha = rr/cdot(h,h);
	
	FORALL(x,i) x[i] += alpha * aux[i]; 
	FORALL(r,i) r[i] -= alpha * h[i];
	
	preCond->ATMult(aux,A,r);
	preCond->invert(h,A,aux); 
	
	rrM1 = rr;
	rr   = cdot(h,h);
	
	if (convergenceTest(iter, x, r, bNorm, 0.0)) break;
	
	beta = rr / rrM1;
	FORALL(p,i) p[i] = h[i] + beta*p[i];
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------


/* 		// LSQCG with left preconditioning 

  Bool LinSystem:: LSQCG(Vector<Num>& x, int maxIter)
  {
  
  
  int  iter, i;
  Num  alpha, beta, rTsq, rTsqM1;
  Real bNorm;
  int  dim = A->Dim();
  Vector<Num> p(dim), r(dim), aux(dim), h(dim);
  
  preCond->initialize(A,x,b);
  
  preCond->invert(aux,A,b);
  bNorm = normRhs();
  
  preCond->AMult(r,A,x);
  FORALL(r,i) aux[i] = b[i]-r[i];	
  preCond->invert(r,A,aux);   		// initial residual
  
  preCond->invert(aux,A,r);   
  preCond->ATMult(p,A,aux);			// initial search direction
  
  rTsq = dot(p,p);
  convergenceTest(0, x, r, bNorm, 0.0);
  
  
  for (iter=1; iter<=maxIter; ++iter)
  {
  preCond->AMult(h,A,p);
  preCond->invert(aux,A,h); 
  
  alpha = rTsq/dot(aux,aux);
  
  FORALL(x,i) x[i] += alpha * p[i];
  FORALL(r,i) r[i] -= alpha * aux[i];
  
  preCond->invert(aux,A,r); 
  preCond->ATMult(h,A,aux);
  
  rTsqM1 = rTsq;
  rTsq   = dot(h,h);
  
  if (convergenceTest(iter, x, r, bNorm, 0.0)) break;
  
  beta = rTsq / rTsqM1;
  FORALL(p,i) p[i] = h[i] + beta*p[i];
  }
  
  preCond->close(A,x,b);
  return iter <= maxIter;
  }
  */
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------

// -- 	CG with different preconditioners for preconditioning and 
//	iteration error estimation <r,Cr>

Bool LinSystem:: TestSolver1(Vector<Num>& x, int maxIter)
{
    int iter, i;
    Num alpha, beta, rr, rrM1, eAe;
    Real bNorm;
    int  dim = A->Dim();
    Vector<Num> p(dim), r(dim), aux(dim);

    preCond->initialize(A,x,b);
    bNorm = normRhs();
    
    A->Mult(aux,x);				// aux = A*x
    FORALL(r,i) r[i] = b[i]-aux[i];		// initial residual

    // -- 	the usual preconditioning for the cg:
    
    if      (Cmd.isSet("ccgPrec","jacobi")) A->DiagDiv(aux,r);
    else if (Cmd.isSet("ccgPrec","sgs")) 
    {
	A->Fm1(aux,r);
	A->DiagMult(aux,aux);
	A->FmT(aux,aux);
    }
    else  FORALL(r,i) aux[i] = r[i];		// preconditioned residual s

    FORALL(p,i) p[i] = aux[i];			// initial search direction

    rr = dot(r,aux);
    convergenceTest(0, x, r, bNorm, 0.0, 0.0);

    {
	preCond->invert(aux,A,r);	
	eAe = dot(r,aux);
    }
    

    for (iter=1; iter<=maxIter; ++iter)
    {
	A->Mult(aux,p);

	alpha = rr / dot(p,aux);
	FORALL(x,i) x[i] += alpha * p[i];
	FORALL(r,i) r[i] -= alpha * aux[i];

	if (convergenceTest(iter, x, r, bNorm, Abs(alpha*rr), Abs(eAe))) break;

	{
	    preCond->invert(aux,A,r);	
	    eAe = dot(r,aux);
	}

	// -- 	the usual preconditioning for the cg:

	if      (Cmd.isSet("ccgPrec","jacobi")) A->DiagDiv(aux,r);
	else if (Cmd.isSet("ccgPrec","sgs")) 
	{
	    A->Fm1(aux,r);
	    A->DiagMult(aux,aux);
	    A->FmT(aux,aux);
	}
	else  FORALL(r,i) aux[i] = r[i];


	rrM1 = rr;
	rr   = dot(aux,r);

	beta = rr / rrM1;
	FORALL(p,i) p[i] = aux[i] + beta*p[i];
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------

// -- 	relaxation method with different preconditioners for relaxation
//	(=preconditioning) and iteration error estimation <r,Cr>


Bool LinSystem:: TestSolver2(Vector<Num>& x, int maxIter)
{
    int  i, iter;
    Real bNorm, dE=0.0, eAe;
    int  dim = A->Dim();
    Vector<Num> r(dim), aux(dim);

    preCond->initialize(A,x,b);
    bNorm = normRhs();

    for (iter=0; iter<=maxIter; ++iter)
    {
	A->Mult(aux,x);
	FORALL(r,i) r[i] = b[i] - aux[i];	// new residual

	Real omega = 1.0;
	Cmd.get("Omega",&omega);

	if      (Cmd.isSet("ccgPrec","jacobi")) A->DiagDiv(aux,r,omega);
	else if (Cmd.isSet("ccgPrec","richar")) FORALL(r,i) aux[i] = omega*r[i];
	else {
	    A->Fm1(aux,r,omega);
	    A->DiagMult(aux,aux);
	    A->FmT(aux,aux,omega);
	}
 
	FORALL(x,i) x[i] += aux[i]; 
	/*
	  A->Mult(h,aux);
	  dE = dot(aux,h);
	  preCond->invert(aux,A,r);		// aux is the precond. residual
	*/
	eAe = Abs(dot(aux,r));

	if (convergenceTest(iter, x, r, bNorm, dE, eAe))  break;

    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------


Bool LinSystem:: NonLinRelax(Vector<Num>& x, int maxIter)
{
    int  i, iter;
    Real ENormX, ENormE;
    Num  dummy;
    Vector<Num> r(A->Dim()), e(A->Dim());
    
    preCond->initialize(A,x,b);
    
    nonLinEnergyConvergence(0, 0.0, 0.0);
    
    for (iter=1; iter<=maxIter; ++iter)
    {
	preCond->AMult(e,A,x);
	FORALL(r,i) r[i] = b[i] - e[i];		// new residual
	
	preCond->invert(e,A,r);		
	FORALL(x,i) x[i] += e[i]; 
	
	preCond->AMult(r,A,e);
	ENormE = 0.5*Abs(cdot(r,e));
	
	compEnergy(x, &ENormX, &dummy); 
	
	if (nonLinEnergyConvergence(iter, ENormX, ENormE))  break;
    }
    
    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------
   
   
Bool LinSystem:: nonLinEnergyConvergence(int iter, Real ENormX, Real ENormE)
{
    const char *format0 = "%20s %7s%8.3g %12s";
    const char *format = "\n%20d %15.3g %15.3g";
    
    static Real tol, ratioM1;
    Real ratio;
    
    if (iter==0) 					// solver-setup
    {  
	tol 	= 0.1 * extPrecFactor * sqrt(globalPrecision);
	ratioM1 = 1.0;
	
	if (infoLinSystem) 
	    cout << Form(format0, "iteration", "rel.Error < ", 
			   tol, "conv.rate") << "\t abs.Error";
	return False;
    }

    ratio = sqrt(quot(ENormE,ENormX));
   
    if (ratio < tol) 					// converged
    {
	if (infoLinSystem)
	    cout << Form(format, abs(iter), ratio, ratio/ratioM1) 
	         << "\t " << sqrt(ENormE) << "\n"; 
	
	if (Cmd.isTrue("writeIterations")) 
	  *infoFile << A->Dim() << "  " << iter << "\n";

	return True;
    }
    
    if (infoLinSystem)
	if (iter%infoLinSystem==0 || iter==1 || iter<0)
	  cout << Form(format, abs(iter), ratio, ratio/ratioM1)
	       << "\t " << sqrt(ENormE);
    
    ratioM1 = ratio;
    return False;
}




syntax highlighted by Code2HTML, v. 0.9.1