/*
 $Id: linsystemA.cc,v 1.4 1996/11/20 10:00:07 roitzsch Exp $
 (C)opyright 1996 by Konrad-Zuse-Center, Berlin
 All rights reserved.
 Part of the Kaskade distribution
*/

#include "linsystem.h"
#include "statistic.h"

#include "sysmat.h"
#include "precond.h"
#include "dirichlet.h"

#include "cmdpars.h"
extern CmdPars  Cmd;

extern ostream* infoFile;

//-------------------------------------------------------------------------


Bool LinSystem:: solve(Vector<Num>& x, Real eNorm, Real error,
		       Real reqGlobalPrecision, int step, int maxIter)
{
    int    i, nRestart, noDir=0;
    const  Real tiny = machMin(Real(0.0));
    static Real accTime = 0.0;
    Timer  timer, accTimer;

    status = True;			// required in compEnergy

    ENorm = eNorm;			// for convergence test
    if (error < eNorm)  Error = error;	
    else 		Error = eNorm;
    globalPrecision = reqGlobalPrecision;


    if (Cmd.isTrue("checkDiagonal"))  A->checkDiagonal();

    FORALL(x,i) if (dirichletBCs->isSet(i)) { ++noDir;
					      x[i] = dirichletBCs->value(i); }


    if ((A->Dim() <= directSolverLimit) ||		    // direct solution
	(step == 0 && A->Dim() <= level0DirectSolverLimit))
    {
	if (infoLinSystem) cout << "\n\t\tDirect solution by LU-Factorization\n";
	A->Decompose();
	A->FBSubst(x,b);
    }

    else if (noDir == A->Dim()) status = True;	       // all nodes constrained
		
    else 		 			       // iterative solver 
    {
	if (maxIter <= 0) maxIter = 10*A->Dim();  
	if (maxIter < 10) maxIter = 10;

	int maxRestarts = 0;      	

	Cmd.get("maxIter",    &maxIter);
	Cmd.get("maxRestarts",&maxRestarts);


	if (step==0 || ENorm<tiny) convTest = residual;
	else  			   convTest = convTest0;


	for (nRestart=0; nRestart<=maxRestarts; ++nRestart)
	{
	    switch (solver)
	    {
	      case cgOmin:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: CG";
		status = CG(x,maxIter);
		break;
		
	      case cgOdir:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: CG (OrthoDir)";
		status = CGOdir(x,maxIter);
		break;
		
	      case crOmin:
		if (infoLinSystem)  cout << "\n\tLinear Solver: CR (OrthoMin)";
		status = CROmin(x,maxIter);
		break;
		
	      case crOdir:
		if (infoLinSystem)  cout << "\n\tLinear Solver: CR (OrthoDir)";
		status = CROdir(x,maxIter);
		break;
		
	      case biCG:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: BiCG";
		status = BiCG(x,maxIter);
		break;
		
	      case biCGStab:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: BiCGStab";
		status = BiCGStab(x,maxIter);
		break;
		
	      case cgs:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: CGS";
		status = CGS(x,maxIter);
		break;
		
	      case lsqCG:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: LSQ-CG"; 
		status = LSQCG(x,maxIter);
		break;
		
	      case ddCG:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: Dd-CG";
		status = DdCG(x,maxIter);
		break;
		
	      case gmRes:
		if (infoLinSystem)  cout << "\n\tLinear Solver: GMRes";
		status = GMRes(x,maxIter, maxOrthoGMRes);
		break;
		
	      case relax:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: Relaxation";
		status = Relax(x,maxIter);
		break;

	      case nonLinRelax:  
		if (infoLinSystem) cout << "\n\tLinear Solver: NL-Relaxation";
		status = NonLinRelax(x,maxIter);
		break;

	      case testSolver1:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: Test1";
		status = TestSolver1(x,maxIter);
		break;

	      case testSolver2:  
		if (infoLinSystem)  cout << "\n\tLinear Solver: Test2";
		status = TestSolver2(x,maxIter);
		break;
	    }
	    
	    if (status == True) break;
	    if (infoLinSystem)
	      cout << "  restart (at " << (nRestart+1)*maxIter << ")";
	}
    }

    if (status==False) { cout << "\n** Solution of Linear System failed\n"; }

    if (timeLinSystem) { cout << "\n\tLinear System Solution: "; timer.cpu(); }

	
    accTime += accTimer.cpu(False);

    if (accTimeLinSystem)
      { cout << Form("\tAccumulated time:\t %1.2f sec.\n", accTime); }
    
    if (Cmd.isTrue("writeCpuTime")) *infoFile << accTime;

    // if (Cmd.isTrue("Eigenvalues"))  Lanczos(x,A);

    return status;
}
//-------------------------------------------------------------------------


Bool LinSystem:: convergenceTest(int iter, Vector<Num>& x, Vector<Num>& r, 
				 Real bNorm, Real deltaE, Real eAe)
{
    switch (convTest)
    {
      case ci:	     return ciConvergence   (iter, x, r, bNorm, deltaE, eAe);
      case ccgDd:    return ccgDdConvergence(iter, x, r, bNorm, deltaE, eAe);
      case ccgDB:    return ccgDBConvergence(iter, x, r, bNorm, deltaE, eAe);

      case residual: return residualConvergence (iter, x, r, bNorm);
      case decayOfResidual: return residualDecay(iter, x, r, bNorm);
      case vectorIteration: return vectorIterConvergence(iter, x, r, bNorm);
    }
    return False;
}
//-------------------------------------------------------------------------


