#include "BSprivate.h"
/*@ BSfor_solve - Forward triangular matrix solution on a
single vector
Input Parameters:
. A - The sparse matrix
. x - The rhs
. comm - The communication structure for A
. procinfo - the usual processor information
Output Parameters:
. x - on exit contains the solution vector
Returns:
void
@*/
void BSfor_solve(BSpar_mat *A, FLOAT *x, BScomm *comm, BSprocinfo *procinfo)
{
BMphase *to_phase, *from_phase;
BMmsg *msg;
int i, j, k;
int cl_ind, in_ind, symmetric;
int count, size, length, ind, num_cols;
int *row;
FLOAT *nz;
BScl_2_inode *clique2inode;
BSnumbering *color2clique;
BSinode *inodes;
int *data_ptr, msg_len;
FLOAT *msg_buf, *matrix;
FLOAT *work;
char UP = 'U';
char TR = 'T';
char NTR = 'N';
char ND = 'N';
int ione = 1;
FLOAT one = 1.0;
FLOAT zero = 0.0;
int *gnum, *iperm;
/* Is the symmetric data structure used? */
symmetric = A->icc_storage;
color2clique = A->color2clique;
clique2inode = A->clique2inode;
inodes = A->inodes->list;
gnum = A->global_row_num->numbers;
iperm = A->inv_perm->perm;
/* get some work space */
MY_MALLOC(work,(FLOAT *),sizeof(FLOAT)*A->num_rows,1);
/* post for all messages */
BMinit_comp_msg(comm->from_msg,procinfo); CHKERR(0);
/* now do this phase by phase */
for (i=0;i<color2clique->length-1;i++) {
if(symmetric) {
/* find my portion of the solution using the cliques on the diagonal */
for (cl_ind=color2clique->numbers[i];
cl_ind<color2clique->numbers[i+1];cl_ind++) {
if (procinfo->my_id == clique2inode->proc[cl_ind]) {
/* first, multiply the clique */
/* the clique is stored, inverted, in the upper triangle */
size = clique2inode->d_mats[cl_ind].size;
ind = clique2inode->d_mats[cl_ind].local_ind;
matrix = clique2inode->d_mats[cl_ind].matrix;
#ifdef MY_BLAS_DTRMV_ON
MY_DTRMV_T_U(size,matrix,size,&(x[ind]));
#else
DTRMV(&UP,&TR,&ND,&size,matrix,&size,&(x[ind]),&ione);
#endif
}
}
}
/* now send my messages */
to_phase = BMget_phase(comm->to_msg,i); CHKERR(0);
msg = NULL;
while ((msg = BMnext_msg(to_phase,msg)) != NULL) {
CHKERR(0);
msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
for (j=0;j<msg_len;j++) {
msg_buf[j] = x[data_ptr[j]];
}
BMsendf_msg(msg,procinfo); CHKERR(0);
}
CHKERR(0);
/* do some local work */
for (cl_ind=color2clique->numbers[i];
cl_ind<color2clique->numbers[i+1];cl_ind++) {
if (procinfo->my_id == clique2inode->proc[cl_ind]) {
ind = clique2inode->d_mats[cl_ind].local_ind;
/* multiply the inodes */
for (in_ind=clique2inode->inode_index[cl_ind];
in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
row = inodes[in_ind].row_num;
nz = inodes[in_ind].nz;
size = inodes[in_ind].length;
num_cols = inodes[in_ind].num_cols;
if(symmetric) {
if (size > 0) {
#ifdef MY_BLAS_DGEMV_ON
if (num_cols > DGEMV_UNROLL_LVL) {
DGEMV(&NTR,&size,&num_cols,&one,nz,&size,&(x[ind]),
&ione,&zero,work,&ione);
for (k=0;k<size;k++) x[row[k]] -= work[k];
} else {
MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),x,
row);
}
#else
DGEMV(&NTR,&size,&num_cols,&one,nz,&size,&(x[ind]),
&ione,&zero,work,&ione);
for (k=0;k<size;k++) x[row[k]] -= work[k];
#endif
}
} else {
length = inodes[in_ind].length;
/* The following part is added to make sure the */
/* nz are below pivot. (ILU) */
/*
for (j=0; j<length; j++) {
if (gnum[iperm[row[j]]] < inodes[in_ind].gcol_num) {
nz++; size--;
} else {
break;
}
}
if(size!=length-inodes[in_ind].below_diag) {
printf("FS, L: size = %d, size2 = %d\n",size,
length-inodes[in_ind].below_diag);
}
*/
size -= inodes[in_ind].below_diag;
nz += inodes[in_ind].below_diag;
if (size > 0) {
#ifdef MY_BLAS_DGEMV_ON
if (num_cols > DGEMV_UNROLL_LVL) {
DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
&ione,&zero,work,&ione);
for (k=0;k<size;k++) x[row[k+j]] -= work[k];
} else {
MY_DGEMVM1_N_1111(size,num_cols,nz,size,&(x[ind]),
x,row);
}
#else
DGEMV(&NTR,&size,&num_cols,&one,nz,&length,&(x[ind]),
&ione,&zero,work,&ione);
for (k=0;k<size;k++) x[row[k+j]] -= work[k];
#endif
}
}
ind += num_cols;
}
}
}
/* receive my messages and do non-local work */
from_phase = BMget_phase(comm->from_msg,i); CHKERR(0);
while ((msg = BMrecv_msg(from_phase)) != NULL) {
CHKERR(0);
msg_buf = (FLOAT *) BMget_msg_ptr(msg); CHKERR(0);
data_ptr = BMget_user(msg,&msg_len); CHKERR(0);
count = 0;
for (cl_ind=data_ptr[0];cl_ind<=data_ptr[1];cl_ind++) {
for (in_ind=clique2inode->inode_index[cl_ind];
in_ind<clique2inode->inode_index[cl_ind+1];in_ind++) {
row = inodes[in_ind].row_num;
nz = inodes[in_ind].nz;
size = inodes[in_ind].length;
num_cols = inodes[in_ind].num_cols;
if(symmetric) {
if (size > 0) {
#ifdef MY_BLAS_DGEMV_ON
if (num_cols > DGEMV_UNROLL_LVL) {
DGEMV(&NTR,&size,&num_cols,&one,nz,&size,
&(msg_buf[count]),&ione,&zero,work,&ione);
for (k=0;k<size;k++) x[row[k]] -= work[k];
} else {
MY_DGEMVM1_N_1111(size,num_cols,nz,size,
&(msg_buf[count]),x,row);
}
#else
DGEMV(&NTR,&size,&num_cols,&one,nz,&size,
&(msg_buf[count]),&ione,&zero,work,&ione);
for (k=0;k<size;k++) x[row[k]] -= work[k];
#endif
}
} else {
length = inodes[in_ind].length;
/* The following part is added to make sure the */
/* nz are below pivot. (ILU) */
/*
for (j=0; j<length; j++) {
if (gnum[iperm[row[j]]] < inodes[in_ind].gcol_num) {
nz++; size--;
} else {
break;
}
}
if(size!=length-inodes[in_ind].below_diag) {
printf("FS, NL: size = %d, size2 = %d\n",size,
length-inodes[in_ind].below_diag);
}
*/
size -= inodes[in_ind].below_diag;
nz += inodes[in_ind].below_diag;
if (size > 0) {
#ifdef MY_BLAS_DGEMV_ON
if (num_cols > DGEMV_UNROLL_LVL) {
DGEMV(&NTR,&size,&num_cols,&one,nz,&length,
&(msg_buf[count]),
&ione,&zero,work,&ione);
for (k=0;k<size;k++) x[row[k+j]] -= work[k];
} else {
MY_DGEMVM1_N_1111(size,num_cols,nz,size,
&(msg_buf[count]),x,row);
}
#else
DGEMV(&NTR,&size,&num_cols,&one,nz,&length,
&(msg_buf[count]),&ione,&zero,work,&ione);
for (k=0;k<size;k++) x[row[k+j]] -= work[k];
#endif
}
}
count += num_cols;
}
}
BMfree_msg(msg); CHKERR(0);
}
CHKERR(0);
}
MY_FREE(work);
/* wait for all of the sent messages to finish */
BMfinish_comp_msg(comm->to_msg,procinfo); CHKERR(0);
MLOG_flop((2*A->local_nnz));
}
syntax highlighted by Code2HTML, v. 0.9.1