#include "BSprivate.h"

/*@ BSpar_sym_solve - Solve a symmetric positive definite system of equations
                  using conjugate gradients preconditioned by one of several
                  preconditioners.  The rhs can be a block of vectors.  The
				  user should not call this function directly, but BSpar_solve().

    Input Parameters:
.   BS - the number of vectors in the RHS
.   A - a sparse matrix
.   fact_A - the incomplete factored version of A, if any
.   comm_A - the communication structure for A
.   in_rhs - the contiguous block of vectors forming the rhs
.   pre_option - the preconditioner to use
$        PRE_ICC: incomplete Cholesky factorization
$        PRE_ILU: incomplete LU factorization
$        PRE_SSOR: Successive over relaxation
$        PRE_BJACOBI: Block Jacobi
.   err_tol - the tolerance to which to solve the problem
              stop if the estimated norm of the residual divided by
              the norm of the rhs is less than err_tol
.   max_iter - the maximum number of iterations to take
.   residual - the final computed residual
.   guess - if TRUE, then initialize out_x to 0, otherwise the program
            assumes that out_x contains an initial guess
.   procinfo - the usual processor stuff

    Output Parameters:
.   out_x - the contiguous block of vectors containing the solution

    Returns:
    The number of iterations or a negative number indicating the number
    of iterations prior to finding that the matrix (or preconditioner) 
    is not positive definite.

    Notes:
    The preconditioners must be computed prior to calling BSpar_isolve.
    For more information on the preconditioners, see the manual.
 @*/
int BSpar_sym_solve(int BS, BSpar_mat *A, BSpar_mat *fact_A, BScomm *comm_A,
	FLOAT *in_rhs, FLOAT *out_x, int pre_option, FLOAT err_tol, int max_iter,
	FLOAT *residual, int guess, BSprocinfo *procinfo)
{
	int	i, j;
	int	cur_step, cur_phase;
	int	done;
	FLOAT	*resid, *z, *p;
	FLOAT	*cg_beta, *cg_alpha;
	FLOAT *bnorm;
	FLOAT *x;
	FLOAT *rhs;
	FLOAT tval;
	FLOAT	*t_x, *t_rhs;
	int	n;

	if(!A->symmetric) {
		MY_SETERRCN(PAR_SOLVE_ERROR,"Trying to solve nonsymmetric system with CG\n");
	}

	n = A->num_rows;

	/* reorder the rhs */
	MY_MALLOCN(rhs,(FLOAT *),sizeof(FLOAT)*n*BS,1);
	for (i=0;i<BS;i++) {
		BSperm_dvec(&(in_rhs[i*n]),&(rhs[i*n]),A->perm); CHKERRN(0);
	}

	/* allocate space for x */
	MY_MALLOCN(x,(FLOAT *),sizeof(FLOAT)*n*BS,2);

	/* allocate space for cg vectors */
	MY_MALLOCN(resid,(FLOAT *),sizeof(FLOAT)*n*BS,3);
	MY_MALLOCN(p,(FLOAT *),sizeof(FLOAT)*n*BS,4);
	MY_MALLOCN(z,(FLOAT *),sizeof(FLOAT)*n*BS,5);
	MY_MALLOCN(cg_alpha,(FLOAT *),sizeof(FLOAT)*BS,6);
	MY_MALLOCN(cg_beta,(FLOAT *),sizeof(FLOAT)*BS,7);
	MY_MALLOCN(bnorm,(FLOAT *),sizeof(FLOAT)*BS,8);

	/* form the initial guess */
	if (guess) {
		for (i=0;i<n*BS;i++) {
			x[i] = 0.0;
		}
	} else {
		for (i=0;i<BS;i++) {
			BSperm_dvec(&(out_x[i*n]),&(x[i*n]),A->perm); CHKERRN(0);
		}
	}

	/* scale the rhs and x */
	if(A->scale_diag!=NULL) {
		for (i=0;i<BS;i++) {
			t_rhs = &(rhs[i*n]);
			t_x = &(x[i*n]);
			for (j=0;j<n;j++) {
				tval = sqrt(A->scale_diag[j]);
				t_rhs[j] /= tval;
				t_x[j] *= tval;
			}
		}
	}

	/* get the norm of B */
	BSpar_bip(n,rhs,rhs,BS,bnorm,procinfo); CHKERRN(0);
	done = TRUE;
	for (i=0;i<BS;i++) {
		bnorm[i] = sqrt(bnorm[i]);
		if (bnorm[i] != 0.0) done = FALSE;
	}
	cur_step = 0;
	cur_phase = 0;
	while ((!done) && (cur_step < max_iter)) {
		switch(BSpar_bcg(n,rhs,x,resid,z,p,cg_beta,cg_alpha,
			&cur_step,&cur_phase,BS,procinfo)) {
			case CG_MATVECR: {
				BStri_mult(A,comm_A,NULL,NULL,x,resid,NULL,NULL,0.0,BS,
					procinfo); CHKERRN(0);
				break;
			}
			case CG_MATVECZ: {
				BStri_mult(A,comm_A,NULL,NULL,p,z,NULL,NULL,0.0,BS,
					procinfo); CHKERRN(0);
				break;
			}
			case CG_MSOLVE: {
				BSpar_bip(n,resid,resid,BS,residual,procinfo); CHKERRN(0);
				for (i=0;i<BS;i++) {
					if (bnorm[i] != 0.0) {
						residual[i] = sqrt(residual[i])/bnorm[i];
					} else {
						residual[i] = sqrt(residual[i]);
					}
				}
				done = TRUE;
				for (i=0;i<BS;i++) {
					if (residual[i] >= err_tol) done = FALSE;
				}
				if (!done) {
					for (i=0;i<n*BS;i++) {
						z[i] = resid[i];
					}
					BStri_solve(A,fact_A,comm_A,z,pre_option,BS,procinfo);
					CHKERRN(0);
				}
				break;
			}
			default: {
				return(-cur_step);
			}
		}
		CHKERRN(0);
	}

	BStri_mult(A,comm_A,NULL,NULL,x,resid,NULL,NULL,0.0,BS,procinfo); CHKERRN(0);
	for (i=0;i<n*BS;i++) {
		resid[i] = rhs[i]-resid[i];
	}
	BSpar_bip(n,resid,resid,BS,residual,procinfo); CHKERRN(0);
	for (i=0;i<BS;i++) {
		if (bnorm[i] != 0.0) {
			residual[i] = sqrt(residual[i])/bnorm[i];
		} else {
			residual[i] = sqrt(residual[i]);
		}
	}

	MY_FREE(z);
	MY_FREE(p);
	MY_FREE(resid);
	MY_FREE(cg_alpha);
	MY_FREE(cg_beta);
	MY_FREE(bnorm);

	/* Rescale X */
	if(A->scale_diag!=NULL) {
		for (i=0;i<BS;i++) {
			t_x = &(x[i*n]);
			for (j=0;j<n;j++) {
				t_x[j] /= sqrt(A->scale_diag[j]);
			}
		}
	}

	/* reorder the solution vector */
	for (i=0;i<BS;i++) {
		BSiperm_dvec(&(x[i*n]),&(out_x[i*n]),A->perm); CHKERRN(0);
	}

	MY_FREE(rhs);
	MY_FREE(x);

	/* return the number of iterations */
	return(cur_step);
}


syntax highlighted by Code2HTML, v. 0.9.1