void LinSystem:: compEnergy(Vector<Num>& x, Real *eNorm, Num* fct, Bool print)
{
    Num xAx, xb, xbSave, xr;
    static Real tiny = machMin(Real(0));

    if (status == False) 			    // solution hath failed!
    { 
	Real huge = machMax(Real(0));
	*fct   = huge; 
	*eNorm = huge; 
	return;
    }

    if (A->DirectSolution())
    {
	xb     = cdot(x,b);
	xbSave = cdot(x,bSave);

	if (symmetry == sym)
	{
	    Fct   = -0.5*xb + Fct0;
	    ENorm = Abs(Fct +xbSave);	 
	}
	else
	{
	    Fct   = -0.5*xb + Fct0 + xbSave;
	    ENorm = Abs(Fct + xb + E0);
	}
	xr = 0.0;
    }
    else   
    {
	Vector<Num> aux(A->Dim());
	xb = cdot(x,b);

	if (symmetry == sym)
	{
	    A->Mult(aux,x);
	    xAx = 0.5*cdot(x,aux);		//  1/2 x*A*x
	    
	    Fct   = xAx - xb + Fct0;		
	    ENorm = Abs(Fct + cdot(x,bSave)); 	
	}
	else
	{
	    A->Mult(aux,x);
	    xAx = 0.5*cdot(x,aux);		//  1/2 x*A*x
	    
	    xbSave = cdot(x,bSave);
	    Fct   = xAx - xb + Fct0 + xbSave;	
	    ENorm = Abs(Fct + xb + E0); 
	}

	xr = 0.5*(2.*xAx - xb);	// x*A*x - x*b = x*r  (x*r: <- residual forces)
			    	// [Fct = 1/2*x*A*x-x*b = 1/2*x(r-b)];
    }
	
    if (print && infoLinSystem)
    {
	if (Abs(Fct) > tiny)
	{
	    cout << "\t  F=" << Fct << Form("  FRes/F=%1.2g",Abs(xr/Fct));
	    if (Abs(xr/Fct) > globalPrecision) 
	      cout << " > " << globalPrecision << " (global Precision)";
	    cout << "\n";
	}
    }

    *fct   = Fct;
    *eNorm = ENorm; 

    //if (Cmd.isTrue("printCCGError"))
    //  cout << "\n\tCCG: estimated error = " << Error/ENorm *100 << " %\n";
}
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------


Bool LinSystem:: residualConvergence(int iter, Vector<Num>& /*x*/, Vector<Num>& r, 
				     Real bNorm)
{
    const char *format0 =   "%14s %7s %7.3g %10s %10s   |b|=%1.3g";
    const char *format = "\n%14d %15.3g %10.3g %9.3g";
    static Real tol, rNormM1, totRatio; 
    static const Real  tiny=10*machMin(Real(0));

    Real rNorm = Norm(r);


    if (statistic->active) 
    {
        statistic->ZD_IntWrite(statistic->idIteSteps,iter);
        statistic->ZD_RealWrite(statistic->idIteErr,rNorm/bNorm);
    }

    if (iter==0) 				// solver-setup
    {  
	tol = 0.1 * extPrecFactor * globalPrecision;

	rNormM1 = rNorm; 
	totRatio = 1.0;


	if (infoLinSystem) 
	  cout << Form(format0, "iter", "r/b < ", tol, "ratio(r)", 
			 "<ratio(r)>", bNorm);

	if (infoLinSystem == 1)
	  cout << Form(format, Abs(iter), quot(rNorm,bNorm), 0.0, 0.0);


	if (rNorm < tiny) 		// unusual convergence: r == 0
	{
	    if (infoLinSystem)
	      cout << Form(format, Abs(iter), quot(rNorm,bNorm), 0.0, 0.0) 
	           << "\n";

	    return True;
	}
	else return False; 
    }


    totRatio *= rNorm/rNormM1;


    if (rNorm < tol*bNorm) 				// converged
    {	
    	if (infoLinSystem) 
	  cout << Form(format, Abs(iter), quot(rNorm,bNorm), rNorm/rNormM1, 
			 pow(totRatio, 1.0/Abs(iter))) << "\n"; 

	if (Cmd.isTrue("writeIterations")) 
	  *infoFile << A->Dim() << "  " << iter << "\n";

	return True;
    }

    if (infoLinSystem)
	if (iter%infoLinSystem==0 || iter==1 || iter<0)
	    cout << Form(format, Abs(iter), rNorm/bNorm, rNorm/rNormM1,
			   pow(totRatio, 1.0/Abs(iter))); 


    rNormM1 = rNorm;

    return False;
}
//-------------------------------------------------------------------------


