#include "BSprivate.h"

/*@ BSback_solve - Backward triangular matrix solution on a 
                   single vector

    Input Parameters:
.   A - The sparse matrix
.   x - The rhs
.   comm - The communication structure for A
.   procinfo - the usual processor information

    Output Parameters:
.   x - on exit contains the solution vector

    Returns:
    void

 @*/
void BSback_solve(BSpar_mat *A, FLOAT *x, BScomm *comm, BSprocinfo *procinfo)
{
	BMcomp_msg *from_msg, *to_msg;
	BMphase *to_phase, *from_phase;
	BMmsg *msg;
	int	i, j, k;
	int	cl_ind, in_ind;
	int	count, size, ind, num_cols;
	int *row;
	FLOAT *nz;
	BScl_2_inode *clique2inode = A->clique2inode;
	BSnumbering *color2clique = A->color2clique;
	BSinode *inodes = A->inodes->list;
	int	*in_index = clique2inode->inode_index;
	int	*proc = clique2inode->proc;
	BSdense	*d_mats = clique2inode->d_mats;
	int	*data_ptr, msg_len;
	FLOAT *msg_buf, *matrix;
	int	my_id = procinfo->my_id;
	FLOAT *work;
	char UP = 'U';
	char TR = 'T';
	char NTR = 'N';
	char ND = 'N';
	int	*col2cl = color2clique->numbers;
	int	length = color2clique->length;
	int	start, finish, symmetric;
	int	ione = 1;
	FLOAT one = 1.0;
	FLOAT zero = 0.0;
	FLOAT minus_one = -1.0;
	FLOAT DDOT();
	int *gnum = A->global_row_num->numbers;
	int *iperm = A->inv_perm->perm;

	/* Is the symmetric data structure used? */
	symmetric = A->icc_storage;

	if(symmetric) {
		from_msg = comm->to_msg; /* we do mean to switch these */
		to_msg = comm->from_msg;
	} else {
		from_msg = comm->from_msg; /* do not switch for ILU case */
		to_msg = comm->to_msg;
	}

	/* get some work space */
	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1);

	/* post for all messages */
	BMinit_comp_msg(from_msg,procinfo); CHKERR(0);

	/* now do this phase by phase */
	for (i=length-2;i>=0;i--) {
		start = col2cl[i];
		finish = col2cl[i+1];

		if(!symmetric) {
			/* invert the diagonals and find the answers */
			for (cl_ind=start;cl_ind<finish;cl_ind++) {
				if (my_id == proc[cl_ind]) {
					size = clique2inode->d_mats[cl_ind].size;
					ind = clique2inode->d_mats[cl_ind].local_ind;
					matrix = clique2inode->d_mats[cl_ind].matrix;
					/* can't do much better (likely) on this DGEMV */
					DGEMV(&NTR,&size,&size,&one,matrix,&size,&(x[ind]),&ione,&zero,
						work,&ione);
					for (k=0; k<size; k++) x[ind+k] = work[k];
				}
			}
		}

		/* first send my messages */
		/* this will involve computing partial sums */
		to_phase = BMget_phase(to_msg,i); CHKERR(0);
		msg = NULL;
		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
			CHKERR(0);
			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
			if(symmetric) {
				count = 0;
				for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
					for (in_ind=in_index[cl_ind];
						in_ind<in_index[cl_ind+1];in_ind++) {
						row = inodes[in_ind].row_num;
						nz = inodes[in_ind].nz;
						size = inodes[in_ind].length;
						num_cols = inodes[in_ind].num_cols;
						if (size > 0) {
#ifdef MY_BLAS_DGEMV_ON
							if (num_cols > DGEMV_UNROLL_LVL) {
								for (k=0;k<size;k++) work[k] = x[row[k]];
								DGEMV(&TR,&size,&num_cols,&one,nz,&size,
									work,&ione,&zero,&(msg_buf[count]),&ione);
							} else {
								MY_DGEMV_Y_1101(size,num_cols,nz,size,x,row,
									&(msg_buf[count]));
							}
#else
							for (k=0;k<size;k++) work[k] = x[row[k]];
							DGEMV(&TR,&size,&num_cols,&one,nz,&size,
								work,&ione,&zero,&(msg_buf[count]),&ione);
#endif
						}
						count += num_cols;
					}
				}
			} else {
				for (j=0; j<msg_len; j++)
					msg_buf[j] = x[data_ptr[j]];
			}
			BMsendf_msg(msg,procinfo); CHKERR(0);
		}
		CHKERR(0);

		/* do some local work, multiply by the i-nodes */
		for (cl_ind=start;cl_ind<finish;cl_ind++) {
			if (my_id == proc[cl_ind]) {
				ind = d_mats[cl_ind].local_ind;
				for (in_ind=in_index[cl_ind];
					in_ind<in_index[cl_ind+1];in_ind++) {
					size = inodes[in_ind].length;
					num_cols = inodes[in_ind].num_cols;
					row = inodes[in_ind].row_num;
					nz = inodes[in_ind].nz;
					if(symmetric) {
						if (size > 0) {
#ifdef MY_BLAS_DGEMV_ON
							if (num_cols > DGEMV_UNROLL_LVL) {
								for (k=0;k<size;k++) work[k] = x[row[k]];
								DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size,
									work,&ione,&one,&(x[ind]),&ione);
							} else {
								MY_DGEMVM1_Y_1111(size,num_cols,nz,size,x,row,
									&(x[ind]));
							}
#else
							for (k=0;k<size;k++) work[k] = x[row[k]];
							DGEMV(&TR,&size,&num_cols,&minus_one,nz,&size,
								work,&ione,&one,&(x[ind]),&ione);
#endif
						}
					} else {
						/* The following part is added to make sure the nz are */
						/* above pivot. (ILU)                                  */
						length = size;
						size = inodes[in_ind].below_diag;
						/*
						for (j=length-1; j>=0; j--) {
						if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num)
							size--;
						else
							break;
						}
						if(size!=inodes[in_ind].below_diag) {
							printf("BS, L: size = %d, size2 = %d\n",size,
								inodes[in_ind].below_diag);
						}
						*/
						if (size > 0) {
#ifdef MY_BLAS_DGEMV_ON
							if (num_cols > DGEMV_UNROLL_LVL) {
								DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
									&ione,&zero,work,&ione);
								for (k=0;k<size;k++) x[row[k]] -= work[k];
							} else {
								MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),
									x,row);
							}
#else
							DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
								&ione,&zero,work,&ione);
							for (k=0;k<size;k++) x[row[k]] -= work[k];
