/* solveMT.c */
#include "../spoolesMT.h"
#include "../../Ideq.h"
#include "../../timings.h"
#include "../../SolveMap.h"
#include "../../SubMtxList.h"
/*--------------------------------------------------------------------*/
/*
---------------------------------
declarations for static functions
---------------------------------
*/
static void * FrontMtx_workerSolve ( void *arg ) ;
static void setupListObjects ( FrontMtx *frontmtx, SolveMap *solvemap,
SubMtxList *forwAggList, SubMtxList *backAggList,
int msglvl, FILE *msgFile ) ;
/*--------------------------------------------------------------------*/
/*
--------------------------------------------------------------------
data structure for each thread
global data
frontmtx -- FrontMtx object that contains the matrix
factorization
solmtx -- DenseMtx object, on output it holds the solution
rhsmtx -- DenseMtx object, on input it holds the right
hand side
solvemap -- SolveMap object that holds the map of computations
to processes for the forward and backward solve
mtxmanager -- SubMtxManager object to manage working storage
forwAggList -- SubMtxList to hold aggregate objects
during the forward solve
backAggList -- SubMtxList to hold aggregate objects
during the backward solve
p_mtx -- array of pointers to SubMtx objects that point
to global solution SubMtx objects for the fronts
frontIsDone -- char vector used for synchronization
there are three steps to the solve
(1) (L + I) Y = B
(2) D Z = Y
(3) (I + U) X = B
frontIsDone[J] = 'N' --> J at step is not yet complete
frontIsDone[J] = 'Y' --> J at step is complete
lowerIsDone -- vector used for synchronization
'Y' --> forward solve for iproc complete
'N' --> forward solve for iproc not complete
diagIsDone -- vector used for synchronization
'Y' --> diagonal solve for iproc complete
'N' --> diagonal solve for iproc not complete
upperIsDone -- vector used for synchronization
'Y' --> backward solve for iproc complete
'N' --> backward solve for iproc not complete
local data
myid -- thread id
cpus -- vector to hold cpu breakdown
msglvl -- local message level
msgFile -- local message file
--------------------------------------------------------------------
*/
typedef struct _SolveData SolveData ;
struct _SolveData {
FrontMtx *frontmtx ;
DenseMtx *solmtx ;
DenseMtx *rhsmtx ;
SolveMap *solvemap ;
SubMtxManager *mtxmanager ;
int myid ;
SubMtxList *forwAggList ;
SubMtxList *backAggList ;
SubMtx **p_mtx ;
char *frontIsDone ;
char *lowerIsDone ;
char *diagIsDone ;
char *upperIsDone ;
double *cpus ;
int msglvl ;
FILE *msgFile ;
} ;
/*--------------------------------------------------------------------*/
/*
----------------------------------------------------
multithreaded solve method for (L + I)D(I + U) X = B
or (U^T + I)D(I + U) X = B.
created -- 98mar19, cca
----------------------------------------------------
*/
void
FrontMtx_MT_solve (
FrontMtx *frontmtx,
DenseMtx *solmtx,
DenseMtx *rhsmtx,
SubMtxManager *mtxmanager,
SolveMap *solvemap,
double cpus[],
int msglvl,
FILE *msgFile
) {
char buffer[20] ;
char *diagIsDone, *frontIsDone, *lowerIsDone, *upperIsDone ;
SubMtx **p_mtx ;
SubMtxList *backAggList, *forwAggList ;
SolveData *data, *dataObjects ;
double t1, t2 ;
FILE *fp ;
int ii, J, myid, nfront, nthread, rc ;
/*
---------------
check the input
---------------
*/
if ( frontmtx == NULL || solmtx == NULL || rhsmtx == NULL
|| mtxmanager == NULL || solvemap == NULL || cpus == NULL
|| (msglvl > 0 && msgFile == NULL) ) {
fprintf(stderr, "\n fatal error in FrontMtx_MT_solve()"
"\n frontmtx = %p, solmtx = %p, rhsmtx = %p, mtxmanager = %p"
"\n solvemap = %p, cpus = %p"
"\n msglvl = %d, msgFile = %p"
"\n\n bad input\n", frontmtx, solmtx, rhsmtx, mtxmanager,
solvemap, cpus, msglvl, msgFile) ;
exit(-1) ;
}
MARKTIME(t1) ;
nfront = FrontMtx_nfront(frontmtx) ;
nthread = SolveMap_nproc(solvemap) ;
/*
--------------------------------------------------
set up the forward and backward solve list objects
--------------------------------------------------
*/
backAggList = SubMtxList_new() ;
forwAggList = SubMtxList_new() ;
setupListObjects(frontmtx, solvemap, forwAggList, backAggList,
msglvl, msgFile) ;
if ( msglvl > 1 ) {
fprintf(msgFile, "\n\n forward aggregate list") ;
SubMtxList_writeForHumanEye(forwAggList, msgFile) ;
fflush(msgFile) ;
}
/*
---------------------------
allocate the global vectors
---------------------------
*/
ALLOCATE(p_mtx, struct _SubMtx *, nfront) ;
for ( J = 0 ; J < nfront ; J++ ) {
p_mtx[J] = NULL ;
}
frontIsDone = CVinit(nfront, 'N') ;
lowerIsDone = CVinit(solvemap->nproc, 'N') ;
diagIsDone = CVinit(solvemap->nproc, 'N') ;
upperIsDone = CVinit(solvemap->nproc, 'N') ;
/*
----------------------------------
allocate the local data structures
----------------------------------
*/
ALLOCATE(dataObjects, struct _SolveData, nthread) ;
for ( myid = 0, data = dataObjects ; myid < nthread ; myid++, data++ ) {
data->frontmtx = frontmtx ;
data->solmtx = solmtx ;
data->rhsmtx = rhsmtx ;
data->solvemap = solvemap ;
data->mtxmanager = mtxmanager ;
data->myid = myid ;
data->forwAggList = forwAggList ;
data->backAggList = backAggList ;
data->p_mtx = p_mtx ;
data->frontIsDone = frontIsDone ;
data->lowerIsDone = lowerIsDone ;
data->diagIsDone = diagIsDone ;
data->upperIsDone = upperIsDone ;
data->cpus = cpus ;
data->msglvl = msglvl ;
if ( msglvl > 0 ) {
sprintf(buffer, "solve.res.%d", myid) ;
if ( (fp = fopen(buffer, "w")) == NULL ) {
fprintf(stderr, "\n fatal error, unable to open file %s",
buffer) ;
exit(-1) ;
}
data->msgFile = fp ;
} else {
data->msgFile = NULL ;
}
}
MARKTIME(t2) ;
if ( msglvl > 1 ) {
fprintf(msgFile, "\n CPU %8.3f : initialization time", t2 - t1) ;
fflush(msgFile) ;
}
/*
--------------------------------------
switch over the type of thread library
--------------------------------------
*/
#if THREAD_TYPE == TT_SOLARIS
/*
---------------
solaris threads
---------------
*/
MARKTIME(t1) ;
thr_setconcurrency(nthread) ;
MARKTIME(t2) ;
if ( msglvl > 2 ) {
fprintf(msgFile, "\n CPU %8.3f : set concurrency time", t2 - t1) ;
}
MARKTIME(t1) ;
for ( myid = 0, data = dataObjects ;
myid < nthread - 1 ;
myid++, data++ ) {
rc = thr_create(NULL, 0, FrontMtx_workerSolve, data, 0, NULL) ;
if ( rc != 0 ) {
fprintf(stderr,
"\n fatal error, myid = %d, rc = %d from thr_create",
myid, rc) ;
exit(-1) ;
}
}
MARKTIME(t2) ;
if ( msglvl > 1 ) {
fprintf(msgFile, "\n CPU %8.3f : thread creation time", t2 - t1) ;
fflush(msgFile) ;
}
MARKTIME(t1) ;
FrontMtx_workerSolve(data) ;
for ( myid = 0 ; myid < nthread - 1 ; myid++ ) {
thr_join(0, 0, 0) ;
}
MARKTIME(t2) ;
if ( msglvl > 1 ) {
fprintf(msgFile, "\n CPU %8.3f : thread join time", t2 - t1) ;
fflush(msgFile) ;
}
#endif
#if THREAD_TYPE == TT_POSIX
/*
-------------
POSIX threads
-------------
*/
{
pthread_attr_t attr ;
pthread_t *tids ;
void *status ;
/*
##### NOTE: for SGI machines, this command must be present
##### for the thread scheduling to be efficient.
##### this is NOT a POSIX call, but SGI needs it anyway
pthread_setconcurrency(nthread) ;
*/
MARKTIME(t1) ;
pthread_attr_init(&attr) ;
/*
pthread_attr_setscope(&attr, PTHREAD_SCOPE_SYSTEM) ;
*/
pthread_attr_setscope(&attr, PTHREAD_SCOPE_PROCESS) ;
ALLOCATE(tids, pthread_t, nthread) ;
for ( myid = 0, data = dataObjects ; myid < nthread ; myid++, data++ ) {
rc = pthread_create(&tids[myid], &attr,
FrontMtx_workerSolve, data) ;
if ( rc != 0 ) {
fprintf(stderr,
"\n fatal error, myid = %d, rc = %d from pthread_create",
myid, rc) ;
exit(-1) ;
}
}
MARKTIME(t2) ;
if ( msglvl > 1 ) {
fprintf(msgFile, "\n CPU %8.3f : thread creation time", t2 - t1) ;
fflush(msgFile) ;
}
MARKTIME(t1) ;
for ( myid = 0 ; myid < nthread ; myid++ ) {
pthread_join(tids[myid], &status) ;
}
MARKTIME(t2) ;
if ( msglvl > 1 ) {
fprintf(msgFile, "\n CPU %8.3f : thread join time", t2 - t1) ;
fflush(msgFile) ;
}
FREE(tids) ;
pthread_attr_destroy(&attr) ;
}
#endif
/*
---------------------
free the working data
---------------------
*/
MARKTIME(t1) ;
DVzero(8, cpus) ;
for ( myid = 0, data = dataObjects ; myid < nthread ; myid++, data++ ) {
for ( ii = 0 ; ii < 8 ; ii++ ) {
cpus[ii] += data->cpus[ii] ;
}
}
FREE(dataObjects) ;
FREE(p_mtx) ;
CVfree(frontIsDone) ;
CVfree(lowerIsDone) ;
CVfree(diagIsDone) ;
CVfree(upperIsDone) ;
SubMtxList_free(forwAggList) ;
SubMtxList_free(backAggList) ;
MARKTIME(t2) ;
if ( msglvl > 1 ) {
fprintf(msgFile, "\n CPU %8.3f : free working data", t2 - t1) ;
fflush(msgFile) ;
}
return ; }
/*--------------------------------------------------------------------*/
/*
----------------------------------------
worker method to solve the linear system
created -- 97jun27, cca
----------------------------------------
*/
static void *
FrontMtx_workerSolve (
void *arg
) {
char *diagIsDone, *frontIsDone, *lowerIsDone,
*status, *upperIsDone ;
DenseMtx *rhsmtx, *solmtx ;
SubMtx **p_agg, **p_mtx ;
SubMtxList *backAggList, *forwAggList ;
SubMtxManager *mtxmanager ;
FrontMtx *frontmtx ;
double t0, t1, t2, t3 ;
double *cpus ;
SolveData *data ;
FILE *msgFile ;
Ideq *dequeue ;
int everybodyDone, iproc, I, J, K, msglvl, myid,
nfront, nproc, nrhs ;
int *fch, *nactiveChild, *owners, *par, *sib ;
IP **heads ;
SolveMap *solvemap ;
Tree *tree ;
/*
-------------------------------
extract pointers and dimensions
-------------------------------
*/
MARKTIME(t0) ;
data = (SolveData *) arg ;
frontmtx = data->frontmtx ;
rhsmtx = data->rhsmtx ;
solmtx = data->solmtx ;
solvemap = data->solvemap ;
mtxmanager = data->mtxmanager ;
forwAggList = data->forwAggList ;
backAggList = data->backAggList ;
p_mtx = data->p_mtx ;
frontIsDone = data->frontIsDone ;
lowerIsDone = data->lowerIsDone ;
diagIsDone = data->diagIsDone ;
upperIsDone = data->upperIsDone ;
myid = data->myid ;
cpus = data->cpus ;
msglvl = data->msglvl ;
msgFile = data->msgFile ;
nfront = frontmtx->nfront ;
nproc = solvemap->nproc ;
tree = frontmtx->frontETree->tree ;
par = tree->par ;
fch = tree->fch ;
sib = tree->sib ;
nrhs = rhsmtx->ncol ;
owners = SolveMap_owners(solvemap) ;
/*
-----------------------------------------------
set up for the forward solve
(1) IP lists for the forward solve
(2) local status[] vector
(3) create the dequeue for the forward solve
(4) create the number of active children vector
-----------------------------------------------
*/
MARKTIME(t1) ;
heads = SolveMap_forwardSetup(solvemap, myid, msglvl, msgFile) ;
status = CVinit(nfront, 'F') ;
dequeue = FrontMtx_setUpDequeue(frontmtx, owners, myid, status, heads,
'W', 'F', msglvl, msgFile) ;
FrontMtx_loadActiveLeaves(frontmtx, status, 'W', dequeue) ;
nactiveChild = FrontMtx_nactiveChild(frontmtx, status, myid) ;
MARKTIME(t2) ;
cpus[0] += t2 - t1 ;
if ( msglvl > 2 ) {
IP *ip ;
fprintf(msgFile, "\n\n forward update lists") ;
for ( J = 0 ; J < nfront ; J++ ) {
if ( heads[J] != NULL ) {
fprintf(msgFile, "\n %d : ", J) ;
for ( ip = heads[J] ; ip != NULL ; ip = ip->next ) {
fprintf(msgFile, " %d", ip->val) ;
}
}
}
fflush(msgFile) ;
}
if ( msglvl > 2 ) {
fprintf(msgFile, "\n\n dequeue set up for forward solve") ;
Ideq_writeForHumanEye(dequeue, msgFile) ;
fprintf(msgFile, "\n\n nactiveChild") ;
IVfprintf(msgFile, nfront, nactiveChild) ;
fflush(msgFile) ;
}
/*
------------------------
load the right hand side
------------------------
*/
MARKTIME(t1) ;
p_agg = FrontMtx_loadRightHandSide(frontmtx, rhsmtx, owners, myid,
mtxmanager, msglvl, msgFile) ;
MARKTIME(t2) ;
cpus[1] += t2 - t1 ;
/*
-------------
forward solve
-------------
*/
#if MYDEBUG > 0
fprintf(stdout, "\n\n thread %d : starting forward solve", myid) ;
fflush(stdout) ;
#endif
MARKTIME(t1) ;
while ( (J = Ideq_removeFromHead(dequeue)) != -1 ) {
if ( msglvl > 1 ) {
fprintf(msgFile,
"\n\n ### forward solve, checking out J = %d, status %c",
J, status[J]) ;
fflush(msgFile) ;
}
FrontMtx_forwardVisit(frontmtx, J, nrhs, owners, myid, mtxmanager,
forwAggList, p_mtx, frontIsDone, heads,
p_agg, status, msglvl, msgFile) ;
if ( status[J] == 'F' ) {
/*
-------------------
front J is finished
-------------------
*/
if ( msglvl > 1 ) {
fprintf(msgFile, "\n front J is finished") ;
fflush(msgFile) ;
}
if ( (K = par[J]) != -1 && --nactiveChild[K] == 0 ) {
/*
-----------------------------------
all children of parent are finished
add parent to head of dequeue
-----------------------------------
*/
if ( msglvl > 1 ) {
fprintf(msgFile, ", placing %d on head of dequeue", K) ;
fflush(msgFile) ;
}
Ideq_insertAtHead(dequeue, K) ;
}
} else {
/*
-------------------------------------------------
front J is not finished, place on tail of dequeue
-------------------------------------------------
*/
if ( msglvl > 1 ) {
fprintf(msgFile, "\n front J is not finished") ;
fflush(msgFile) ;
}
Ideq_insertAtTail(dequeue, J) ;
}
}
IP_free(heads[nfront+1]) ;
FREE(heads) ;
IVfree(nactiveChild) ;
Ideq_free(dequeue) ;
MARKTIME(t2) ;
cpus[2] += t2 - t1 ;
if ( msglvl > 1 ) {
fprintf(msgFile, "\n\n forward solve finished") ;
fflush(msgFile) ;
}
/*
------------------------------------------
all processes wait here until all are done
------------------------------------------
*/
lowerIsDone[myid] = 'Y' ;
do {
if ( msglvl > 2 ) {
fprintf(msgFile, "\n\n 1. checking out lowerIsDone[] vector: ") ;
fflush(msgFile) ;
}
everybodyDone = 1 ;
for ( iproc = 0 ; iproc < nproc ; iproc++ ) {
if ( msglvl > 2 ) {
fprintf(msgFile, " %c", lowerIsDone[iproc]) ;
fflush(msgFile) ;
}
if ( lowerIsDone[iproc] != 'Y' ) {
everybodyDone = 0 ;
break ;
}
}
} while ( everybodyDone == 0 ) ;
/*
---------------------
do the diagonal solve
---------------------
*/
MARKTIME(t1) ;
for ( J = 0 ; J < nfront ; J++ ) {
if ( owners[J] == myid ) {
FrontMtx_diagonalVisit(frontmtx, J, owners, myid, p_mtx,
frontIsDone, p_agg, msglvl, msgFile) ;
/*
------------------------------------------------------------
set frontIsDone[J] to 'N' to be ready for the backward solve
------------------------------------------------------------
*/
frontIsDone[J] = 'N' ;
}
}
MARKTIME(t2) ;
cpus[3] += t2 - t1 ;
/*
------------------------------------------
all processes wait here until all are done
------------------------------------------
*/
diagIsDone[myid] = 'Y' ;
do {
if ( msglvl > 2 ) {
fprintf(msgFile, "\n\n 2. checking out diagIsDone[] vector: ") ;
fflush(msgFile) ;
}
everybodyDone = 1 ;
for ( iproc = 0 ; iproc < nproc ; iproc++ ) {
if ( msglvl > 2 ) {
fprintf(msgFile, " %c", diagIsDone[iproc]) ;
fflush(msgFile) ;
}
if ( diagIsDone[iproc] != 'Y' ) {
everybodyDone = 0 ;
break ;
}
}
} while ( everybodyDone == 0 ) ;
if ( msglvl > 2 ) {
fprintf(msgFile, "\n everybody is done") ;
fflush(msgFile) ;
}
/*
---------------------------------------------
set up for the backward solve
(1) IP lists for the backward solve
(2) set up the dequeue for the backward solve
---------------------------------------------
*/
MARKTIME(t1) ;
heads = SolveMap_backwardSetup(solvemap, myid, msglvl, msgFile) ;
if ( msglvl > 2 ) {
fprintf(msgFile, "\n\n ") ;
for ( J = 0 ; J < nfront ; J++ ) {
IP *ip ;
fprintf(msgFile, "\n %d :", J) ;
fflush(msgFile) ;
for ( ip = heads[J] ; ip != NULL ; ip = ip->next ) {
fprintf(msgFile, " %d", ip->val) ;
fflush(msgFile) ;
}
}
}
dequeue = FrontMtx_setUpDequeue(frontmtx, owners, myid, status, heads,
'W', 'F', msglvl, msgFile) ;
FrontMtx_loadActiveRoots(frontmtx, status, 'W', dequeue) ;
MARKTIME(t2) ;
cpus[1] += t2 - t1 ;
/*
--------------
backward solve
--------------
*/
#if MYDEBUG > 0
fflush(stdout) ;
#endif
if ( msglvl > 2 ) {
fprintf(msgFile, "\n\n thread %d : starting backward solve", myid) ;
fflush(msgFile) ;
}
MARKTIME(t1) ;
while ( (J = Ideq_removeFromHead(dequeue)) != -1 ) {
if ( msglvl > 1 ) {
fprintf(msgFile,
"\n\n ### backward solve, checking out J = %d, status %c",
J, status[J]) ;
fflush(msgFile) ;
}
FrontMtx_backwardVisit(frontmtx, J, nrhs, owners, myid, mtxmanager,
backAggList, p_mtx, frontIsDone, heads,
p_agg, status, msglvl, msgFile) ;
if ( status[J] == 'F' ) {
/*
-------------------
front J is finished
-------------------
*/
if ( msglvl > 1 ) {
fprintf(msgFile, "\n front J is finished") ;
fflush(msgFile) ;
}
for ( I = fch[J] ; I != -1 ; I = sib[I] ) {
if ( status[I] == 'W' ) {
Ideq_insertAtHead(dequeue, I) ;
}
}
} else {
/*
-------------------------------------------------
front J is not finished, place on tail of dequeue
-------------------------------------------------
*/
if ( msglvl > 1 ) {
fprintf(msgFile, "\n front J is not finished") ;
fflush(msgFile) ;
}
Ideq_insertAtTail(dequeue, J) ;
}
}
CVfree(status) ;
IP_free(heads[nfront+1]) ;
FREE(heads) ;
Ideq_free(dequeue) ;
MARKTIME(t2) ;
cpus[4] += t2 - t1 ;
/*
------------------------------------------
all processes wait here until all are done
------------------------------------------
*/
upperIsDone[myid] = 'Y' ;
do {
if ( msglvl > 2 ) {
fprintf(msgFile, "\n\n 2. checking out upperIsDone[] vector: ") ;
fflush(msgFile) ;
}
everybodyDone = 1 ;
for ( iproc = 0 ; iproc < nproc ; iproc++ ) {
if ( msglvl > 2 ) {
fprintf(msgFile, " %c", upperIsDone[iproc]) ;
fflush(msgFile) ;
}
if ( upperIsDone[iproc] != 'Y' ) {
everybodyDone = 0 ;
break ;
}
}
} while ( everybodyDone == 0 ) ;
/*
--------------------------
store the solution entries
--------------------------
*/
MARKTIME(t1) ;
FrontMtx_storeSolution(frontmtx, owners, myid, mtxmanager,
p_mtx, solmtx, msglvl, msgFile) ;
MARKTIME(t2) ;
cpus[5] += t2 - t1 ;
/*
------------------------
free the working storage
------------------------
*/
FREE(p_agg) ;
return(NULL) ; }
/*--------------------------------------------------------------------*/
/*
-----------------------------------------------------------------
set up the aggregate list objects for the forward and back solves
created -- 98mar20, cca
-----------------------------------------------------------------
*/
static void
setupListObjects (
FrontMtx *frontmtx,
SolveMap *solvemap,
SubMtxList *forwAggList,
SubMtxList *backAggList,
int msglvl,
FILE *msgFile
) {
char *flags ;
int J, nfront ;
int *aggcounts ;
IV *aggcountsIV ;
nfront = SolveMap_nfront(solvemap) ;
flags = CVinit(nfront, 'N') ;
/*
------------------------------------------------------
set up the aggregate list object for the forward solve
------------------------------------------------------
*/
aggcountsIV = SolveMap_lowerAggregateIV(solvemap, -1, msglvl, msgFile) ;
aggcounts = IV_entries(aggcountsIV) ;
for ( J = 0 ; J < nfront ; J++ ) {
if ( aggcounts[J] > 1 ) {
flags[J] = 'Y' ;
}
}
SubMtxList_init(forwAggList, nfront, aggcounts, 1, flags) ;
IV_free(aggcountsIV) ;
/*
-------------------------------------------------------
set up the aggregate list object for the backward solve
-------------------------------------------------------
*/
aggcountsIV = SolveMap_upperAggregateIV(solvemap, -1, msglvl, msgFile) ;
aggcounts = IV_entries(aggcountsIV) ;
for ( J = 0 ; J < nfront ; J++ ) {
if ( aggcounts[J] > 1 ) {
flags[J] = 'Y' ;
}
}
SubMtxList_init(backAggList, nfront, aggcounts, 1, flags) ;
IV_free(aggcountsIV) ;
CVfree(flags) ;
return ; }
/*--------------------------------------------------------------------*/
syntax highlighted by Code2HTML, v. 0.9.1