#include "BSprivate.h"
/*@ BSb_backward - Backward triangular matrix multiplication on a
block of vectors
Input Parameters:
. A - The sparse matrix
. x - The contiguous block of input vectors
. comm - The communication structure for A
. block_size - the number of input vectors
. procinfo - the usual processor information
Output Parameters:
. b - on exit these vectors contain A*x
Returns:
void
@*/
void BSb_backward(BSpar_mat *A, FLOAT *x, FLOAT *b, BScomm *comm,
int block_size, BSprocinfo *procinfo)
{
BMcomp_msg *from_msg, *to_msg;
BMphase *to_phase, *from_phase;
BMmsg *msg;
int i, j, k, n;
int cl_ind, in_ind;
int count, size, ind, num_cols;
int *row;
FLOAT *nz;
BScl_2_inode *clique2inode;
BSnumbering *color2clique;
BSinode *inodes;
int *data_ptr, msg_len;
FLOAT *msg_buf, *matrix;
FLOAT *work;
FLOAT *bptr, *xptr, *wptr;
FLOAT **boff, **xoff;
char UP = 'L';
char TR = 'T';
char NTR = 'N';
char ND = 'N';
char SIDE = 'L';
int ione = 1;
FLOAT one = 1.0;
FLOAT zero = 0.0;
FLOAT DDOT();
if(!A->icc_storage) {
/* nonsymmetric version done in BSb_forward() */
return;
}
if(procinfo->single) {
/* No single version so call BSbackward1 BS times */
n = A->num_rows;
for (i=0;i<block_size;i++) {
BSbackward1(A,&(x[n*i]),&(b[n*i]),comm,procinfo); CHKERR(0);
}
return;
}
color2clique = A->color2clique;
clique2inode = A->clique2inode;
inodes = A->inodes->list;
/* get some work space */
MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows*block_size,1);
/* calculate b and x offsets */
MY_MALLOC(boff,(FLOAT **),sizeof(FLOAT *)*block_size,1);
MY_MALLOC(xoff,(FLOAT **),sizeof(FLOAT *)*block_size,1);
for (i=0;i<block_size;i++) {
boff[i] = &(b[i*A->num_rows]);
xoff[i] = &(x[i*A->num_rows]);
}
/* REMEMBER, the to and from phase are switched here */
from_msg = comm->to_msg;
to_msg = comm->from_msg;
/* post for all messages */
BMinit_comp_msg(from_msg,procinfo); CHKERR(0);
/* REMEMBER, the diagonal has already been taken care */
/* now do this phase by phase */
for (i=color2clique->length-2;i>=0;i--) {
/* first send my messages */
/* this will involve computing partial sums */
to_phase = BMget_phase(to_msg,i); CHKERR(0);
msg = NULL;
while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
CHKERR(0);
msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
msg_len = BMget_msg_size(msg); CHKERR(0);
msg_len /= block_size;
count = 0;
for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
for (in_ind=clique2inode->inode_index[cl_ind];
in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
row = inodes[in_ind].row_num;
nz = inodes[in_ind].nz;
size = inodes[in_ind].length;
num_cols = inodes[in_ind].num_cols;
if (size > 0) {
for (j=0;j<block_size;j++) {
wptr = &(work[j*size]);
xptr = xoff[j];
for (k=0;k<size;k++) wptr[k] = xptr[row[k]];
}
DGEMM(&TR,&NTR,&num_cols,&block_size,&size,&one,
nz,&size,work,&size,&zero,&(msg_buf[count]),
&msg_len);
}
count += num_cols;
}
}
BMsendf_msg(msg,procinfo); CHKERR(0);
}
CHKERR(0);
}
/* do some local work */
for (i=color2clique->length-2;i>=0;i--) {
for (cl_ind=color2clique->numbers[i];
cl_ind<color2clique->numbers[i+1];cl_ind++) {
if (procinfo->my_id == clique2inode->proc[cl_ind]) {
/* first, multiply the clique */
/* only do the strictly upper triangular part */
/* we ASSUME the diagonal is all 1's */
size = clique2inode->d_mats[cl_ind].size;
ind = clique2inode->d_mats[cl_ind].local_ind;
matrix = clique2inode->d_mats[cl_ind].matrix;
if (size > 1) {
j = size-1;
matrix++;
nz = work;
for (k=0;k<block_size;k++) {
DCOPY(&j,&(xoff[k][ind+1]),&ione,nz,&ione);
nz += j;
}
DTRMM(&SIDE,&UP,&TR,&ND,&j,&block_size,&one,matrix,
&size,work,&j);
nz = work;
for (k=0;k<block_size;k++) {
DAXPY(&j,&one,nz,&ione,&(boff[k][ind]),&ione);
nz += j;
}
}
/* now, multiply the inodes */
for (in_ind=clique2inode->inode_index[cl_ind];
in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
row = inodes[in_ind].row_num;
nz = inodes[in_ind].nz;
size = inodes[in_ind].length;
num_cols = inodes[in_ind].num_cols;
if (size > 0) {
for (j=0;j<block_size;j++) {
wptr = &(work[j*size]);
xptr = xoff[j];
for (k=0;k<size;k++) wptr[k] = xptr[row[k]];
}
DGEMM(&TR,&NTR,&num_cols,&block_size,&size,&one,
nz,&size,work,&size,&one,&(b[ind]),
&(A->num_rows));
}
ind += num_cols;
}
}
}
}
/* receive my messages and update my rhs */
for (i=color2clique->length-2;i>=0;i--) {
from_phase = BMget_phase(from_msg,i); CHKERR(0);
while ((msg = BMrecv_msg(from_phase)) != NULL) {
CHKERR(0);
msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
msg_len = BMget_msg_size(msg); CHKERR(0);
msg_len /= block_size;
for (j=0;j<block_size;j++) {
wptr = &(msg_buf[j*msg_len]);
bptr = boff[j];
for (k=0;k<msg_len;k++) bptr[data_ptr[k]] += wptr[k];
}
BMfree_msg(msg); CHKERR(0);
}
CHKERR(0);
}
MY_FREE(xoff);
MY_FREE(boff);
MY_FREE(work);
/* wait for all of the sent messages to finish */
BMfinish_comp_msg(to_msg,procinfo); CHKERR(0);
MLOG_flop((2*A->local_nnz*block_size));
}
syntax highlighted by Code2HTML, v. 0.9.1