#include "BSprivate.h"

/* a data structure used only in factoring */
/* nodes in a list of i-nodes that need to be received/sent during */
/* factorization */
typedef struct __list_type {
	int	length;					/* length of the columns in the i-node */
	int	num_cols;				/* # of columns in the i-node */
	int	cur_len;				/* don't know anymore */
	FLOAT	*d_buffer;			/* don't know anymore */
	int	*i_buffer;				/* don't know anymore */
	int	buf_size;				/* don't know anymore */
	int	cl_ind; 				/* index of the clique, if local */
	int	in_ind; 				/* index of the i-node if local */
} list_type;

#include	"BStree.h"

extern void BSupdt_matrix(BStree, int,int ,int *, FLOAT *, int, int *,
     FLOAT *, int ,int ,BScl_2_inode *,BSinode *, BSprocinfo *);

/*+ BSfactorn - Compute the incomplete factor of a matrix

    Input Parameters:
.   A - The sparse matrix to be factored
.   comm - the communication structure of the factoring
.   procinfo - the usual processor info

    Output Parameters:
.   A - The factored sparse matrix

    Returns:
    0 if successful, otherwise a negative number whose absolute
    value is the row number of the color (less one) where the
    failure occured.

 +*/
int BSfactorn(BSpar_mat *A, BScomm *comm, BSprocinfo *procinfo)
{
	BMphase *to_phase, *from_phase;
	BMmsg *msg;
	int	i, j, k;
	int	cl_ind, in_ind;
	int	count, size, ind;
	int	length, num_cols, gnum, offset;
	int	found;
	int	*intptr, *intptr2, *i_buffer;
	FLOAT *d_buffer;
	BScl_2_inode *clique2inode;
	BSnumbering *color2clique;
	BSinode *inodes;
	int	*data_ptr, data_len, msg_len;
	FLOAT *msg_buf, *matrix;
	FLOAT *nzptr;
	char UP = 'U';
	char DI = 'N';
	char TR = 'N';
	char SIDE = 'R';
	FLOAT one = 1.0;
	int	info, fact_error;
	list_type *list_data_ptr;
	BStree_ptr	tree_node_ptr;
	BStree	recv_tree;
	BStree	inode_tree;
	int	dummy = 0;
	int	*g_row_num, *iperm;

	BMinit_comp_msg(comm->from_msg,procinfo); CHKERRN(0);
	BMinit_comp_msg(comm->to_msg,procinfo); CHKERRN(0);
	color2clique = A->color2clique;
	clique2inode = A->clique2inode;
	inodes = A->inodes->list;
	g_row_num = A->global_row_num->numbers;
	iperm = A->inv_perm->perm;

	MY_INIT_TREE(inode_tree,(sizeof(int)*2));
	for (i=0;i<color2clique->length-1;i++) {
		for (cl_ind=color2clique->numbers[i];
			cl_ind<color2clique->numbers[i+1];cl_ind++) {
			for (in_ind=clique2inode->inode_index[cl_ind];
				in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
				MY_INSERT_TREE_NODE(inode_tree,(inodes[in_ind].gcol_num),
					found,tree_node_ptr,dummy);
				data_ptr = (int *) MY_GET_TREE_DATA(tree_node_ptr);
				data_ptr[0] = in_ind;
				data_ptr[1] = cl_ind;
			}
		}
	}

	fact_error = FALSE;
	/* now do this phase by phase */
	for (i=0;i<color2clique->length-1;i++) {
		to_phase = BMget_phase(comm->to_msg,i); CHKERRN(0);
		from_phase = BMget_phase(comm->from_msg,i); CHKERRN(0);

		/* send my part of each column in this color */
		/* use the "from" part here */
		msg = NULL;
		while ((msg = BMnext_msg(from_phase,msg)) != NULL) { 
			CHKERRN(0);
			/* allocate space for the message */
			data_ptr = BMget_user(msg,&data_len); CHKERRN(0);
			in_ind = data_ptr[1];
			size = inodes[in_ind].length*inodes[in_ind].num_cols*sizeof(FLOAT)
				+ inodes[in_ind].length*sizeof(int);
			MY_MALLOCN(msg_buf,(FLOAT *),size,1);
			BMset_msg_ptr(msg,msg_buf); CHKERRN(0);
			BMset_msg_size(msg,size); CHKERRN(0);
			count = 0;
			for (j=0;j<inodes[in_ind].num_cols;j++) {
				nzptr = &(inodes[in_ind].nz[j*inodes[in_ind].length]);
				for (k=0;k<inodes[in_ind].length;k++) {
					msg_buf[count] = nzptr[k];
					count++;
				}
			}
			intptr = (int *) &(msg_buf[count]);
			intptr2 = (int *) inodes[in_ind].row_num;
			for (j=0;j<inodes[in_ind].length;j++) {
				intptr[j] = g_row_num[iperm[intptr2[j]]];
			}
			BMsend_block_msg(from_phase,msg,procinfo); CHKERRN(0);
			MY_FREE(msg_buf);
			BMset_msg_ptr(msg,NULL); CHKERRN(0);
		}
		CHKERRN(0);
		BMsend_block_msg(from_phase,NULL,procinfo); CHKERRN(0);


		/* factor the cliques in this color */
		for (cl_ind=color2clique->numbers[i];
			cl_ind<color2clique->numbers[i+1];cl_ind++) {
			if (procinfo->my_id == clique2inode->proc[cl_ind]) {
				/* factor this clique */
				/* stored in the upper diagonal */
				size = clique2inode->d_mats[cl_ind].size;
				matrix = clique2inode->d_mats[cl_ind].matrix;
				DPOTRF(&UP,&size,matrix,&size,&info);
				MLOG_flop((size*size*size)/6);
				if (info != 0) fact_error = TRUE;
				/* now invert the clique */
				DTRTRI(&UP,&DI,&size,matrix,&size,&info);
				MLOG_flop((size*size*size)/6);
				if (info != 0) fact_error = TRUE;
			}
		}

		/* figure out what I am going to receive and make room for it */
		/* use the "to" part here */
		MY_INIT_TREE(recv_tree,sizeof(list_type));
		msg = NULL;
		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
			CHKERRN(0);
			/* search list for this number */
			data_ptr = BMget_user(msg,&data_len); CHKERRN(0);
			gnum = data_ptr[0];
			length = data_ptr[1];
			num_cols = data_ptr[2];
			MY_INSERT_TREE_NODE(recv_tree,gnum,found,tree_node_ptr,dummy);
			list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
			if (found) {
				list_data_ptr->length += length;
			} else {
				list_data_ptr->length = length;
				list_data_ptr->num_cols = num_cols;
				list_data_ptr->cur_len = 0;
				list_data_ptr->d_buffer = NULL;
				list_data_ptr->i_buffer = NULL;
				list_data_ptr->cl_ind = -1;
				list_data_ptr->in_ind = -1;
			}
		}
		CHKERRN(0);
		/* add any of my parts to the list, if necessary */
		for (cl_ind=color2clique->numbers[i];
			cl_ind<color2clique->numbers[i+1];cl_ind++) {
			/* only work on columns that I own */
			if (clique2inode->proc[cl_ind] != procinfo->my_id) continue;
			for (in_ind=clique2inode->inode_index[cl_ind];
				in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
				MY_INSERT_TREE_NODE(recv_tree,(inodes[in_ind].gcol_num),
					found,tree_node_ptr,dummy);
				list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
				if (found) {
					list_data_ptr->cl_ind = cl_ind;
					list_data_ptr->in_ind = in_ind;
					list_data_ptr->length += inodes[in_ind].length;
				} else {
					list_data_ptr->length = inodes[in_ind].length;
					list_data_ptr->num_cols = inodes[in_ind].num_cols;
					list_data_ptr->cur_len = 0;
					list_data_ptr->d_buffer = NULL;
					list_data_ptr->i_buffer = NULL;
					list_data_ptr->cl_ind = cl_ind;
					list_data_ptr->in_ind = in_ind;
				}
			}
		}
		/* make sure that no one has been received out of order */
		if (procinfo->error_check) {
			MY_FIRST_IN_TREE(recv_tree,tree_node_ptr);
			while (! IS_TREE_PTR_NULL(tree_node_ptr)) {
				list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
				MY_NEXT_IN_TREE(tree_node_ptr);
				if (list_data_ptr->cl_ind == -1) {
					MY_SETERRCN(FACTOR_ERROR,"Received column out of order\n");
				}
			}
		}

		/* allocate space for each whole column */
		MY_FIRST_IN_TREE(recv_tree,tree_node_ptr);
		while (! IS_TREE_PTR_NULL(tree_node_ptr)) {
			list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
			MY_NEXT_IN_TREE(tree_node_ptr);
			list_data_ptr->buf_size = 
				(list_data_ptr->num_cols*list_data_ptr->length*sizeof(FLOAT)) +
				(list_data_ptr->length*sizeof(int));
			MY_MALLOCN(list_data_ptr->d_buffer,(FLOAT *),
				list_data_ptr->buf_size,2);
			list_data_ptr->i_buffer = (int *)
				&(list_data_ptr->d_buffer[list_data_ptr->num_cols*
				list_data_ptr->length]);
		}

		/* put my part of the column into the collected column */
		MY_FIRST_IN_TREE(recv_tree,tree_node_ptr);
		while (! IS_TREE_PTR_NULL(tree_node_ptr)) {
			list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
			MY_NEXT_IN_TREE(tree_node_ptr);
			in_ind = list_data_ptr->in_ind;
			length = inodes[in_ind].length;
			num_cols = inodes[in_ind].num_cols;
			nzptr = inodes[in_ind].nz;
			d_buffer = list_data_ptr->d_buffer;
			for (j=0;j<num_cols;j++) {
				for (k=0;k<length;k++) {
					d_buffer[k] = nzptr[(j*length)+k];
				}
				d_buffer += list_data_ptr->length;
			}
			i_buffer = list_data_ptr->i_buffer;
			intptr = inodes[in_ind].row_num;
			/* while copying, convert i_buffer from new local to new global */
			for (j=0;j<length;j++) {
				i_buffer[j] = g_row_num[iperm[intptr[j]]];
			}
			list_data_ptr->cur_len += length;
		}

		/* collect all the parts of "my" columns */
		while ((msg = BMrecv_block_msg(to_phase,procinfo)) != NULL) {
			CHKERRN(0);
			BMcheck_on_async_block(from_phase); CHKERRN(0);
			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERRN(0);
			BMset_msg_ptr(msg,NULL); CHKERRN(0);
			data_ptr = BMget_user(msg,&msg_len); CHKERRN(0);
			gnum = data_ptr[0];
			length = data_ptr[1];
			num_cols = data_ptr[2];
			MY_SEARCH_TREE(recv_tree,gnum,found,tree_node_ptr);
			if (found) {
				list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
				d_buffer = &(list_data_ptr->d_buffer[list_data_ptr->cur_len]);
				count = 0;
				for (j=0;j<num_cols;j++) {
					for (k=0;k<length;k++) {
						d_buffer[k] = msg_buf[count];
						count++;
					}
					d_buffer += list_data_ptr->length;
				}
				i_buffer = &(list_data_ptr->i_buffer[list_data_ptr->cur_len]);
				intptr = (int *) &(msg_buf[num_cols*length]);
				for (j=0;j<length;j++) {
					i_buffer[j] = intptr[j];
				}
				list_data_ptr->cur_len += length;
			} else {
				MY_SETERRCN(FACTOR_ERROR,"Can't find matching column\n");
			}
		}
		CHKERRN(0);
		BMfree_block_msg(to_phase); CHKERRN(0);
		BMfinish_on_async_block(from_phase); CHKERRN(0);

		/* see if the factorizations went okay, if not, then return */
		/* we have to wait until here to make sure all messages are received */
		/* make anything that needs free'ing is free'ed */
		/* everyone agrees on success/error */
		GISUM(&fact_error,1,&info,procinfo->procset);
		if (fact_error != 0) {
			/* free up space and return with the negative phase number */
			MY_FIRST_IN_TREE(recv_tree,tree_node_ptr);
			while (! IS_TREE_PTR_NULL(tree_node_ptr)) {
				list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
				MY_NEXT_IN_TREE(tree_node_ptr);
				MY_FREE(list_data_ptr->d_buffer);
			}
			MY_FREE_TREE(recv_tree);
			return(-(i+1));
		}

		/* factor these columns */
		MY_FIRST_IN_TREE(recv_tree,tree_node_ptr);
		while (! IS_TREE_PTR_NULL(tree_node_ptr)) {
			list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
			MY_NEXT_IN_TREE(tree_node_ptr);
			if (list_data_ptr->length <= 0) continue;
			cl_ind = list_data_ptr->cl_ind;
			in_ind = list_data_ptr->in_ind;
			j = inodes[in_ind].num_cols;
			offset = inodes[in_ind].gcol_num - clique2inode->g_offset[cl_ind];
			size = clique2inode->d_mats[cl_ind].size;
			matrix = clique2inode->d_mats[cl_ind].matrix;
			matrix += ((offset*size)+offset);
			DTRMM(&SIDE,&UP,&TR,&DI,&(list_data_ptr->length),&j,&one,
				matrix,&size,list_data_ptr->d_buffer,&(list_data_ptr->length));
			MLOG_flop(2*((j*(j+1))/2)*list_data_ptr->length*j);
		}

		/* send these columns out */
		msg = NULL;
		while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
			CHKERRN(0);
			data_ptr = BMget_user(msg,&data_len); CHKERRN(0);
			gnum = data_ptr[0];
			MY_SEARCH_TREE(recv_tree,gnum,found,tree_node_ptr);
			if (found) {
				list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
			} else {
				MY_SETERRCN(FACTOR_ERROR,"Can't find matching column\n");
			}
			BMset_msg_ptr(msg,list_data_ptr->d_buffer); CHKERRN(0);
			BMset_msg_size(msg,list_data_ptr->buf_size); CHKERRN(0);
			BMsend_block_msg(to_phase,msg,procinfo); CHKERRN(0);
			BMset_msg_ptr(msg,NULL); CHKERRN(0);
		}
		CHKERRN(0);
		BMsend_block_msg(to_phase,NULL,procinfo); CHKERRN(0);

		/* retrieve my part of the column and put it in the data structure */
		MY_FIRST_IN_TREE(recv_tree,tree_node_ptr);
		while (! IS_TREE_PTR_NULL(tree_node_ptr)) {
			list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
			MY_NEXT_IN_TREE(tree_node_ptr);
			if (list_data_ptr->length <= 0) continue;
			in_ind = list_data_ptr->in_ind;
			length = inodes[in_ind].length;
			num_cols = inodes[in_ind].num_cols;
			nzptr = inodes[in_ind].nz;
			d_buffer = list_data_ptr->d_buffer;
			for (j=0;j<num_cols;j++) {
				for (k=0;k<length;k++) {
					nzptr[(j*length)+k] = d_buffer[k];
				}
				d_buffer += list_data_ptr->length;
			}
		}

		/* now update the remaining matrix using the local columns */
		MY_FIRST_IN_TREE(recv_tree,tree_node_ptr);
		while (! IS_TREE_PTR_NULL(tree_node_ptr)) {
			list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
			MY_NEXT_IN_TREE(tree_node_ptr);
			if (list_data_ptr->length <= 0) continue;
			in_ind = list_data_ptr->in_ind;
			BSupdt_matrix(inode_tree,list_data_ptr->length,
				list_data_ptr->num_cols,
				list_data_ptr->i_buffer,list_data_ptr->d_buffer,
				inodes[in_ind].length,
				inodes[in_ind].row_num,inodes[in_ind].nz,
				color2clique->numbers[i+1],
				color2clique->numbers[color2clique->length-1],clique2inode,
				inodes,procinfo); CHKERRN(0);
		}

		/* receive the incoming columns and update the remaining matrix */
		/* strip out the "local" part and put it in the matrix */
		while ((msg = BMrecv_block_msg(from_phase,procinfo)) != NULL) {
			CHKERRN(0);
			BMcheck_on_async_block(to_phase); CHKERRN(0);
			msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERRN(0);
			BMset_msg_ptr(msg,NULL); CHKERRN(0);
			data_ptr = BMget_user(msg,&msg_len); CHKERRN(0);
			cl_ind = data_ptr[0];
			in_ind = data_ptr[1];
			/* let gnum be first global row location of my part */
			gnum = g_row_num[iperm[inodes[in_ind].row_num[0]]];
			num_cols = inodes[in_ind].num_cols;
			length = inodes[in_ind].length;
			msg_len = BMget_msg_size(msg); CHKERRN(0);
			msg_len /= ((sizeof(FLOAT)*num_cols)+sizeof(int));

			/* find my local part */
			i_buffer = (int *) &(msg_buf[num_cols*msg_len]);
			for (ind=0;ind<msg_len;ind++) {
				if (gnum == i_buffer[ind]) break;
			}
			nzptr = inodes[in_ind].nz;
			d_buffer = &(msg_buf[ind]);
			for (j=0;j<num_cols;j++) {
				for (k=0;k<length;k++) {
					nzptr[(j*length)+k] = d_buffer[k];
				}
				d_buffer += msg_len;
			}

			/* update the remaining matrix */
			BSupdt_matrix(inode_tree,msg_len,num_cols,i_buffer,msg_buf,length,
				inodes[in_ind].row_num,nzptr,color2clique->numbers[i+1],
				color2clique->numbers[color2clique->length-1],clique2inode,
				inodes,procinfo); CHKERRN(0);
		}
		CHKERRN(0);
		BMfree_block_msg(from_phase); CHKERRN(0);
		BMfinish_on_async_block(to_phase); CHKERRN(0);

		/* now free up the list */
		MY_FIRST_IN_TREE(recv_tree,tree_node_ptr);
		while (! IS_TREE_PTR_NULL(tree_node_ptr)) {
			list_data_ptr = (list_type *) MY_GET_TREE_DATA(tree_node_ptr);
			MY_NEXT_IN_TREE(tree_node_ptr);
			MY_FREE(list_data_ptr->d_buffer);
		}
		MY_FREE_TREE(recv_tree);
	}

    MY_FREE_TREE(inode_tree);
	BMfinish_comp_msg(comm->from_msg,procinfo); CHKERRN(0);
	BMfinish_comp_msg(comm->to_msg,procinfo); CHKERRN(0);
	return(0);
}

