/* * Copyright (c) 2002-2006 Samit Basu * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; either version 2 of the License, or * (at your option) any later version. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with this program; if not, write to the Free Software * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA * */ //FIXME //#include "config.h" #ifdef USE_MPI #include #include "HandleList.hpp" #include "MPIWrap.hpp" #include "Malloc.hpp" #include "FunctionDef.hpp" #include "Context.hpp" HandleList comms; MPI_Errhandler errhdl; void MPIErrHandler(MPI_Comm *comm, int *errorcode, ...) { char buffer[4096]; int resultlen; MPI_Error_string(*errorcode,buffer,&resultlen); buffer[resultlen] = 0; throw Exception(buffer); } void InitializeMPIWrap() { comms.assignHandle(MPI_COMM_WORLD); MPI_Comm_create_errhandler(MPIErrHandler,&errhdl); MPI_Comm_set_errhandler(MPI_COMM_WORLD, errhdl); } // Helper function int ArrayToInt(const Array &A) { Array tmp(A); return tmp.getContentsAsIntegerScalar(); } /* * Send an array via MPI: * Arguments: MPISend(A, dest, tag, communicator) * Throws an exception if there is an error. * * How does this work? Well, tag is a positive integer - we * use it as upper bits in the message id, reserving lower * bits for the submessage. * * Suppose A is an array (homogenous) - * Part 1 - type * Part 2 - dimension data * Part 3 - array contents * * Suppose A is an array (structure) * Part 1 - type * Part 2 - dimension data * Part 3 - fields (if its a structure) * Part 4 - the component arrays */ /* * A simpler solution is to pack and unpack the data into a buffer and then * send the buffer using the raw protocol. The problem with this is that * buffers get duplicated... But that's ok - The problem is to determine * how much buffer space is needed - that too is non-trivial - consider... */ /* * We could do something like this: * Is the array homogenous? * * */ // An array transmission is composed of how many array transmissions: // 1. How many arrays are contained in A? --> ArrayCount // 2. Send msgID + ArrayCount // 3. For each Array // Simplest method: // 1. Pack the array into a buffer (and determine the resulting length) // 2. Send msgID*2 as the length of the buffer // 3. Send msgID*2+1 as the contents of the buffer int getCanonicalSize(int count, MPI_Datatype atype, MPI_Comm comm) { int size; MPI_Pack_size(count,atype,comm,&size); return size; } // OK, for now, I'm going to use this method. int getArrayByteFootPrint(Array &a, MPI_Comm comm) { unsigned int overhead; // How many bytes in the overhead overhead = getCanonicalSize(maxDims+1,MPI_INT, comm); Class dataClass(a.getDataClass()); // Is input array a reference type? if (a.isReferenceType()) { if (dataClass == FM_CELL_ARRAY) { int total = 0; Array *dp; dp = (Array *)a.getDataPointer(); for (int i=0;i //! ArrayVector MPISend(int nargout, const ArrayVector& args) { if ((args.size() < 3) || (args.size() > 4)) throw Exception("Expect 4 arguments to MPISend: array to send, destination rank, message tag and (optionally communicator handle -- defaults to MPI_COMM_WORLD)"); Array A(args[0]); Array tmp(args[1]); int dest(tmp.getContentsAsIntegerScalar()); tmp = args[2]; int tag(tmp.getContentsAsIntegerScalar()); int comhandle; if (args.size() > 3) { tmp = args[3]; comhandle = tmp.getContentsAsIntegerScalar(); } else { comhandle = 1; } // Calculate how much space we need to pack A MPI_Comm comm(comms.lookupHandle(comhandle)); int Asize = getArrayByteFootPrint(A,comm); int bufsize = Asize; // Allocate it... void *cp = malloc(Asize); int packpos = 0; packArrayMPI(A,cp,bufsize,&packpos,comm); MPI_Send(&packpos,1,MPI_INT,dest,tag,comm); MPI_Send(cp,packpos,MPI_PACKED,dest,tag,comm); free(cp); return ArrayVector(); } //! //@Module MPIBARRIER MPI Barrier //@@Section MPI //@@Usage //This function is used as a synchronization point for all //processes in a group. All processes are blocked until //every process calls @|mpibarrier|. The general syntax //for its use is //@[ // mpibarrier(comm) //@] //where @|comm| is the communicator. If no communicator is //provided, it defaults to @|MPI_COMM_WORLD|. //! ArrayVector MPIBarrier(int nargout, const ArrayVector& args) { int comhandle; if (args.size() > 0) { Array tmp(args[0]); comhandle = tmp.getContentsAsIntegerScalar(); } else { comhandle = 1; } MPI_Comm comm(comms.lookupHandle(comhandle)); MPI_Barrier(comm); return ArrayVector(); } /* * Broadcast an array via MPI: * Arguments: A = MPIBcast(A,root,communicator) */ //! //@Module MPIBCAST MPI Broadcast //@@Section MPI //@@Usage //This function is used to broadcast an array to all group //members. The syntax for its use //@[ // B = mpibcast(A,root,comm) //@] //where @|A| is the array to broadcast, @|root| is the rank //of the root of the broadcast, and @|comm| is the communicator //to do the broadcast on. If no communicator is provided, it //defaults to @|MPI_COMM_WORLD|. Note that in practice, the process //running at the root will use the syntax: //@[ // mpibcast(A,root,comm), //@] //while the remaining processes will use the syntax //@[ // B = mpibcast([],root,comm). //@] //! ArrayVector MPIBcast(int nargout, const ArrayVector& args) { if ((args.size() < 2) || (args.size() > 3)) throw Exception("expect 3 arguments to MPIBcast: array, root rank and (optionally) communicator handle"); Array A(args[0]); Array tmp(args[1]); int root(tmp.getContentsAsIntegerScalar()); int comhandle; if (args.size() > 2) { tmp = args[2]; comhandle = tmp.getContentsAsIntegerScalar(); } else { comhandle = 1; } MPI_Comm comm(comms.lookupHandle(comhandle)); // Get our rank int ourrank; MPI_Comm_rank(comm,&ourrank); ArrayVector retval; // Are we the originator of this broadcast? if (ourrank == root) { // Marshall the array into a message int Asize = getArrayByteFootPrint(A,comm); int bufsize = Asize; void *cp = malloc(Asize); int packpos = 0; packArrayMPI(A,cp,bufsize,&packpos,comm); // First broadcast the size MPI_Bcast(&packpos,1,MPI_INT,root,comm); // Then broadcast the data MPI_Bcast(cp,packpos,MPI_PACKED,root,comm); // Clean up free(cp); // Return it to the sender... retval.push_back(A); } else { // We are note the originator - wait for the size to // appear int msgsize; MPI_Bcast(&msgsize,1,MPI_INT,root,comm); void *cp = malloc(msgsize); MPI_Bcast(cp,msgsize,MPI_PACKED,root,comm); int packpos = 0; Array A2(unpackArrayMPI(cp,msgsize,&packpos,comm)); free(cp); retval.push_back(A2); } return retval; } /* * Recv an array via MPI: * Arguments: A = MPIRecv(source, tag, communicator) * Throws an exception if there is an error. */ //! //@Module MPIRECV MPI Receive Array //@@Section MPI //@@Usage //This function receives an array from a source node on //a given communicator with the specified tag. The //general syntax for its use is //@[ // y = mpirecv(rank,tag,comm) //@] //where @|rank| is the rank of the node sending the message, //@|tag| is the message tag and @|comm| is the communicator //to use. If no communicator is provided, then @|MPI_COMM_WORLD| //is used. //@@Example //The @|mpirecv| command is fairly straightforward to use. //Its power is in the ability to receive arrays of arbitrary //complexity, including cell arrays, structures, strings, etc. //Here is an example of an @|mpisend| and @|mpirecv| being used //on the same node to pass a structure through MPI. //@< //mpiinit //x.color = 'blue'; //x.pi = 3; //x.cells = {'2',2}; //mpisend(x,0,32); //y = mpirecv(0,32) //@> //! ArrayVector MPIRecv(int nargout, const ArrayVector& args) { if ((args.size() < 2) || (args.size() > 3)) throw Exception("Expect 3 arguments to MPIRecv: source rank, message tag and (optionally communicator handle -- defaults to MPI_COMM_WORLD)"); Array tmp(args[0]); int source(tmp.getContentsAsIntegerScalar()); tmp = args[1]; int tag(tmp.getContentsAsIntegerScalar()); int comhandle; if (args.size() > 2) { tmp = args[2]; comhandle = tmp.getContentsAsIntegerScalar(); } else { comhandle = 1; } MPI_Comm comm(comms.lookupHandle(comhandle)); int msgsize; MPI_Status status; MPI_Recv(&msgsize,1,MPI_INT,source,tag,comm,&status); void *cp = malloc(msgsize); MPI_Recv(cp,msgsize,MPI_PACKED,status.MPI_SOURCE,status.MPI_TAG,comm,MPI_STATUS_IGNORE); int packpos = 0; Array A(unpackArrayMPI(cp,msgsize,&packpos,comm)); free(cp); ArrayVector retval; retval.push_back(A); retval.push_back(Array::int32Constructor(status.MPI_SOURCE)); retval.push_back(Array::int32Constructor(status.MPI_TAG)); return retval; } //! //@Module MPICOMMRANK MPI Communicator Rank //@@Section MPI //@@Usage //This function returns the rank of a process within //the specified communicator. The general syntax for //its use is //@[ // y = mpicommrank(comm) //@] //where @|comm| is the communicator to use. If no communicator //is provided, then @|MPI_COMM_WORLD| is used. The returned value //@|y| is the rank of the current process in the communicator. //@@Example //Here is a simple example of using @|mpicommrank| to obtain the //process rank. It defaults to 0, because the process is the root //of the group (which contains only itself). //@< //mpiinit //mpicommrank //@> //! ArrayVector MPICommRank(int nargout, const ArrayVector& args) { int comhandle; if (args.size() == 0) { comhandle = 1; } else { Array tmp(args[0]); comhandle = tmp.getContentsAsIntegerScalar(); } MPI_Comm comm(comms.lookupHandle(comhandle)); int rank; MPI_Comm_rank(comm,&rank); ArrayVector retval; retval.push_back(Array::int32Constructor(rank)); return retval; } //! //@Module MPICOMMSIZE MPI Communicator Size //@@Section MPI //@@Usage //This function returns the size of the group using the //given communicator. The general syntax for its use is //@[ // y = mpicommsize(comm) //@] //where @|comm| is the communicator to use. If no communicator //is provided, then @|MPI_COMM_WORLD| is assumed. //@@Example //Here is a simple example of using @|mpicommsize|: //@< //mpiinit //mpicommrank //@> //! ArrayVector MPICommSize(int nargout, const ArrayVector& args) { int comhandle; if (args.size() == 0) { comhandle = 1; } else { Array tmp(args[0]); comhandle = tmp.getContentsAsIntegerScalar(); } MPI_Comm comm(comms.lookupHandle(comhandle)); int size; MPI_Comm_size(comm,&size); ArrayVector retval; retval.push_back(Array::int32Constructor(size)); return retval; } /* * syntax: x = mpiallreduce(y,operation,root,comm) */ //! //@Module MPIALLREDUCE MPI All Reduce Operation //@@Section MPI //@@Usage //This function implements the all-reduce operation using MPI. //The general syntax for its use is //@[ // x = mpiallreduce(y,operation,comm) //@] //! ArrayVector MPIAllReduce(int nargout, const ArrayVector& args) { int comhandle; if (args.size() < 3) comhandle = 1; else comhandle = ArrayToInt(args[2]); MPI_Comm comm(comms.lookupHandle(comhandle)); if (args.size() < 2) throw Exception("mpiallreduce requires an array, an operation"); char *op; Array oper(args[1]); op = oper.getContentsAsCString(); MPI_Op mpiop; switch (*op) { case '+': mpiop = MPI_SUM; break; case '*': mpiop = MPI_PROD; break; case '<': mpiop = MPI_MIN; break; case '>': mpiop = MPI_MAX; break; default: throw Exception(std::string("unrecognized mpiop type:") + op + ": supported types are '+','*','>' and '<'"); } Array source(args[0]); Array dest(source); Class dataClass(source.getDataClass()); switch (dataClass) { case FM_LOGICAL: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_CHAR,mpiop,comm); break; case FM_UINT8: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_CHAR,mpiop,comm); break; case FM_INT8: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_CHAR,mpiop,comm); break; case FM_UINT16: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_SHORT,mpiop,comm); break; case FM_INT16: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_SHORT,mpiop,comm); break; case FM_UINT32: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED,mpiop,comm); break; case FM_INT32: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_INT,mpiop,comm); break; case FM_FLOAT: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_FLOAT,mpiop,comm); break; case FM_DOUBLE: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_DOUBLE,mpiop,comm); break; case FM_COMPLEX: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),2*source.getLength(),MPI_FLOAT,mpiop,comm); break; case FM_DCOMPLEX: MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),2*source.getLength(),MPI_DOUBLE,mpiop,comm); break; case FM_INT64: case FM_UINT64: throw Exception("MPI support for 64 bit values is still needed!"); default: throw Exception("unsupported array type in argument to allreduce - must be a numerical type"); } ArrayVector retval; retval.push_back(dest); return retval; } /* * syntax: x = mpireduce(y,operation,root,comm) */ //! //@Module MPIREDUCE MPI Reduce Operation //@@Section MPI //@@Usage //This function implements the reduction operation using MPI. //The general syntax for its use is //@[ // x = mpireduce(y,operation,root,comm) //@] //where @|y| is the current processes contribution to the //reduction operation, @|operation| is either @|'+','*','>','<'| for //an additive, multiplicative, max or min type reduction operations //respectively, //@|root| is the rank of the process that will retrieve the //result of the reduction operation, and @|comm| is the MPI //communicator handle. If no communicator is provided, //then @|MPI_COMM_WORLD| is used by default. Note that FreeMat does not //check to ensure that the reduction operation @|y| arguments //are all the same size across the various processes in the group. //Instead, you must make sure that each process passes the same //sized array to the @|mpireduce| operation. //! ArrayVector MPIReduce(int nargout, const ArrayVector& args) { int comhandle; if (args.size() < 4) comhandle = 1; else comhandle = ArrayToInt(args[3]); MPI_Comm comm(comms.lookupHandle(comhandle)); if (args.size() < 3) throw Exception("mpireduce requires an array, an operation, and a root rank"); int root = ArrayToInt(args[2]); char *op; Array oper(args[1]); op = oper.getContentsAsCString(); MPI_Op mpiop; switch (*op) { case '+': mpiop = MPI_SUM; break; case '*': mpiop = MPI_PROD; break; case '<': mpiop = MPI_MIN; break; case '>': mpiop = MPI_MAX; break; default: throw Exception(std::string("unrecognized mpiop type:") + op + ": supported types are '+','*','>' and '<'"); } Array source(args[0]); Array dest(source); Class dataClass(source.getDataClass()); switch (dataClass) { case FM_LOGICAL: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_CHAR,mpiop,root,comm); break; case FM_UINT8: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_CHAR,mpiop,root,comm); break; case FM_INT8: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_CHAR,mpiop,root,comm); break; case FM_UINT16: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_SHORT,mpiop,root,comm); break; case FM_INT16: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_SHORT,mpiop,root,comm); break; case FM_UINT32: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED,mpiop,root,comm); break; case FM_INT32: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_INT,mpiop,root,comm); break; case FM_FLOAT: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_FLOAT,mpiop,root,comm); break; case FM_DOUBLE: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_DOUBLE,mpiop,root,comm); break; case FM_COMPLEX: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),2*source.getLength(),MPI_FLOAT,mpiop,root,comm); break; case FM_DCOMPLEX: MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),2*source.getLength(),MPI_DOUBLE,mpiop,root,comm); break; case FM_INT64: case FM_UINT64: throw Exception("MPI support for 64 bit values is still needed!"); default: throw Exception("unsupported array type in argument to reduce - must be a numerical type"); } ArrayVector retval; retval.push_back(dest); return retval; } //! //@Module MPIINITIALIZED MPI Initialized Test //@@Section MPI //@@Usage //This function tests to see if MPI is already initialized. //The general syntax for its use is //@[ // x = mpiinitialized //@] //It returns a logical 1 if @|mpiinit| has already been called, //and a logical 0 otherwise. //@@Example //Here we call @|mpiinitialized| before and after a call to //@|mpiinit|. //@< //mpiinitialized //mpiinit //mpiinitialized //@> //! ArrayVector MPIInitialized(int nargout, const ArrayVector& args) { int flag; MPI_Initialized(&flag); ArrayVector retval; retval.push_back(Array::logicalConstructor(flag)); return retval; } //! //@Module MPIINIT MPI Initialize //@@Section MPI //@@Usage //This function initializes the MPI subsystem and joins //the current FreeMat process to the MPI environment. //The general syntax for its use is //@[ // mpiinit //@] //Note that @|mpiinit| must be called before any other //MPI routines (with the exception of @|mpiinitialized|), or //an MPI error will occur. //! ArrayVector MPIInit(int nargout, const ArrayVector& args) { int flag; MPI_Initialized(&flag); if (flag) { return ArrayVector(); } MPI_Init(NULL,NULL); InitializeMPIWrap(); return ArrayVector(); } //! //@Module MPICOMMSPAWN MPI Communicator Spawn //@@Section MPI //@@Usage //This function uses MPI to spawn a process on a members of a group. //The full power of the underlying routine, @|MPI_Comm_spawn| is not //yet available via the @|mpicommspawn| routine in FreeMat. The //general syntax for its use is //@[ // errcodes = mpicommspawn(command,argv,maxprocs,root,comm) //@] //where @|command| is the command to execute, @|argv| is a cell-array //of strings to pass as arguments to @|command|, @|maxprocs| is the //number of processes to spawn, @|root| is the node that will actually //do the process spawn, and @|comm| is the communicator to use. //If no communicator is specified, @|comm| defaults to @|MPI_COMM_SELF|. //If @|root| is not specified, it defaults to 0. If @|maxprocs| is //not specified, it defaults to 1. If @|argv| is not specified no //arguments are passed to the spawned processes. //! ArrayVector MPICommSpawn(int nargout, const ArrayVector& args) { char *command; char **argv; int maxprocs; int root; MPI_Info info; MPI_Info_create(&info); MPI_Comm intercomm; MPI_Comm comm; if (args.size() == 0) throw Exception("mpicommspawn requires at least one argument (name of the command to spawn"); if (args.size() > 5) throw Exception("too many arguments to mpicommspawn."); Array t1(args[0]); command = t1.getContentsAsCString(); if (args.size() < 2) argv = NULL; else { Array t2(args[1]); if (t2.isEmpty()) argv = NULL; else { if (t2.isString()) { argv = (char**) malloc(sizeof(char*)*2); argv[1] = NULL; argv[0] = t2.getContentsAsCString(); } else if (t2.getDataClass() == FM_CELL_ARRAY) { Array *dp; dp = (Array*) t2.getDataPointer(); int len; len = t2.getLength(); argv = (char**) malloc(sizeof(char*)*(len+1)); argv[len] = 0; for (int m=0;maddFunction("mpisend",MPISend,4,0,"array","dest","tag","communicator"); context->addFunction("mpirecv",MPIRecv,3,3,"source","tag","communicator"); context->addFunction("mpibcast",MPIBcast,3,1,"array","root","communicator"); context->addFunction("mpibarrier",MPIBarrier,1,0,"communicator"); context->addFunction("mpicommrank",MPICommRank,1,1,"communicator"); context->addFunction("mpicommsize",MPICommSize,1,1,"communicator"); context->addFunction("mpireduce",MPIReduce,4,1,"y","operation","root","comm"); context->addFunction("mpiallreduce",MPIAllReduce,3,1,"y","operation","root"); context->addFunction("mpiinitialized",MPIInitialized,0,1); context->addFunction("mpiinit",MPIInit,0,0); context->addFunction("mpifinalize",MPIFinalize,0,0); context->addFunction("mpicommgetparent",MPICommGetParent,0,1); context->addFunction("mpicommspawn",MPICommSpawn,5,2, "command","args","maxprocs","root","comm"); context->addFunction("mpiintercommmerge",MPIIntercommMerge,2,1,"intercomm","highflag"); } #endif