Bool LinSystem:: ciConvergence(int iter, Vector<Num>& /*x*/, Vector<Num>& r, 
			       Real bNorm, Real deltaE, Real eAe)
{
    static int  dimM1=0;
    static Real tol, iterErrM1=0.0, tiny=10*machMin(Real(0));
    Real ratio=1.0, expo, rNorm, iterErr;
    Real KrauseFactor = 1.4789365294274;		// (truncated!)

    const char *format0 =   "%17s %10s %10s %10s tol=%1.2g |b|=%1.3g";
    const char *format = "\n%17d %10.3g %10.3g %10.3g";

    rNorm = Norm(r);

    if (iter==0) 	
    {  
	if (Abs(Error) > tiny && Abs(ENorm) > tiny)
	{
	    if (dimM1 == 0) dimM1 = A->Dim();
	    ratio = Real(dimM1)/Real(A->Dim());

	    expo = 1.0/Real(spaceDim);
	    tol  = sqrt(globalPrecision*Abs(ENorm)) / 
						(sqrt(Error)*pow(ratio,expo));
	    expo = Real(spaceDim+1) / 2.0;
	    tol  = sqrt(Abs(Error)) * pow(tol,expo) + KrauseFactor*iterErrM1;

	    tol *= 0.1 * extPrecFactor;
	}
	else tol = 0.0;

	Error = 0.0;

	if (infoLinSystem) 
	   cout << Form(format0, "iter(CI)", "eAe", "sum dE", "r/b", 
			  sqr(tol), bNorm);

	if (rNorm < tiny) 		// unusual convergence: r == 0
	{
	    if (infoLinSystem)
	      cout << Form(format, iter, eAe, Error, rNorm/bNorm) << "\n";
	    return True;
	}
	else return False; 
    }

    Error += deltaE;
    iterErr = sqrt(eAe);

    if (tol <= tiny)  iterErr = 0.0;		// forced convergence

    if (iterErr <= tol || rNorm < tiny)	 			// converged
    {						
    	if (infoLinSystem)
	  cout << Form(format, iter, eAe, Error, rNorm/bNorm) << "\n";

	dimM1 = A->Dim();
	iterErrM1 = iterErr;

	if (Cmd.isTrue("writeIterations")) 
	  *infoFile << A->Dim() << "  " << iter << "\n";

	return True;
    }

    if (infoLinSystem)
	if (iter%infoLinSystem==0 || iter==1 || iter<0) 
	cout << Form(format, iter, eAe, Error, rNorm/bNorm);

    return False;
}
//-------------------------------------------------------------------------

	// deltaE is the "eps(jk)" in script, ratio is Theta


Bool LinSystem:: ccgDBConvergence(int iter, Vector<Num>& /*x*/, Vector<Num>& r, 
				  Real bNorm, Real deltaE, Real eAe)
{
    static Real tol, tiny=10*machMin(Real(0));;
    Real rNorm = Norm(r);

    const char *format0 =   "%17s %10s %10s %10s tol=%1.2g |b|=%1.3g";
    const char *format = "\n%17d %10.3g %10.3g %10.3g";

    if (iter==0) 					// solver-setup
    {  
	tol = 0.1 * extPrecFactor * globalPrecision * ::real(ENorm);
	if (tol < tiny) tol = 1e10;			// forced convergence

	Error = 0.0;
	if (infoLinSystem) 
	   cout << Form(format0, "iter(CCGDB)", "eAe", "sum dE", "r/b", 
			  sqr(tol), bNorm);

	if (rNorm < tiny) 		// unusual convergence: r == 0
	{
	    if (infoLinSystem)
	      cout << Form(format, iter, eAe, Error, rNorm/bNorm) << "\n";
	    return True;
	}
	else return False; 
    }

    Error += deltaE;

    if (eAe < tol || rNorm < tiny)	 		// converged
    {						
    	if (infoLinSystem)
	  cout << Form(format, iter, eAe, Error, rNorm/bNorm) << "\n";

	if (Cmd.isTrue("writeIterations")) *infoFile << A->Dim() << "  " 
								<< iter << "\n";
	return True;
    }

    if (infoLinSystem)
	if (iter%infoLinSystem==0 || iter==1 || iter<0) 
	  cout << Form(format, iter, eAe, Error, rNorm/bNorm);

    return False;
}
//-------------------------------------------------------------------------

	// deltaE is the "eps(jk)" in script, ratio is Theta


Bool LinSystem:: ccgDdConvergence(int iter, Vector<Num>& /*x*/, Vector<Num>& r, 
				  Real bNorm, Real deltaE, Real /*eAe*/)
{
    static int dimM1=0;
    static Real tol, deltaEM1, tiny=10*machMin(Real(0));;
    Real rNorm, ratio;

    const char *format0 =   "%17s %7s%8.3g %10s %10s  |b|=%1.3g";
    const char *format = "\n%17d %15.3g %10.3g %10.3g";

    rNorm = Norm(r);

    if (iter==0) 					// solver-setup
    {  
	tol = 0.1 * extPrecFactor * globalPrecision * ::real(ENorm);
	if (tol < tiny) tol = 1e10;			// forced convergence

	ratio = 1.0;
	Error = 0.0;
	if (infoLinSystem) 
	   cout << Form(format0,
			  "iter(CCG)", "dE < ", tol, "ratio(dE)", "r/b", bNorm);

	if (infoLinSystem == 1)
	  cout << Form(format, iter, deltaE, ratio, Norm(r)/bNorm);

	if (rNorm < tiny) 		// unusual convergence: r == 0
	{
	    if (infoLinSystem)
	      cout << Form(format, iter, deltaE, ratio, rNorm/bNorm) << "\n";
	    return True;
	}
	else return False; 
    }

    if (deltaE == 0.0) 
      { cout << "\n*** CCG-Convergence: dE = 0 ! No CG-Solver ?\n"; }

    if (iter > 1) ratio = deltaE / deltaEM1;
    else	  ratio = 1.0;

    Error += deltaE;
    deltaEM1 = deltaE;

    if (deltaE/(1.0-Min(0.9,ratio)) < tol || rNorm < tiny)	// converged
    {						
    	if (infoLinSystem)
	  cout << Form(format, iter, deltaE, ratio, rNorm/bNorm) << "\n";

	if (dimM1) 
	{
	    ratio  = Real(dimM1)/Real(A->Dim());  // ~ estim. contraction factor
	    Error *= 0.5 * ratio/(1.0-ratio);
	}
	else Error *= 0.5 * 1./3.; 		  // contraction factor = 1/4

	dimM1 = A->Dim();
	if (Cmd.isTrue("writeIterations")) *infoFile << A->Dim() << "  " 
								<< iter << "\n";
	return True;
    }

    if (infoLinSystem)
	if (iter%infoLinSystem==0 || iter==1 || iter<0) 
	cout << Form(format, iter, deltaE, ratio, rNorm/bNorm);

    if (rNorm < tiny) return True;

    return False;
}
//-------------------------------------------------------------------------


