#include "BSprivate.h"

void BSbuild_solution(int i,int n,int BS,FLOAT *x,FLOAT *y,
	 FLOAT **v,FLOAT ***h,FLOAT **s,BSprocinfo *procinfo)
{
	int ii, j, q;
	FLOAT temp;

	for (q=0;q<BS;q++) {
		y[i] = s[q][i] / h[q][i][i];
		for (ii=i-1;ii>=0;ii--) {
			temp = s[q][ii];
			for (j=ii+1;j<=i;j++) {
				temp -= h[q][ii][j] * y[j];
			}
			y[ii] = temp / h[q][ii][ii];
		}
		for (ii=0;ii<=i;ii++) {
			for (j=0;j<n;j++) {
		  		x[q*n+j] += v[ii][q*n+j] * y[ii];
			}
		}
	}
	MLOG_flop((2*BS*((i+1)*n + (i*(i+1))/2)));
}

void BSupdate_hessenberg(int i,int n,int BS,FLOAT *w,FLOAT **v,
		FLOAT ***h, FLOAT **sn, FLOAT **cs, FLOAT **s,BSprocinfo *procinfo)
{
	int j, k, q;
	FLOAT temp, *wdot;
	int f1, f2;

	MY_MALLOC(wdot,(FLOAT *),sizeof(FLOAT)*BS,1);

	f1 = 0; f2 = 0;
	for (k=0;k<=i;k++) {
		BSpar_bip(n,w,v[k],BS,wdot,procinfo); CHKERR(0);
		for (q=0;q<BS;q++) {
			h[q][k][i] = wdot[q];
			/* printf("h[%d][%d][%d] = %f\n",q,k,i,h[q][k][i]); */
			for (j=0;j<n;j++) {
				/* printf("w(%d) = %e; v(%d) = %e;\n",j+1,w[j],j+1,v[k][j]); */
				w[q*n+j] -= h[q][k][i] * v[k][q*n+j];
			}
		}
	}
	MLOG_flop((2*BS*n));
	BSpar_bip(n,w,w,BS,wdot,procinfo); CHKERR(0);
	for (q=0;q<BS;q++) {
		/* for (j=0;j<n*BS;j++) {
			printf("ww(%d) = %e;\n",j+1,w[j]);
		} */
		h[q][i+1][i] = sqrt(wdot[q]);
		/*
		printf("new subdiagonal h[%d][%d][%d] = %e\n",
				q,i+1,i,h[q][i+1][i]);
		*/
	}
	/*
	printf("H[%d] = \n",i);
	for (j=0;j<=i+1;j++) {
		for (k=0;k<=i;k++) {
			printf("%f ",h[j][k]);
		}
		printf("\n");
	}
	printf("\n");
	*/
	for (q=0;q<BS;q++) {
		for (j=0;j<n;j++) {
			v[i+1][q*n+j] = w[q*n+j]/h[q][i+1][i];
		}
	}
	MLOG_flop((BS*n));
	for (q=0;q<BS;q++) {
		for (k=0;k<i;k++) {
			temp      =  cs[q][k] * h[q][k][i] + sn[q][k] * h[q][k+1][i];
			h[q][k+1][i] = -sn[q][k] * h[q][k][i] + cs[q][k] * h[q][k+1][i];
			h[q][k][i]   =  temp;
		}
		/* compute new c, s */
		temp = sqrt(h[q][i][i] * h[q][i][i] + h[q][i+1][i] * h[q][i+1][i]);
		if(temp==0.0) printf("error: zero sqrt in BSupdate_hessenberg\n");
		cs[q][i]  =  h[q][i][i] / temp;
		sn[q][i]  =  h[q][i+1][i] / temp;
		/*
		printf("cs = %e; sn = %e;\n",cs[i],sn[i]);
		printf("old_s = %e; \n",s[i]);
		*/
		temp   =  cs[q][i] * s[q][i];
		s[q][i+1] = -sn[q][i] * s[q][i];
		s[q][i]   =  temp;
		/* printf("new_s = %e; new_ss = %e;\n",s[i],s[i+1]); */
		h[q][i][i] = cs[q][i] * h[q][i][i] + sn[q][i] * h[q][i+1][i];
		h[q][i+1][i] = 0.0;
	}
	MLOG_flop((BS*6*i));
	MY_FREE(wdot);
}

