#include "BSprivate.h"

/*@ BSb_forward - Forward triangular matrix multiplication on a 
                  block of vectors

    Input Parameters:
.   A - The sparse matrix
.   x - The contiguous block of input vectors
.   comm - The communication structure for A
.   block_size - the number of input vectors
.   procinfo - the usual processor information

    Output Parameters:
.   b - on exit these vectors contain A*x

    Returns:
    void

 @*/
void BSb_forward(BSpar_mat *A, FLOAT *x, FLOAT *b, BScomm *comm,
		int block_size, BSprocinfo *procinfo)
{
	BMphase *to_phase, *from_phase;
	BMmsg *msg;
	int	i, j, k, n;
	int	cl_ind, in_ind;
	int	count, size, ind, num_cols;
	int *row;
	FLOAT *nz;
	BScl_2_inode *clique2inode;
	BSnumbering *color2clique;
	BSinode *inodes;
	int	*data_ptr, msg_len;
	FLOAT *msg_buf, *matrix;
	FLOAT *work;
	FLOAT *bptr, *xptr, *wptr;
	FLOAT **boff, **xoff;
	char UP = 'L';
	char TR = 'N';
	char ND = 'N';
	char SIDE = 'L';
	int	ione = 1;
	FLOAT one = 1.0;
	FLOAT zero = 0.0;

	if((!A->icc_storage)||(procinfo->single)) {
		/* No ILU version or single version so call BSforward BS times */
		n = A->num_rows;
		for (i=0;i<block_size;i++) {
			if(procinfo->single) {
				BSforward1(A,&(x[n*i]),&(b[n*i]),comm,procinfo); CHKERR(0);
			} else {
				BSforward(A,&(x[n*i]),&(b[n*i]),comm,procinfo); CHKERR(0);
			}
		}
		return;
	}

	color2clique = A->color2clique;
	clique2inode = A->clique2inode;
	inodes = A->inodes->list;

	/* get some work space */
	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows*block_size,1);

	/* calculate b and x offsets */
	MY_MALLOC(boff,(FLOAT **),sizeof(FLOAT *)*block_size,1);
	MY_MALLOC(xoff,(FLOAT **),sizeof(FLOAT *)*block_size,1);
	for (i=0;i<block_size;i++) {
		boff[i] = &(b[i*A->num_rows]);
		xoff[i] = &(x[i*A->num_rows]);
	}

	/* post for all messages */
	BMinit_comp_msg(comm->from_msg,procinfo); CHKERR(0);

	if (A->save_diag == NULL) {
		/* because we know the diagonal is ones, initialize b to x */
		for (i=0;i<block_size;i++) {
			bptr = boff[i];
			xptr = xoff[i];
			for (j=0;j<A->num_rows;j++) bptr[j] = xptr[j];
		}
	} else {
		for (i=0;i<block_size;i++) {
			bptr = boff[i];
			xptr = xoff[i];
			for (j=0;j<A->num_rows;j++) bptr[j] = A->save_diag[j]*xptr[j];
		}
	}

	/* now do this phase by phase */
	for (i=0;i<color2clique->length-1;i++) {
		/* first send my messages */
		to_phase = BMget_phase(comm->to_msg,i); CHKERR(0);
		msg = NULL;
		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
			CHKERR(0);
			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
			count = 0;
			for (j=0;j<block_size;j++) {
				wptr = &(msg_buf[j*msg_len]);
				xptr = xoff[j];
				for (k=0;k<msg_len;k++) {
					wptr[k] = xptr[data_ptr[k]];
				}
			}
			BMsendf_msg(msg,procinfo); CHKERR(0);
		}
		CHKERR(0);
	}

	/* do some local work */
	for (i=0;i<color2clique->length-1;i++) {
		for (cl_ind=color2clique->numbers[i];
			cl_ind<color2clique->numbers[i+1];cl_ind++) {
			if (procinfo->my_id == clique2inode->proc[cl_ind]) {
				/* first, multiply the clique */
				/* only do the strictly lower triangular part */
				/* we ASSUME the diagonal is all 1's */
				size = clique2inode->d_mats[cl_ind].size;
				ind = clique2inode->d_mats[cl_ind].local_ind;
				matrix = clique2inode->d_mats[cl_ind].matrix;
				j = size-1;
				matrix++;
				if (size > 1) {
					nz = work;
					for (k=0;k<block_size;k++) {
						DCOPY(&j,&(xoff[k][ind]),&ione,nz,&ione);
						nz += j;
					}
					DTRMM(&SIDE,&UP,&TR,&ND,&j,&block_size,&one,matrix,&size,
						work,&j);
					nz = work;
					for (k=0;k<block_size;k++) {
						DAXPY(&j,&one,nz,&ione,&(boff[k][ind+1]),&ione);
						nz += j;
					}
				}

				/* now, multiply the inodes */
				for (in_ind=clique2inode->inode_index[cl_ind];
					in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
					row = inodes[in_ind].row_num;
					nz = inodes[in_ind].nz;
					size = inodes[in_ind].length;
					num_cols = inodes[in_ind].num_cols;
					if (size > 0) {
						DGEMM(&TR,&TR,&size,&block_size,&num_cols,&one,
							nz,&size,&(x[ind]),&(A->num_rows),&zero,work,&size);
						for (j=0;j<block_size;j++) {
							bptr = boff[j];
							wptr = &(work[j*size]);
							for (k=0;k<size;k++) {
								bptr[row[k]] += wptr[k];
							}
						}
					}
					ind += num_cols;
				}
			}
		}
	}

	/* receive my messages and do non-local work */
	for (i=0;i<color2clique->length-1;i++) {
		from_phase = BMget_phase(comm->from_msg,i); CHKERR(0);
		while ((msg = BMrecv_msg(from_phase)) != NULL) {
			CHKERR(0);
			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
			msg_len = BMget_msg_size(msg); CHKERR(0);
			msg_len /= block_size;
			count = 0;
			for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
				for (in_ind=clique2inode->inode_index[cl_ind];
					in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
					row = inodes[in_ind].row_num;
					nz = inodes[in_ind].nz;
					size = inodes[in_ind].length;
					num_cols = inodes[in_ind].num_cols;
					if (size > 0) {
						DGEMM(&TR,&TR,&size,&block_size,&num_cols,&one,
							nz,&size,&(msg_buf[count]),&msg_len,
							&zero,work,&size);
						for (j=0;j<block_size;j++) {
							bptr = boff[j];
							wptr = &(work[j*size]);
							for (k=0;k<size;k++) {
								bptr[row[k]] += wptr[k];
							}
						}
					}
					count += num_cols;
				}
			}
			BMfree_msg(msg); CHKERR(0);
		}
		CHKERR(0);
	}

	MY_FREE(xoff);
	MY_FREE(boff);
	MY_FREE(work);

	/* wait for all of the sent messages to finish */
	BMfinish_comp_msg(comm->to_msg,procinfo); CHKERR(0);
	MLOG_flop((2*block_size*A->local_nnz));
}


syntax highlighted by Code2HTML, v. 0.9.1