Bool LinSystem:: residualDecay(int iter, Vector<Num>& /*x*/, Vector<Num>& r, 
			       Real /*bNorm*/)
{
    const char *format0 = "%20s %7s%8.3g %12s   |r(0)|=%1.3g";
    const char *format = "\n%20d %15.3g %12.3g";
    static const Real  tiny=10*machMin(Real(0));
    static Real tol, rNorm0, rNormM1; 

    Real rNorm = Norm(r);

    if (iter==0) 					// solver-setup
    {  
	tol = 0.1 * extPrecFactor;
			// *preCond->precFactor(globalPrecision,Error,ENorm);

	rNorm0  = rNorm; 
	rNormM1 = rNorm; 
	Error = machMax(Real(0.0));			// dummy

	if (infoLinSystem) 
	  cout << Form(format0, "iteration", "r/r0 < ", tol, "ratio(r)",rNorm);

	if (infoLinSystem == 1)
	  cout << Form(format, Abs(iter), rNorm/rNorm0, rNorm/rNormM1);

	if (rNorm0 < tiny) 		// unusual convergence: r == 0
	{
	    if (infoLinSystem)
	      cout << Form(format, Abs(iter), rNorm/rNorm0,
			     rNorm/rNormM1) << "\n";
	    return True;
	}
	else return False; 
    }

    if (rNorm < tol*rNorm0) 				// converged
    {	
    	if (infoLinSystem) 
	    cout << Form(format, Abs(iter), rNorm/rNorm0, rNorm/rNormM1)<<"\n";
 
	if (Cmd.isTrue("writeIterations")) 
	  *infoFile << A->Dim() << "\t " << iter << "\n";

	return True;
    }

    if (infoLinSystem)
	if (iter%infoLinSystem==0 || iter==1 || iter<0)
	    cout << Form(format, Abs(iter), rNorm/rNorm0, rNorm/rNormM1); 

    rNormM1 = rNorm;
    return False;
}
//-------------------------------------------------------------------------


Bool LinSystem:: vectorIterConvergence(int iter, Vector<Num>& x, Vector<Num>& r, 
				       Real bNorm)
{
    const char *format0 = "%18s %5s%8.3g %7s%6.3g %12s   |b|=%1.3g";
    const char *format = "\n%18d %13.3g %13.3g %12.3g";
    static const Real  tiny=10*machMin(Real(0));
    static Real tol, rNormM1; 

    Real rNorm = Norm(r);
    Real xNorm = Norm(x);

    if (iter==0) 					// solver-setup
    {  
	tol = 0.1 * extPrecFactor * globalPrecision;
			// * preCond->precFactor(globalPrecision,Error,ENorm);

	if (infoLinSystem) 
	  cout << Form(format0, "iteration", "x/b > ", 1./sqr(tol), 
		       "r/b < ", tol, "ratio(r)", bNorm);
	rNormM1 = rNorm; 

	if (rNorm < tiny) 		// unusual convergence: r == 0
	{
	    if (infoLinSystem)
	    cout << Form(format, Abs(iter), xNorm/bNorm, rNorm/bNorm, 
			   rNorm/rNormM1) << "\n"; 
	    return True;
	}
	else return False; 
    }

    if (xNorm > bNorm/sqr(tol) || rNorm < tol*bNorm)
    {	
    	if (infoLinSystem) 
	    cout << Form(format, Abs(iter), xNorm/bNorm, rNorm/bNorm, 
			 rNorm/rNormM1) << "\n"; 

	if (Cmd.isTrue("writeIterations")) *infoFile << A->Dim() << "  "
								<< iter << "\n";
	return True;
    }

    if (infoLinSystem)
      if (iter%infoLinSystem==0 || iter==1 || iter<0)
       cout << Form(format,Abs(iter),xNorm/bNorm,rNorm/bNorm,rNorm/rNormM1); 

    rNormM1 = rNorm;
    return False;
}
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------