/*@ BSpar_gmres - Solve a nonsymmetric system of equations
                  using gmres preconditioned by one of several
                  preconditioners.  The rhs can be a block of vectors.

    Input Parameters:
.   BS - the number of vectors in the RHS
.   A - a sparse matrix
.   fact_A - the incomplete factored version of A, if any
.   comm_A - the communication structure for A
.   in_rhs - the contiguous block of vectors forming the rhs
.   pre_option - the preconditioner to use
                 PRE_ICC: incomplete Cholesky factorization
                 PRE_ILU: incomplete LU factorization
                 PRE_SSOR: Successive over relaxation
                 PRE_BJACOBI: Block Jacobi
.   err_tol - the tolerance to which to solve the problem
              stop if the estimated norm of the residual divided by
              the norm of the rhs is less than err_tol
.   max_iter - the maximum number of iterations to take
.   residual - the final computed residual
.   guess - if TRUE, then initialize out_x to 0, otherwise the program
            assumes that out_x contains an initial guess
.   procinfo - the usual processor stuff

    Output Parameters:
.   out_x - the contiguous block of vectors containing the solution

    Returns:
    The number of iterations or a negative number indicating the number
    of iterations prior to finding that the matrix (or preconditioner) 
    is not positive definite.

    Notes:
    The preconditioners must be computed prior to calling BSpar_isolve.
    For more information on the preconditioners, see the manual.
 @*/