/*+ BSupdt_matrix - Update the matrix using a bunch of i-nodes 

    Input Parameters:
    I don't really know what these things do now.

    Output Parameters:
    I don't really know what these things do now.

    Returns:
    void

    Notes:
    This code is very complicated and could be done much better.

+*/
void BSupdt_matrix(BStree inode_tree,int len,int num_cols,int *updt_index,
FLOAT *updt_inode,int len2,int *my_inode_ind,FLOAT *my_inode,
int start,int end,BScl_2_inode *clique2inode,BSinode *inodes,
BSprocinfo *procinfo)
{
	int	i, j;
	int	updt_count, offset, toffset;
	int	cl_ind, in_ind;
	BSpermutation *iperm;
	FLOAT	*work, *matrix;
	int	*intptr;
	FLOAT	*nzptr;
	int	tcount, count2, gnum, size;
	int	found;
	BStree_ptr	tree_node_ptr;
	int	*tree_data_ptr;

	/* sort the updt_inode */
	iperm = BSalloc_permutation(len); CHKERR(0);
	for (i=0;i<len;i++) iperm->perm[i] = i;
	BSheap_sort1(len,updt_index,iperm->perm); CHKERR(0);
	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*len,1);
	for (i=0;i<num_cols;i++) {
		BSiperm_dvec(&(updt_inode[i*len]),work,iperm); CHKERR(0);
		for (j=0;j<len;j++) updt_inode[(i*len)+j] = work[j];
	}
	MY_FREE(work);
	BSfree_permutation(iperm); CHKERR(0);

	MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*num_cols,1);
	updt_count = 0;
	while (updt_count < len) {
		MY_SEARCH_TREE(inode_tree,updt_index[updt_count],found,tree_node_ptr);
		if (found) {
			tree_data_ptr = (int *) MY_GET_TREE_DATA(tree_node_ptr);
			in_ind = tree_data_ptr[0];
			cl_ind = tree_data_ptr[1];
			/********** update the inode */
			tcount = 0;
			count2 = 0;
			intptr = inodes[in_ind].row_num;
			while ((tcount<len2) && (count2<inodes[in_ind].length)) {
				if (my_inode_ind[tcount] == intptr[count2]) {
					for (j=0;j<num_cols;j++) {
						work[j] = my_inode[(j*len2)+tcount];
					}
					nzptr = &(inodes[in_ind].nz[count2]);
					for (i=0;i<inodes[in_ind].num_cols;i++) {
						for (j=0;j<num_cols;j++) {
							(*nzptr) -= work[j]*
								updt_inode[(j*len)+updt_count+i];
						}
						nzptr += inodes[in_ind].length;
					}
					MLOG_flop(2*num_cols*inodes[in_ind].num_cols);
					tcount++;
					count2++;
				} else if (my_inode_ind[tcount] < intptr[count2]) {
					tcount++;
				} else {
					count2++;
				}
			}
			/* update a protion of the clique if it is local */
			if (clique2inode->proc[cl_ind] == procinfo->my_id) {
				/***** update the clique (stored in UPPER triangle) */
				gnum = clique2inode->g_offset[cl_ind];
				offset = updt_index[updt_count]-gnum;
				matrix = clique2inode->d_mats[cl_ind].matrix;
				size = clique2inode->d_mats[cl_ind].size;
				for (i=0;i<inodes[in_ind].num_cols;i++) {
					tcount = updt_count + i;
					for (j=0;j<num_cols;j++) {
						work[j] = updt_inode[(j*len)+tcount];
					}
					while (tcount < len) {
						toffset = updt_index[tcount]-gnum;
						if (toffset < size) {
							for (j=0;j<num_cols;j++) {
								matrix[(toffset*size)+offset+i] -= work[j]*
									updt_inode[(j*len)+tcount];
							}
							MLOG_flop(2*num_cols);
						} else {
							break;
						}
						tcount++;
					}
				}
			}
			/* end of clique update */
			updt_count += inodes[in_ind].num_cols;
		} else {
			updt_count++;
		}
	}
	MY_FREE(work);
}


syntax highlighted by Code2HTML, v. 0.9.1