#endif
						}
					}
					ind += num_cols;
				}
			}
		}

		/* receive my messages and update my rhs */
		from_phase = BMget_phase(from_msg,i); CHKERR(0);
		while ((msg = BMrecv_msg(from_phase)) != NULL) {
			CHKERR(0);
			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
			data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
			if(symmetric) {
				msg_len = BMget_msg_size(msg); CHKERR(0);
				for (j=0;j<msg_len;j++) x[data_ptr[j]] -= msg_buf[j];
			} else {
				count = 0;
				for (cl_ind=data_ptr[0]; cl_ind<=data_ptr[1]; cl_ind++) {
					for (in_ind=clique2inode->inode_index[cl_ind];
					in_ind<clique2inode->inode_index[cl_ind+1]; in_ind++) {
						row = inodes[in_ind].row_num;
						nz = inodes[in_ind].nz;
						/*size = inodes[in_ind].length;*/
						length = inodes[in_ind].length;
						num_cols = inodes[in_ind].num_cols;

						size = inodes[in_ind].below_diag;
						/*
						for (j=length-1; j>=0; j--) {
							if (gnum[iperm[row[j]]] > inodes[in_ind].gcol_num)
								size--;
							else
								break;
						}
						if(size!=inodes[in_ind].below_diag) {
							printf("NL: size = %d, size2 = %d\n",size,
								inodes[in_ind].below_diag);
						}
						*/
						if (size > 0) {
#ifdef MY_BLAS_DGEMV_ON
							if (num_cols > DGEMV_UNROLL_LVL) {
								DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]),
									&ione,&zero,work,&ione);
								for (k=0;k<size;k++) x[row[k]] -= work[k];
							} else {
								MY_DGEMVM1_N_1111(size,num_cols,nz,length,
									&(msg_buf[count]),x,row);
							}
#else
							DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(msg_buf[count]),
								&ione,&zero,work,&ione);
							for (k=0;k<size;k++) x[row[k]] -= work[k];
#endif
						}
						count += num_cols;
					}
				}
			}
			BMfree_msg(msg); CHKERR(0);
		}
		CHKERR(0);

		if(symmetric) {
			/* invert the diagonals and find the answers */
			for (cl_ind=start;cl_ind<finish;cl_ind++) {
				if (my_id == proc[cl_ind]) {
					/* first, multiply the clique */
					/* only do the strictly upper triangular part */
					/* we ASSUME the diagonal is all 1's */
					size = clique2inode->d_mats[cl_ind].size;
#ifdef MY_BLAS_DTRMV_ON
					MY_DTRMV_N_U(size,d_mats[cl_ind].matrix,size,
						&(x[d_mats[cl_ind].local_ind]),work);
#else
					DTRMV(&UP,&NTR,&ND,&size,d_mats[cl_ind].matrix,&size,
						&(x[d_mats[cl_ind].local_ind]),&ione);
#endif
				}
			}
		}

	}
	MY_FREE(work);
	/* wait for all of the sent messages to finish */
	BMfinish_comp_msg(to_msg,procinfo); CHKERR(0);
	MLOG_flop((2*A->local_nnz));
}


syntax highlighted by Code2HTML, v. 0.9.1