int BSpar_gmres(int BS, BSpar_mat *A, BSpar_mat *fact_A, BScomm *comm_A,
	FLOAT *in_rhs, FLOAT *out_x, int pre_option, int restart, FLOAT err_tol,
	int max_iter, FLOAT *residual, int guess, BSprocinfo *procinfo)
{
	int	i, j, k, n, m;
	int	cur_step, done;
	FLOAT *rhs, *x, *bnorm;
	FLOAT *resid, *w;
	FLOAT ***h, **v;
	FLOAT **cs, **sn, **s, *y;
	FLOAT tval;
	FLOAT *t_x, *t_rhs;
	FLOAT error, temp;

	n = A->num_rows;
	m = restart;

	/* reorder the rhs */
	MY_MALLOCN(rhs,(FLOAT *),sizeof(FLOAT)*n*BS,1);
	for (i=0;i<BS;i++) {
		BSperm_dvec(&(in_rhs[i*n]),&(rhs[i*n]),A->perm); CHKERRN(0);
	}

	/* allocate space for x */
	MY_MALLOCN(x,(FLOAT *),sizeof(FLOAT)*n*BS,2);
	MY_MALLOCN(bnorm,(FLOAT *),sizeof(FLOAT)*BS,3);

	/* allocate space for gmres vectors */
	MY_MALLOCN(resid,(FLOAT *),sizeof(FLOAT)*n*BS,4);
	MY_MALLOCN(w,(FLOAT *),sizeof(FLOAT)*n*BS,5);

	MY_MALLOCN(v,(FLOAT **),sizeof(FLOAT)*(m+1),6);
	for (i=0;i<(m+1);i++) {
		MY_MALLOCN(v[i],(FLOAT *),sizeof(FLOAT)*n*BS,6);
	}
	MY_MALLOCN(h,(FLOAT ***),sizeof(FLOAT)*BS,7);
	for (i=0;i<BS;i++) {
		MY_MALLOCN(h[i],(FLOAT **),sizeof(FLOAT)*(m+1),7);
		for (j=0;j<(m+1);j++) {
			MY_MALLOCN(h[i][j],(FLOAT *),sizeof(FLOAT)*m,7);
		}
	}
	MY_MALLOCN(cs,(FLOAT **),sizeof(FLOAT)*BS,8);
	MY_MALLOCN(sn,(FLOAT **),sizeof(FLOAT)*BS,9);
	MY_MALLOCN(s,(FLOAT **),sizeof(FLOAT)*BS,10);
	for (i=0;i<BS;i++) {
		MY_MALLOCN(cs[i],(FLOAT *),sizeof(FLOAT)*m,8);
		MY_MALLOCN(sn[i],(FLOAT *),sizeof(FLOAT)*m,9);
		MY_MALLOCN(s[i],(FLOAT *),sizeof(FLOAT)*(m+1),10);
	}
	MY_MALLOCN(y,(FLOAT *),sizeof(FLOAT)*m,11);

	/* form the initial guess */
	if (guess) {
		for (i=0;i<n*BS;i++) {
			x[i] = 0.0;
		}
	} else {
		for (i=0;i<BS;i++) {
			BSperm_dvec(&(out_x[i*n]),&(x[i*n]),A->perm); CHKERRN(0);
		}
	}

	/* scale the rhs and x */
	if(A->scale_diag!=NULL) {
		for (i=0;i<BS;i++) {
			t_rhs = &(rhs[i*n]);
			t_x = &(x[i*n]);
			for (j=0;j<n;j++) {
				tval = sqrt(fabs(A->scale_diag[j]));
				t_rhs[j] /= tval;
				t_x[j] *= tval;
			}
		}
	}

	/* get the norm of B */
	BSpar_bip(n,rhs,rhs,BS,bnorm,procinfo); CHKERRN(0);
	done = TRUE;
	for (i=0;i<BS;i++) {
		bnorm[i] = sqrt(bnorm[i]);
		/* if(procinfo->my_id==0) printf("bnorm = %e\n",bnorm[i]); */
		if (bnorm[i] != 0.0) done = FALSE;
	}
	cur_step = 0;
	while (!done) {
		/* compute initial residual */
		BStri_mult(A,comm_A,NULL,NULL,x,resid,NULL,NULL,0.0,
			BS,procinfo); CHKERRN(0);
		for (i=0;i<n*BS;i++) {
			resid[i] = rhs[i]-resid[i];
		}
		BStri_solve(A,fact_A,comm_A,resid,pre_option,BS,procinfo); CHKERRN(0);
		BSpar_bip(n,resid,resid,BS,residual,procinfo); CHKERRN(0);
		/* form initial v */
		for (i=0;i<BS;i++) {
			residual[i] = sqrt(residual[i]);
			s[i][0] = residual[i];
			temp = residual[i];
			if(temp==0.0) temp = 1.0;
			for (j=0;j<n;j++) {
				k = i*n+j;
				v[0][k] = resid[k]/temp;
			}
			if (bnorm[i] != 0.0) {
				residual[i] /= bnorm[i];
			}
		}
		done = TRUE;
		for (i=0;i<BS;i++) {
			if (residual[i] >= err_tol) done = FALSE;
			/* printf("initial residual(%d) = %e\n",i,residual[i]); */
		}
		/* GMRES iteration cycle */
		for (i=0;i<m;i++) {
			cur_step++;
			/*printf("gmres step = %d, iteration number = %d\n",i+1,cur_step);*/
			/*
			BSiperm_dvec(v[i],in_rhs,A->perm); CHKERRN(0);
			for (j=0;j<n;j++) {
				printf("ILUv%1d(%d) = %4.14e;\n",i+1,j+1,in_rhs[j]); 
			}
			*/
			BStri_mult(A,comm_A,NULL,NULL,v[i],w,NULL,NULL,0.0,BS,
					procinfo); CHKERRN(0);
			/*
			BSiperm_dvec(w,in_rhs,A->perm); CHKERRN(0);
			for (j=0;j<n;j++) {
				printf("ILUw%1d(%d) = %4.14e;\n",i+1,j+1,in_rhs[j]); 
			}
			*/
			BStri_solve(A,fact_A,comm_A,w,pre_option,BS,
					procinfo); CHKERRN(0);
			BSupdate_hessenberg(i,n,BS,w,v,h,sn,cs,s,procinfo);
			done = TRUE;
			for (j=0;j<BS;j++) {
				temp = bnorm[j];
				if (temp==0.0) temp = 1.0;
				error = fabs(s[j][i+1])/temp;
				/*
				if(procinfo->my_id==0) 
					printf("error(%d,%d) = %e;\n",cur_step,j+1,error);
				*/
				if (error >= err_tol) done = FALSE;
			}
			if(cur_step >= max_iter) done = TRUE;
			if(done) {
				BSbuild_solution(i,n,BS,x,y,v,h,s,procinfo);
				/*
				if(procinfo->my_id==0)
					printf("done in %d iterations\n",cur_step);
				*/
				break;
			}
		}
		if (!done) {
			BSbuild_solution(m-1,n,BS,x,y,v,h,s,procinfo);
		}
	}
	/* compute final residual */
	BStri_mult(A,comm_A,NULL,NULL,x,resid,NULL,NULL,0.0,BS,
			procinfo); CHKERRN(0);
	for (i=0;i<n*BS;i++) {
		resid[i] = rhs[i]-resid[i];
	}
	BSpar_bip(n,resid,resid,BS,residual,procinfo); CHKERRN(0);
	for (i=0;i<BS;i++) {
		residual[i] = sqrt(residual[i]);
		if (bnorm[i] != 0.0) {
			residual[i] /= bnorm[i];
		}
		/*
		if(procinfo->my_id==0)
			printf("final residual[%d] = %e\n",i,residual[i]);
		*/
	}

	MY_FREE(bnorm);
	MY_FREE(resid);
	MY_FREE(w);
	for (i=0;i<BS;i++) {
		for (j=0;j<(m+1);j++) {
			MY_FREE(h[i][j]);
		}
		MY_FREE(h[i]);
	}
	MY_FREE(h);
	for (i=0;i<(m+1);i++) {
		MY_FREE(v[i]);
	}
	MY_FREE(v);
	for (i=0;i<BS;i++) {
		MY_FREE(cs[i]);
		MY_FREE(sn[i]);
		MY_FREE(s[i]);
	}
	MY_FREE(cs);
	MY_FREE(sn);
	MY_FREE(s);
	MY_FREE(y);
	MY_FREE(rhs);

	/* Rescale X */
	if(A->scale_diag!=NULL) {
		for (i=0;i<BS;i++) {
			t_x = &(x[i*n]);
			for (j=0;j<n;j++) {
				t_x[j] /= sqrt(fabs(A->scale_diag[j]));
			}
		}
	}

	/* reorder the solution vector */
	for (i=0;i<BS;i++) {
		BSiperm_dvec(&(x[i*n]),&(out_x[i*n]),A->perm); CHKERRN(0);
	}

	MY_FREE(x);

	/* return the number of iterations */
	return(cur_step);
}



syntax highlighted by Code2HTML, v. 0.9.1