Bool LinSystem:: CG(Vector<Num>& x, int maxIter) 
{
    int iter, i;
    Num alpha, beta, rr, rrM1;
    Real bNorm;
    int dim = A->Dim();
    Vector<Num> p(dim), r(dim), aux(dim);

    preCond->initialize(A,x,b);
    bNorm = normRhs();
    
    preCond->AMult(aux,A,x);			// aux = A*x
    FORALL(r,i) r[i] = b[i]-aux[i];		// initial residual
    
    preCond->invert(aux,A,r);			// preconditioned residual s
    FORALL(p,i) p[i] = aux[i];			// initial search direction

    rr = dot(r,aux);

    if (convergenceTest(0, x, r, bNorm, 0.0, 0.0)) 
    {  
	preCond->close(A,x,b); 
	return True; 
    }
    

    for (iter=1; iter<=maxIter; ++iter)
    {
	preCond->AMult(aux,A,p);

	alpha = rr / dot(p,aux);
	FORALL(x,i) x[i] += alpha * p[i];
	FORALL(r,i) r[i] -= alpha * aux[i];

	if (convergenceTest(iter, x, r, bNorm, Abs(alpha*rr), Abs(rr))) break;

	preCond->invert(aux,A,r);

	rrM1 = rr;
	rr   = dot(aux,r);

	beta = rr / rrM1;
	FORALL(p,i) p[i] = aux[i] + beta*p[i];
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------


Bool LinSystem:: CGOdir (Vector<Num>& x, int maxIter) 
{
    int iter, i;
    Num alpha, pAp, gamma, sigma, temp, pApM1=0.0, deltaE;
    Real bNorm;
    int dim = A->Dim();
    Vector<Num> p(dim), r(dim), Ap(dim), CAp(dim), pM1(dim), 
    		ApM1(dim);

    preCond->initialize(A,x,b);
    bNorm = normRhs();
    
    preCond->AMult(p,A,x);
    FORALL(r,i) r[i] = b[i]-p[i];		// initial residual
    
    preCond->invert(p,A,r);	

    if (convergenceTest(0, x, r, bNorm, 0.0))
    {  
	preCond->close(A,x,b); 
	return True; 
    }
   

    for (iter=1; iter<=maxIter; ++iter)
    {
	preCond->AMult(Ap,A,p);
	preCond->invert(CAp,A,Ap);

	alpha = dot(r,p);
	pAp   = dot(p,Ap);
	alpha = alpha / pAp;

	FORALL(x,i) x[i] += alpha *  p[i];
	FORALL(r,i) r[i] -= alpha * Ap[i];

	deltaE = alpha*alpha*pAp;

	if (convergenceTest(iter, x, r, bNorm, Abs(deltaE)))  break;

	gamma = dot(CAp,Ap) / pAp;
	if (iter > 1) sigma = dot(CAp,ApM1) / pApM1;
	else          sigma = 0.0;

	FORALL(p,i) 
	{
	    temp = p[i];
	    p[i] = CAp[i] - gamma*p[i] - sigma*pM1[i];
	    pM1[i] = temp;

	    ApM1[i] = Ap[i];
	}
	pApM1 = pAp;
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------

/*
  Bool LinSystem:: CG (Vector<Num>& x, int maxIter)
  {
  int iter, i;
  Num alpha, beta, rr, rrM1;
  Real bNorm;
  Vector<Num> p(A->Dim()), r(A->Dim()), aux(A->Dim());
  static Real tiny = 10.*machMin(Real(0));
  
  bNorm = normRhs();
  
  A->Mult(aux,x);
  FORALL(p,i) p[i] = r[i] = b[i]-aux[i];
  
  rr = dot(r,r);
  convergenceTest(sqrt(Abs(rr)), bNorm, 0, 0.0);
  
  for (iter=1; iter<=maxIter; ++iter)
  {
  A->Mult(aux,p);
  alpha = dot(p,aux);
  
  if (Abs(alpha) > tiny) 			// unusual convergence alpha=0 ?
  {
  alpha = rr / alpha;
  FORALL(x,i) x[i] += alpha * p[i];
  }
  
  if (convergenceTest(sqrt(Abs(rr)),bNorm,iter,Abs(alpha*rr))) break;
  
  FORALL(r,i) r[i] -= alpha * aux[i];
  
  rrM1 = rr;
  rr   = dot(r,r);
  beta = rr / rrM1;
  FORALL(p,i) p[i] = r[i] + beta*p[i];
  }
  
  return iter <= maxIter;
  }
  */
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------


Bool LinSystem:: CGS(Vector<Num>& x, int maxIter)
{
    int iter, i;
    Num rho, rhoM1, alpha, beta, deltaE=0.0;
    Real bNorm;
    int dim = A->Dim();
    Vector<Num> p(dim),r(dim),r0(dim),u(dim),q(dim),v(dim);


    preCond->initialize(A,x,b);
    bNorm = normRhs();
    
    preCond->AMult(u,A,x);
    FORALL(r,i) r0[i] = r[i] = b[i]-u[i];	// initial residuals

    FORALL(p,i) p[i] = q[i] = 0.0;

    rho  = cdot(r0,r);
    beta = rho;

    if (convergenceTest(0, x, r, bNorm, 0.0))
    {  
	preCond->close(A,x,b); 
	return True; 
    }


    for (iter=1; iter<=maxIter; ++iter)
    {

	FORALL(u,i) u[i] = r[i] + beta*q[i];
	FORALL(p,i) p[i] = u[i] + beta*(q[i] + beta*p[i]);

	preCond->invert(q,A,p);
	preCond->AMult (v,A,q);

	alpha = rho/cdot(r0,v);

	FORALL(q,i) q[i] = u[i] - alpha*v[i];
	FORALL(u,i) u[i] += q[i];

	preCond->invert(v,A,u);
	preCond->AMult (u,A,v);

	FORALL(x,i) 
	{
	    x[i] += alpha*v[i];
	    r[i] -= alpha*u[i];
	}
	rhoM1 = rho;
	rho = cdot(r0,r);

	// deltaE = conj(alpha)*alpha*cdot(v,u);

	if (convergenceTest(iter, x, r, bNorm, Abs(deltaE))) break;

	beta = rho/rhoM1;
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------


Bool LinSystem:: BiCGStab(Vector<Num>& x, int maxIter)
{
    int iter, i;
    Num rho, rhoM1, omega, alpha, beta, deltaE=0.0;
    Real bNorm;
    int dim = A->Dim();

    Vector<Num> p(dim), r(dim), rb(dim), t(dim), 
    		v(dim), y(dim), z(dim);


    preCond->initialize(A,x,b);
    bNorm = normRhs();
    
    preCond->AMult(y,A,x);
    FORALL(r,i) rb[i] = r[i] = b[i]-y[i];	// initial residual
    rho = alpha = omega = 1.0;

    FORALL(p,i) p[i] = v[i] = 0.0;

    if (convergenceTest(0, x, r, bNorm, 0.0))
    { 
	preCond->close(A,x,b); 
	return True; 
    }


    for (iter=1; iter<=maxIter; ++iter)
    {
	rhoM1 = rho;
	rho  = cdot(rb,r);
	beta = rho/rhoM1 * alpha/omega;

	FORALL(p,i) p[i] = r[i] + beta*(p[i] - omega*v[i]);

	preCond->invert(y,A,p);
	preCond->AMult (v,A,y);

	alpha = rho/cdot(rb,v);

	FORALL(r,i) r[i] = r[i] - alpha*v[i];	// s = r
	preCond->invert(z,A,r);

	preCond->AMult(t,A,z);

	//preCond->invert(aux,A,t);		// omit this ?
	//omega = cdot(aux,z) / cdot(aux,aux);
	
	omega = cdot(t,z) / cdot(t,t);

	FORALL(x,i) 
	{
	    x[i] += alpha*y[i] + omega*z[i];
	    r[i]  = r[i] - omega*t[i];
	}

	//deltaE = conj(alpha)*alpha*cdot(y,v) + conj(omega)*omega*cdot(z,t) +
	//	 conj(alpha)*omega*cdot(y,t) + conj(omega)*alpha*cdot(z,v);

	if (convergenceTest(iter, x, r, bNorm, Abs(deltaE))) break;
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------

	// This is GMRes, the GPPS (General Purpose Problem Solver)

	// parts of the algorithm were copied from Andreas Hohmann, 
	// 			so blame HIM if it does not work!
	// maxOrtho: max. number of orthogonalization steps in inner iteration

		  // right preconditioning version


Bool LinSystem:: GMRes(Vector<Num>& x, int maxRestarts, int maxOrtho)
{

    int i, j, kMax, nRestart, gmResStatus;
    Real bNorm;

    Vector<Num> g(maxOrtho+1);
    Vector<Num> aux(A->Dim());
    Matrix<Num> R(maxOrtho,maxOrtho);

    Vector<Vector<Num>*> v(maxOrtho+1); 	// orthonormal system
    FORALL(v,i) v[i] = 0;
    v[1] = new Vector<Num>(A->Dim());

    preCond->initialize(A,x,b);
    bNorm = normRhs();
    					

    for (nRestart=0; nRestart<=maxRestarts; ++nRestart)
    {
	gmResStatus = GMResInnerIteration(x, maxOrtho, bNorm, g, R, v, aux, 
				     nRestart, &kMax);

	     // compute solution R^(-1)*g of least squares problem:

	GMResRSolve(g,R,g,kMax);

	FORALL(aux,j) aux[j] = 0.0;
	for (i=1; i<=kMax; i++) 
	    FORALL(aux,j) aux[j] += g[i] * (*v[i])[j];

	preCond->invert(*v[1],A,aux);
	FORALL(x,j) x[j] += (*v[1])[j];		// update solution

	if (gmResStatus == 1) break;		// convergence in inner iteration
	if (infoLinSystem) cout << "   restart";
    }

    preCond->close(A,x,b);
    FORALL(v,i) delete v[i];
    return nRestart <= maxRestarts;
}
//-------------------------------------------------------------------------


Bool LinSystem:: GMResInnerIteration(Vector<Num>& x, int maxOrtho, 
				     Real bNorm, Vector<Num>& g,
				     Matrix<Num>& R, Vector<Vector<Num>*> &v, 
				     Vector<Num>& aux, int nRestart, int* kMax)
{
    int i, j, k, iter;
    Num h, r;
    Real tmp;
    Vector<Num> c(maxOrtho), s(maxOrtho);
    Vector<Num> rDummy(1);

    preCond->AMult(aux,A,x);			
    FORALL(b,i) (*v[1])[i] = b[i] - aux[i];
    
    FORALL(g,i) g[i] = 0.0;
    g[1] = Norm(*v[1]);
    FORALL(*v[1],i) (*v[1])[i] /= g[1];		

    rDummy[1] = g[1];
    if (nRestart==0) convergenceTest(0, rDummy, rDummy, bNorm, 0.0);


    for (k=1; k<=maxOrtho; ++k)
    {
	*kMax = k;
	if (v[k+1]==0) v[k+1] = new Vector<Num>(A->Dim());

		    // next orthonormal vector v[k+1];
	
	preCond->invert(aux,A,*v[k]);
	preCond->AMult(*v[k+1],A,aux);
	
	for (i=1; i<=k; ++i) R(k,i) = dot(*v[i], *v[k+1]);
	for (i=1; i<=k; ++i)
	{
	    FORALL(*v[k+1],j) (*v[k+1])[j] -= R(k,i) * (*v[i])[j];
	}

	h = Norm(*v[k+1]);
	FORALL(*v[k+1],j) (*v[k+1])[j] /= h;


		// transform row k of R with previous Givens rotations:

	GMResQMult(k-1, R(k), c, s);

		  // compute new Givens coefficients c[k] and s[k];

    	r = R(k,k);
	tmp = sqrt(Abs(sqr(r)+sqr(h))); 	// !!!	
	c[k] =  r/tmp;
	s[k] = -h/tmp; 
	R(k,k) = tmp;

		// update constant vector g with new coefficients:
    	
    	GMResGivensMult(k, g.v, c, s);  
    	
	iter = (k<maxOrtho)?  k : -k;
	rDummy[1] = g[k+1];

	if (convergenceTest(iter, rDummy, rDummy, bNorm, 0.0))  return True;
    }
    return False;
}
//-------------------------------------------------------------------------

	//  apply the Givens rotations Q = F_k ... F_1 to the vector v


void LinSystem:: GMResQMult (int k, Num* v, Vector<Num>& c, Vector<Num>& s) 
{
    for (int j=1; j<=k; ++j) GMResGivensMult(j, v, c, s);
}

	//  apply the j-th Givens rotation F_j to the vector v

void LinSystem:: GMResGivensMult (int j, Num* v, Vector<Num>& c, 
				  Vector<Num>& s) 
{
    Num tmp = c[j]*v[j] - s[j]*v[j+1];
    v[j+1]  = s[j]*v[j] + c[j]*v[j+1];
    v[j]    = tmp;
}

	//  solve the triangular system v = R_k * g

void LinSystem:: GMResRSolve (Vector<Num>& lhs, Matrix<Num>& R, 
			      Vector<Num>& rhs, int k)
{
    for (int j=k; j>=1; --j) 
    {
    	lhs[j] = rhs[j];
    	for (int i=j+1; i<=k; ++i) lhs[j] -= R(i,j) * lhs[i];
    	lhs[j] /= R(j,j);
    }
}
//-------------------------------------------------------------------------


/*		// left preconditioning version of GMRes:


  int LinSystem:: GMRes(Vector<Num>& x, int maxRestarts, int maxOrtho)
  {
  
  int i, j, kMax, nRestart, gmResStatus;
  Real bNorm;
  
  Vector<Real> g(maxOrtho+1);
  Vector<Num>  aux(A->Dim());
  Matrix<Num>  R(maxOrtho,maxOrtho);
  
  Vector(NumVectorP) v(maxOrtho+1); 		// orthonormal system
  FORALL(v,i) v[i] = 0;
  v[1] = new Vector<Num>(A->Dim());
  
  preCond->initialize(A,x,b);
  bNorm = preCond->normOfRhs(A,b,aux);  // compute norm of preconditioned
  //   vector b for convergence test
  
  
  for (nRestart=0; nRestart<=maxRestarts; ++nRestart)
  {
  gmResStatus = GMResInnerIteration(x, maxOrtho, bNorm, g, R, v, aux, &kMax);
  
  // compute solution R^(-1)*g of least squares problem:
  
  GMResRSolve(g,R,g,kMax);
  for (i=1; i<=kMax; i++) 
  FORALL(x,j) x[j] += g[i] * (*v[i])[j];
  
  if (gmResStatus == 1) break;		// convergence in inner iteration
  if (infoLinSystem) cout << "   restart\n";
  }
  
  preCond->close(A,x,b);
  FORALL(v,i) delete v[i];
  return nRestart <= maxRestarts;
  }
  //-------------------------------------------------------------------------
  
  int LinSystem:: GMResInnerIteration (Vector<Num>& x, int maxOrtho, 
  Real bNorm, Vector<Real>& g,
  Matrix<Num>& R, Vector(NumVectorP)& v, 
  Vector<Num>& aux, int* kMax)
  {
  int i, j, k, iter;
  Num h, r;
  Real tmp, t;
  Vector<Real> c(maxOrtho), s(maxOrtho);
  
  preCond->AMult(aux,A,x);			
  FORALL(b,i) aux[i] = b[i] - aux[i];
  preCond->invert(*v[1],A,aux);	   // preconditioned residual -> v[1];
  
  FORALL(g,i) g[i] = 0.0;
  g[1] = Norm(*v[1]);
  FORALL(*v[1],i) (*v[1])[i] /= g[1];		
  
  (*convergenceTest)(Abs(g[1]), bNorm, 0, 0.0);
  
  
  for (k=1; k<=maxOrtho; ++k)
  {
  *kMax = k;
  if (v[k+1]==0) v[k+1] = new Vector<Num>(A->Dim());
  
  // next orthonormal vector v[k+1];
  
  preCond->AMult(aux,A,*v[k]);
  preCond->invert(*v[k+1],A,aux);
  
  for (i=1; i<=k; ++i) R(k,i) = dot(*v[i], *v[k+1]);
  for (i=1; i<=k; ++i)
  {
  FORALL(*v[k+1],j) (*v[k+1])[j] -= R(k,i) * (*v[i])[j];
  }
  
  h = Norm(*v[k+1]);
  FORALL(*v[k+1],j) (*v[k+1])[j] /= h;
  
  
  // transform row k of R with previous Givens rotations:
  
  GMResQMult(k-1, R(k), c, s);
  
  // compute new Givens coefficients c[k] and s[k];
  
  r = R(k,k);
  tmp = sqrt(sqr(r)+sqr(h)); 
  c[k] =  r/tmp;
  s[k] = -h/tmp; 
  R(k,k) = tmp;
  
  // update constant vector g with new coefficients:
  
  GMResGivensMult(k, g.v, c, s);  
  
  iter = (k<maxOrtho)?  k : -k;
  if ((*convergenceTest)(Abs(g[k+1]), bNorm, iter, 0.0))  return 1;
  }
  return 0;
  }
  */
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------


		// right preconditioning version of BiCG


Bool LinSystem:: BiCG (Vector<Num>& x, int maxIter)
{
    
    int   iter, i;
    Num   alpha, beta, rr, rrM1;
    Real bNorm, deltaE;
    int dim = A->Dim();

    Vector<Num> p(dim), pT(dim), r(dim), rT(dim), 
    		aux(dim), h(dim);
    
    preCond->initialize(A,x,b);
    
    bNorm = normRhs();
    
    preCond->AMult(r,A,x);
    FORALL(r,i) r[i] = b[i]-r[i];	// initial residual

    FORALL(p,i) p[i] = r[i];		// initial search direction
    
    FORALL(rT,i) rT[i] = conj(r[i]);	// initial bi-residual
    FORALL(pT,i) pT[i] = rT[i];		// initial bi-search direction
    
    rr = cdot(rT,r);

    if (convergenceTest(0, x, r, bNorm, 0.0))
    { 
	preCond->close(A,x,b); 
	return True; 
    }

    
    for (iter=1; iter<=maxIter; ++iter)
    {
	preCond->invert(h,A,p); 
	preCond->AMult(aux,A,h);

	alpha = rr/cdot(pT,aux);
	
	FORALL(x,i) x[i] += alpha * h[i];
	FORALL(r,i) r[i] -= alpha * aux[i];
	
	deltaE = Abs(conj(alpha)*alpha * cdot(h,aux));

	preCond->ATMult(h,A,pT);
	preCond->invert(aux,A,h); 

	FORALL(rT,i) rT[i] -= alpha * aux[i];
	
	rrM1 = rr;
	rr   = cdot(rT,r);
	
	if (convergenceTest(iter, x, r, bNorm, deltaE)) break;
	
	beta = rr / rrM1;
	FORALL(p,i)  p[i]  = r[i]  + beta*p[i];
	FORALL(pT,i) pT[i] = rT[i] + beta*pT[i];
    }
    
    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------


Bool LinSystem:: DdCG (Vector<Num>& x, int maxIter)
{
    int iter, i;
    Num s, sM1, sB, f;
    Real bNorm;
    Vector<Num> dx(A->Dim()), r(A->Dim()), aux(A->Dim());

    preCond->initialize(A,x,b);

    preCond->invert(aux,A,b);
    bNorm = normRhs();
    
    preCond->AMult(r,A,x);
    FORALL(r,i) aux[i] = b[i]-r[i];	
    preCond->invert(r,A,aux);   		// initial residual

    s = sM1 = dot(r,r);
    
    if (convergenceTest(0, x, r, bNorm, 0.0))
    { 
	preCond->close(A,x,b); 
	return True; 
    }

    preCond->invert(aux,A,r);  
    preCond->ATMult(dx,A,aux); 
    

    for (iter=1; iter<=maxIter; ++iter)
    {
	sB = dot(dx,dx);
	f  = s/sB;
	FORALL(x,i) x[i] += f*dx[i];

	preCond->AMult(r,A,x);
	FORALL(aux,i) aux[i] = b[i]-r[i];
	preCond->invert(r,A,aux);  		// new residual

	sM1 = s;
	s = dot(r,r);

	if (convergenceTest(iter, x, r, bNorm, 0.0)) break;

	preCond->invert(aux,A,r);  	
	preCond->ATMult(r,A,aux); 
	f = s/sM1;
	FORALL(dx,i) dx[i] = r[i] + f*dx[i];
    }

    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------
//-------------------------------------------------------------------------


Bool LinSystem:: Relax(Vector<Num>& x, int maxIter)
{
    int   i, iter;
    Real bNorm, eAe;
    Vector<Num> r(A->Dim()), e(A->Dim());

    preCond->initialize(A,x,b);
    bNorm = normRhs();

    for (iter=0; iter<=maxIter; ++iter)
    {
	preCond->AMult(e,A,x);
	FORALL(r,i) r[i] = b[i] - e[i];		// new residual

	preCond->invert(e,A,r);			// e is the defect correction
	FORALL(x,i) x[i] += e[i]; 

	eAe = Abs(dot(e,r));			// eAe = A**(-1)*r * r

	if (convergenceTest(iter, x, r, bNorm, 0.0, eAe))  break;
    }
    preCond->close(A,x,b);
    return iter <= maxIter;
}
//-------------------------------------------------------------------------


syntax highlighted by Code2HTML, v. 0.9.1