#include "pargrid.h"

/*+ worker - Solve a shifted indefinite problem associated with a grid

     Input Parameters:
     grid - the given grid
     procinfo - the processor information (in BlockSolve format)

 +*/

void worker(par_grid *grid, BSprocinfo *procinfo)
{
	int	i, j, n, bs, n_solves = 1, write_option = 0;
	int	A_global_nnz, B_global_nnz, num_iter;
	BSspmat *A, *B;
	BSpar_mat *pA, *f_pA;
	BScomm *Acomm, *f_comm;
	FLOAT shifted_diag, *residual;
	FLOAT	*x, *rhs, shift, t;

	/* set shift to use */
	shift = grid->shift;

	/* number grid to use in matrix assembly */
	num_grid3d(grid,procinfo);

	/* now call the routines to set up the matrix A */
	A = get_mat3d(grid,procinfo);

	/* Set symmetry and storage scheme to be used */
	BSset_mat_symmetric(A,grid->symmetric);
	BSset_mat_icc_storage(A,grid->icc_storage);

	/* write out matrix */
	if(write_option) write_mat_matlab("MATA.m",A,procinfo);

	/* now set up the matrix B */
	grid->positive = TRUE;
	B = get_mat3d(grid,procinfo);

	/* write out matrix */
	if(write_option) write_mat_matlab("MATB.m",B,procinfo);

	/* Set symmetry and storage scheme to be used */
	BSmat_subtract(A,B,shift);

	/* write out matrix */
	if(write_option) write_mat_matlab("MATC.m",A,procinfo);

	/* permute the matrix */
	pA = BSmain_perm(procinfo,A); CHKERR(0);

	/* count nnzs for display */
	A_global_nnz = 2*pA->local_nnz - pA->num_rows;
	GISUM(&A_global_nnz,1,&i,procinfo->procset);
	if(procinfo->my_id==0) {
		printf("o  ");
		printf("Number of nonzeros = %d\n",A_global_nnz);
	}

	/* diagonally scale the matrix */
	if(procinfo->scaling) {
		BSscale_diag(pA,pA->diag,procinfo); CHKERR(0);
	}

	/* set up the communication structure for triangular matrix solution */
	Acomm = BSsetup_forward(pA,procinfo); CHKERR(0);

	/* get a copy of the sparse matrix */
	f_pA = BScopy_par_mat(pA); CHKERR(0);

	/* set up a communication structure for factorization */
	f_comm = BSsetup_factor(f_pA,procinfo); CHKERR(0);

	bs = procinfo->num_rhs;
	/* set up block communication if requested */
	BSsetup_block(pA,Acomm,bs,procinfo);

	/* shifted_diag is the initial diagonal */
	shifted_diag = 1.0;

	/* factor the matrix until successful */
	while (BSfactor(f_pA,f_comm,procinfo) != 0) {
		CHKERR(0);
		/* recopy the nonzeroes */
		BScopy_nz(pA,f_pA); CHKERR(0);
		/* increment the diagonal shift */
		shifted_diag += 0.1;
		BSset_diag(f_pA,shifted_diag,procinfo); CHKERR(0);
	}
	CHKERR(0);

	if(procinfo->my_id==0) {
		printf("o  ");
		printf("Solving the same linear system %d times with differing RHSs\n",
			n_solves);
		printf("o  ");
		printf("Shift required for factorization = %f\n",shifted_diag);
	}
	srand48((long)(11311));
	for (j=0; j<n_solves; j++) {

		/* set up the rhs and the x vector */
		n = A->num_rows;
		rhs = (FLOAT *) MALLOC(sizeof(FLOAT)*bs*n);
		x = (FLOAT *) MALLOC(sizeof(FLOAT)*bs*n);
		residual = (FLOAT *) MALLOC(sizeof(FLOAT)*procinfo->num_rhs);
		t = A->global_num_rows;
		t = 1.0/sqrt(t);
		for (j=0; j<bs; j++) {
			for (i=0; i<n; i++) {
				rhs[i+j*n] = t*j*i + drand48();
				x[i+j*n] = 0.0;
			}
		}
	
		/* write out rhs */
		if(write_option) write_vec_matlab("RHS.m",rhs,A,procinfo);

		/* solve it */
		BSctx_set_max_it(procinfo,200);
		BSctx_set_guess(procinfo,TRUE);
		BSctx_set_tol(procinfo,1.0e-7);
		BSctx_set_restart(procinfo,20);
		num_iter = BSpar_solve(pA,f_pA,Acomm,rhs,x,residual,procinfo);
			CHKERR(0);

		if (procinfo->my_id == 0) {
			printf("o  ");
			printf("Took %d iterations: residuals = ",num_iter);
			for (i=0; i<bs; i++)
				printf("%e ",residual[i]);
			printf("\n");
		}

		/* write out ans */
		if(write_option) write_vec_matlab("ANS.m",x,A,procinfo);

		FREE(rhs);
		FREE(x);
		FREE(residual);
	}

	/* free the grid */
	free_grid(grid);

	/* free the spmat */
	BSfree_easymat(A);
	BSfree_easymat(B);

	/* free the par mat, etc. */
	BSfree_par_mat(pA);
	BSfree_copy_par_mat(f_pA);
	BSfree_comm(Acomm);
	BSfree_comm(f_comm);
}


syntax highlighted by Code2HTML, v. 0.9.1