/*
 * Copyright (c) 2002-2006 Samit Basu
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
 *
 */

//FIXME
//#include "config.h"
#ifdef USE_MPI

#include <mpi.h>
#include "HandleList.hpp"
#include "MPIWrap.hpp"
#include "Malloc.hpp"
#include "FunctionDef.hpp"
#include "Context.hpp"

HandleList<MPI_Comm> comms;
MPI_Errhandler errhdl;

void MPIErrHandler(MPI_Comm *comm, int *errorcode, ...) {
  char buffer[4096];
  int resultlen;
  MPI_Error_string(*errorcode,buffer,&resultlen);
  buffer[resultlen] = 0;
  throw Exception(buffer);
}

void InitializeMPIWrap() {
  comms.assignHandle(MPI_COMM_WORLD);
  MPI_Comm_create_errhandler(MPIErrHandler,&errhdl);
  MPI_Comm_set_errhandler(MPI_COMM_WORLD, errhdl);
}

// Helper function
int ArrayToInt(const Array &A) {
  Array tmp(A);
  return tmp.getContentsAsIntegerScalar();
}
  
/* 
 * Send an array via MPI:
 *  Arguments: MPISend(A, dest, tag, communicator)
 *  Throws an exception if there is an error.
 *
 *  How does this work?  Well, tag is a positive integer - we
 *  use it as upper bits in the message id, reserving lower
 *  bits for the submessage.
 *
 *  Suppose A is an array (homogenous) - 
 *     Part 1 - type
 *     Part 2 - dimension data
 *     Part 3 - array contents
 *
 *  Suppose A is an array (structure)
 *     Part 1 - type
 *     Part 2 - dimension data
 *     Part 3 - fields (if its a structure) 
 *     Part 4 - the component arrays
 */


/*
 * A simpler solution is to pack and unpack the data into a buffer and then
 * send the buffer using the raw protocol.  The problem with this is that 
 * buffers get duplicated...  But that's ok -  The problem is to determine
 * how much buffer space is needed - that too is non-trivial - consider...
 */

/*
 * We could do something like this:
 *  Is the array homogenous?
 *
 *
 */

// An array transmission is composed of how many array transmissions:
//   1.  How many arrays are contained in A? --> ArrayCount
//   2.  Send msgID + ArrayCount
//   3.  For each Array

// Simplest method:
//   1.  Pack the array into a buffer (and determine the resulting length)
//   2.  Send msgID*2 as the length of the buffer
//   3.  Send msgID*2+1 as the contents of the buffer

int getCanonicalSize(int count, MPI_Datatype atype, MPI_Comm comm) {
  int size;
  MPI_Pack_size(count,atype,comm,&size);
  return size;
}

// OK, for now, I'm going to use this method.
int getArrayByteFootPrint(Array &a, MPI_Comm comm) {
  unsigned int overhead;
  // How many bytes in the overhead
  overhead = getCanonicalSize(maxDims+1,MPI_INT, comm);
  Class dataClass(a.getDataClass());
  // Is input array a reference type?
  if (a.isReferenceType()) {
    if (dataClass == FM_CELL_ARRAY) {
      int total = 0;
      Array *dp;
      dp = (Array *)a.getDataPointer();
      for (int i=0;i<a.getLength();i++)
	total += getArrayByteFootPrint(dp[i],comm);
      return (total+overhead);
    } else {
      // Array is a structure array
      stringVector fieldnames(a.getFieldNames());
      int fieldcount = 0;
      fieldcount = fieldnames.size();
      // Start out with the number of fields
      int fieldsize = getCanonicalSize(1,MPI_INT,comm);
      // Each field is encoded as a length + the number of characters in the name
      for (int j=0;j<fieldcount;j++)
	fieldsize += getCanonicalSize(1,MPI_INT,comm) + 
	  getCanonicalSize(fieldnames[j].size(),MPI_CHAR,comm);
      int total = 0;
      Array *dp;
      dp = (Array *) a.getDataPointer();
      for (int i=0;i<a.getLength()*fieldcount;i++)
	total += getArrayByteFootPrint(dp[i],comm);
      return (total+overhead+fieldsize+1);
    }
  }
  switch(dataClass) {
  case FM_LOGICAL:
    return(overhead+getCanonicalSize(a.getLength(),MPI_CHAR,comm));
  case FM_UINT8:
    return(overhead+getCanonicalSize(a.getLength(),MPI_UNSIGNED_CHAR,comm));
  case FM_INT8:
    return(overhead+getCanonicalSize(a.getLength(),MPI_CHAR,comm));
  case FM_UINT16:
    return(overhead+getCanonicalSize(a.getLength(),MPI_UNSIGNED_SHORT,comm));
  case FM_INT16:
    return(overhead+getCanonicalSize(a.getLength(),MPI_SHORT,comm));
  case FM_UINT32:
    return(overhead+getCanonicalSize(a.getLength(),MPI_UNSIGNED,comm));
  case FM_INT32:
    return(overhead+getCanonicalSize(a.getLength(),MPI_INT,comm));
  case FM_FLOAT:
    return(overhead+getCanonicalSize(a.getLength(),MPI_FLOAT,comm));
  case FM_DOUBLE:
    return(overhead+getCanonicalSize(a.getLength(),MPI_DOUBLE,comm));
  case FM_COMPLEX:
    return(overhead+getCanonicalSize(a.getLength()*2,MPI_FLOAT,comm));
  case FM_DCOMPLEX:
    return(overhead+getCanonicalSize(a.getLength()*2,MPI_DOUBLE,comm));
  case FM_STRING:
    return(overhead+getCanonicalSize(a.getLength(),MPI_CHAR,comm));
  case FM_INT64:
  case FM_UINT64:
    throw Exception("MPI support for 64 bit values is still needed!");
  }
}

Class decodeDataClassFromInteger(int code) {
  switch(code) {
  case 1024:
    return FM_CELL_ARRAY;
  case 1025:
    return FM_STRUCT_ARRAY;
  case 1026:
    return FM_LOGICAL;
  case 1027:
    return FM_UINT8;
  case 1028:
    return FM_INT8;
  case 1029:
    return FM_UINT16;
  case 1030:
    return FM_INT16;
  case 1031:
    return FM_UINT32;
  case 1032:
    return FM_INT32;
  case 1033:
    return FM_FLOAT;
  case 1034:
    return FM_DOUBLE;
  case 1035:
    return FM_COMPLEX;
  case 1036:
    return FM_DCOMPLEX;
  case 1037:
    return FM_STRING;
  }
}

int encodeDataClassAsInteger(Class dataClass) {
  switch (dataClass) {
  case FM_CELL_ARRAY:
    return 1024;
  case FM_STRUCT_ARRAY:
    return 1025;
  case FM_LOGICAL:
    return 1026;
  case FM_UINT8:
    return 1027;
  case FM_INT8:
    return 1028;
  case FM_UINT16:
    return 1029;
  case FM_INT16:
    return 1030;
  case FM_UINT32:
    return 1031;
  case FM_INT32:
    return 1032;
  case FM_FLOAT:
    return 1033;
  case FM_DOUBLE:
    return 1034;
  case FM_COMPLEX:
    return 1035;
  case FM_DCOMPLEX:
    return 1036;
  case FM_STRING:
    return 1037;
  case FM_INT64:
  case FM_UINT64:
    throw Exception("MPI support for 64 bit values is still needed!");
  }
}
  
// Pack an array into an MPI buffer using the MPI Pack functions
// We assume that the buffer is large enough, i.e. that it is 
// at least of size getArrayByteFootPrint(a) in size.
void packArrayMPI(Array &a, void *buffer, int bufsize, int *packpos, MPI_Comm comm) {
  Class dataClass(a.getDataClass());
  int idclass;
  int dimlength;
  idclass = encodeDataClassAsInteger(dataClass);
  MPI_Pack(&idclass,1,MPI_INT,buffer,bufsize,packpos,comm);
  dimlength = a.getDimensions().getLength();
  MPI_Pack(&dimlength,1,MPI_INT,buffer,bufsize,packpos,comm);
  for (int j=0;j<dimlength;j++) {
    int tmp;
    tmp = a.getDimensionLength(j);
    MPI_Pack(&tmp,1,MPI_INT,buffer,bufsize,packpos,comm);
  }
  if (a.isReferenceType()) {
    if (dataClass == FM_CELL_ARRAY) {
      Array *dp;
      dp = (Array *) a.getDataPointer();
      for (int i=0;i<a.getLength();i++)
	packArrayMPI(dp[i],buffer,bufsize,packpos,comm);
    } else {
      stringVector fieldnames(a.getFieldNames());
      int fieldcnt(fieldnames.size());
      MPI_Pack(&fieldcnt,1,MPI_INT,buffer,bufsize,packpos,comm);
      for (int i=0;i<fieldcnt;i++) {
	int flen;
	flen = fieldnames[i].size();
	MPI_Pack(&flen,1,MPI_INT,buffer,bufsize,packpos,comm);
	MPI_Pack((void*) fieldnames[i].c_str(),flen,MPI_CHAR,buffer,bufsize,packpos,comm);
      }
      Array *dp;
      dp = (Array *) a.getDataPointer();
      for (int i=0;i<a.getLength()*fieldcnt;i++)
	packArrayMPI(dp[i],buffer,bufsize,packpos,comm);
    }
  } else {
    switch(dataClass) {
    case FM_LOGICAL:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_CHAR,buffer,bufsize,packpos,comm);
      break;
    case FM_UINT8:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_UNSIGNED_CHAR,buffer,bufsize,packpos,comm);
      break;
    case FM_INT8:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_CHAR,buffer,bufsize,packpos,comm);
      break;
    case FM_UINT16:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_UNSIGNED_SHORT,buffer,bufsize,packpos,comm);
      break;
    case FM_INT16:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_SHORT,buffer,bufsize,packpos,comm);
      break;
    case FM_UINT32:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_UNSIGNED,buffer,bufsize,packpos,comm);
      break;
    case FM_INT32:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_INT,buffer,bufsize,packpos,comm);
      break;
    case FM_FLOAT:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_FLOAT,buffer,bufsize,packpos,comm);
      break;
    case FM_DOUBLE:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_DOUBLE,buffer,bufsize,packpos,comm);
      break;
    case FM_COMPLEX:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_FLOAT,buffer,bufsize,packpos,comm);
      break;
    case FM_DCOMPLEX:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_DOUBLE,buffer,bufsize,packpos,comm);
      break;
    case FM_STRING:
      MPI_Pack((void *) a.getDataPointer(),a.getLength(),MPI_CHAR,buffer,bufsize,packpos,comm);
      break;
    case FM_INT64:
    case FM_UINT64:
      throw Exception("MPI support for 64 bit values is still needed!");
    }
  }
}

Array unpackArrayMPI(void *buffer, int bufsize, int *packpos, MPI_Comm comm) {
  Class dataClass;
  int idclass;
  int dimlength;
  MPI_Unpack(buffer,bufsize,packpos,&idclass,1,MPI_INT,comm);
  dataClass = decodeDataClassFromInteger(idclass);
  MPI_Unpack(buffer,bufsize,packpos,&dimlength,1,MPI_INT,comm);
  Dimensions outDim;
  for (int j=0;j<dimlength;j++) {
    int tmp;
    MPI_Unpack(buffer,bufsize,packpos,&tmp,1,MPI_INT,comm);
    outDim[j] = tmp;
  }
  if (dataClass == FM_CELL_ARRAY) {
    Array *dp;
    dp = new Array[outDim.getElementCount()];
    for (int i=0;i<outDim.getElementCount();i++)
      dp[i] = unpackArrayMPI(buffer,bufsize,packpos,comm);
    return Array(FM_CELL_ARRAY,outDim,dp);
  } else if (dataClass == FM_STRUCT_ARRAY) {
    int fieldcnt;
    MPI_Unpack(buffer,bufsize,packpos,&fieldcnt,1,MPI_INT,comm);
    stringVector fieldnames;
    for (int j=0;j<fieldcnt;j++) {
      int fieldnamelength;
      MPI_Unpack(buffer,bufsize,packpos,&fieldnamelength,1,MPI_INT,comm);
      char *dbuff;
      dbuff = (char*) malloc(fieldnamelength+1);
      MPI_Unpack(buffer,bufsize,packpos,dbuff,fieldnamelength,MPI_CHAR,comm);
      dbuff[fieldnamelength] = 0;
      fieldnames.push_back(std::string(dbuff));
      free(dbuff);
    }
    Array *dp;
    dp = new Array[fieldcnt*outDim.getElementCount()];
    for (int i=0;i<fieldcnt*outDim.getElementCount();i++)
      dp[i] = unpackArrayMPI(buffer,bufsize,packpos,comm);
    return Array(FM_STRUCT_ARRAY,outDim,dp,false,fieldnames);
  }
  void *cp;
  switch(dataClass) {
  case FM_LOGICAL:
    cp = Malloc(sizeof(logical)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_CHAR,comm);
    break;
  case FM_UINT8:
    cp = Malloc(sizeof(uint8)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_UNSIGNED_CHAR,comm);
    break;
  case FM_INT8:
    cp = Malloc(sizeof(int8)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_CHAR,comm);
    break;
  case FM_UINT16:
    cp = Malloc(sizeof(uint16)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_UNSIGNED_SHORT,comm);
    break;
  case FM_INT16:
    cp = Malloc(sizeof(int16)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_SHORT,comm);
    break;
  case FM_UINT32:
    cp = Malloc(sizeof(uint32)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_UNSIGNED,comm);
    break;
  case FM_INT32:
    cp = Malloc(sizeof(int32)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_INT,comm);
    break;
  case FM_FLOAT:
    cp = Malloc(sizeof(float)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_FLOAT,comm);
    break;
  case FM_DOUBLE:
    cp = Malloc(sizeof(double)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_DOUBLE,comm);
    break;
  case FM_COMPLEX:
    cp = Malloc(sizeof(float)*2*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount()*2,MPI_FLOAT,comm);
    break;
  case FM_DCOMPLEX:
    cp = Malloc(sizeof(double)*2*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount()*2,MPI_DOUBLE,comm);
    break;
  case FM_STRING:
    cp = Malloc(sizeof(char)*outDim.getElementCount());
    MPI_Unpack(buffer,bufsize,packpos,cp,outDim.getElementCount(),MPI_CHAR,comm);
    break;
  case FM_INT64:
  case FM_UINT64:
    throw Exception("MPI support for 64 bit values is still needed!");
  }
  return Array(dataClass,outDim,cp);
}

//!
//@Module MPISEND MPI Send Array
//@@Section MPI
//@@Usage
//This function sends an array to a destination node on a 
//given communicator with a specific message tag.  Note that
//there has to be a matching receive issued by the destination
//node.  The general syntax for its use is
//@[
//  mpisend(x,rank,tag,comm)
//@]
//where @|x| is the array to send, @|rank| is the rank of the
//node to receive the message, @|tag| is the message tag, and
//@|comm| is the handle of the communicator to use.  If no
//communicator is specified, then @|MPI_COMM_WORLD| is used.
//@@Function Internals
//The @|mpisend| command works by packing the array into a 
//linear buffer and then sending two messages.  The first
//message captures the size of the buffer, and the second
//contains the actual data.  The matching @|mpirecv| command
//reads the two messages, decodes the buffer, and returns
//the resulting array.
//@@Example
//The @|mpisend| command is fairly straightforward to use.
//Its power is in the ability to send arrays of arbitrary
//complexity, including cell arrays, structures, strings, etc.
//Here is an example of an @|mpisend| and @|mpirecv| being used
//on the same node to pass a structure through MPI.
//@<
//mpiinit
//x.color = 'blue';
//x.pi = 3;
//x.cells = {'2',2};
//mpisend(x,0,32);
//y = mpirecv(0,32)
//@>
//!
ArrayVector MPISend(int nargout, const ArrayVector& args) {
  if ((args.size() < 3) || (args.size() > 4))
    throw Exception("Expect 4 arguments to MPISend: array to send, destination rank, message tag and (optionally communicator handle -- defaults to MPI_COMM_WORLD)");
  Array A(args[0]);
  Array tmp(args[1]);
  int dest(tmp.getContentsAsIntegerScalar());
  tmp = args[2];
  int tag(tmp.getContentsAsIntegerScalar());
  int comhandle;
  if (args.size() > 3) {
    tmp = args[3];
    comhandle = tmp.getContentsAsIntegerScalar();
  } else {
    comhandle = 1;
  }
  // Calculate how much space we need to pack A
  MPI_Comm comm(comms.lookupHandle(comhandle));
  int Asize = getArrayByteFootPrint(A,comm);
  int bufsize = Asize;
  // Allocate it...
  void *cp = malloc(Asize);
  int packpos = 0;
  packArrayMPI(A,cp,bufsize,&packpos,comm);
  MPI_Send(&packpos,1,MPI_INT,dest,tag,comm);
  MPI_Send(cp,packpos,MPI_PACKED,dest,tag,comm);
  free(cp);
  return ArrayVector();
}

//!
//@Module MPIBARRIER MPI Barrier
//@@Section MPI
//@@Usage
//This function is used as a synchronization point for all
//processes in a group.  All processes are blocked until
//every process calls @|mpibarrier|.  The general syntax 
//for its use is
//@[
//  mpibarrier(comm)
//@]
//where @|comm| is the communicator.  If no communicator is
//provided, it defaults to @|MPI_COMM_WORLD|.
//!
ArrayVector MPIBarrier(int nargout, const ArrayVector& args) {
  int comhandle;
  if (args.size() > 0) {
    Array tmp(args[0]);
    comhandle = tmp.getContentsAsIntegerScalar();
  } else {
    comhandle = 1;
  }
  MPI_Comm comm(comms.lookupHandle(comhandle));
  MPI_Barrier(comm);
  return ArrayVector();
}

/*
 * Broadcast an array via MPI:
 *  Arguments: A = MPIBcast(A,root,communicator)
 */
//!
//@Module MPIBCAST MPI Broadcast
//@@Section MPI
//@@Usage
//This function is used to broadcast an array to all group
//members.  The syntax for its use 
//@[
//  B = mpibcast(A,root,comm)
//@]
//where @|A| is the array to broadcast, @|root| is the rank
//of the root of the broadcast, and @|comm| is the communicator
//to do the broadcast on.  If no communicator is provided, it
//defaults to @|MPI_COMM_WORLD|.  Note that in practice, the process
//running at the root will use the syntax:
//@[
//  mpibcast(A,root,comm),
//@]
//while the remaining processes will use the syntax
//@[
//  B = mpibcast([],root,comm).
//@]
//!
ArrayVector MPIBcast(int nargout, const ArrayVector& args) {
  if ((args.size() < 2) || (args.size() > 3))
    throw Exception("expect 3 arguments to MPIBcast: array, root rank and (optionally) communicator handle");
  Array A(args[0]);
  Array tmp(args[1]);
  int root(tmp.getContentsAsIntegerScalar());
  int comhandle;
  if (args.size() > 2) {
    tmp = args[2];
    comhandle = tmp.getContentsAsIntegerScalar();
  } else {
    comhandle = 1;
  }
  MPI_Comm comm(comms.lookupHandle(comhandle));
  // Get our rank
  int ourrank;
  MPI_Comm_rank(comm,&ourrank);
  ArrayVector retval;
  // Are we the originator of this broadcast?
  if (ourrank == root) {
    // Marshall the array into a message
    int Asize = getArrayByteFootPrint(A,comm);
    int bufsize = Asize;
    void *cp = malloc(Asize);
    int packpos = 0;
    packArrayMPI(A,cp,bufsize,&packpos,comm);
    // First broadcast the size
    MPI_Bcast(&packpos,1,MPI_INT,root,comm);
    // Then broadcast the data
    MPI_Bcast(cp,packpos,MPI_PACKED,root,comm);
    // Clean up
    free(cp);
    // Return it to the sender...
    retval.push_back(A);
  } else {
    // We are note the originator - wait for the size to 
    // appear
    int msgsize;
    MPI_Bcast(&msgsize,1,MPI_INT,root,comm);
    void *cp = malloc(msgsize);
    MPI_Bcast(cp,msgsize,MPI_PACKED,root,comm);
    int packpos = 0;
    Array A2(unpackArrayMPI(cp,msgsize,&packpos,comm));
    free(cp);
    retval.push_back(A2);
  }
  return retval;
}

/*
 * Recv an array via MPI:
 *  Arguments: A = MPIRecv(source, tag, communicator)
 *  Throws an exception if there is an error.
 */
//!
//@Module MPIRECV MPI Receive Array
//@@Section MPI
//@@Usage
//This function receives an array from a source node on
//a given communicator with the specified tag.  The 
//general syntax for its use is
//@[
//  y = mpirecv(rank,tag,comm)
//@]
//where @|rank| is the rank of the node sending the message,
//@|tag| is the message tag and @|comm| is the communicator
//to use.  If no communicator is provided, then @|MPI_COMM_WORLD|
//is used.
//@@Example
//The @|mpirecv| command is fairly straightforward to use.
//Its power is in the ability to receive arrays of arbitrary
//complexity, including cell arrays, structures, strings, etc.
//Here is an example of an @|mpisend| and @|mpirecv| being used
//on the same node to pass a structure through MPI.
//@<
//mpiinit
//x.color = 'blue';
//x.pi = 3;
//x.cells = {'2',2};
//mpisend(x,0,32);
//y = mpirecv(0,32)
//@>
//!
ArrayVector MPIRecv(int nargout, const ArrayVector& args) {
  if ((args.size() < 2) || (args.size() > 3))
    throw Exception("Expect 3 arguments to MPIRecv: source rank, message tag and (optionally communicator handle -- defaults to MPI_COMM_WORLD)");
  Array tmp(args[0]);
  int source(tmp.getContentsAsIntegerScalar());
  tmp = args[1];
  int tag(tmp.getContentsAsIntegerScalar());
  int comhandle;
  if (args.size() > 2) {
    tmp = args[2];
    comhandle = tmp.getContentsAsIntegerScalar();
  } else {
    comhandle = 1;
  }
  MPI_Comm comm(comms.lookupHandle(comhandle));
  int msgsize;
  MPI_Status status;
  MPI_Recv(&msgsize,1,MPI_INT,source,tag,comm,&status);
  void *cp = malloc(msgsize);
  MPI_Recv(cp,msgsize,MPI_PACKED,status.MPI_SOURCE,status.MPI_TAG,comm,MPI_STATUS_IGNORE);
  int packpos = 0;
  Array A(unpackArrayMPI(cp,msgsize,&packpos,comm));
  free(cp);
  ArrayVector retval;
  retval.push_back(A);
  retval.push_back(Array::int32Constructor(status.MPI_SOURCE));
  retval.push_back(Array::int32Constructor(status.MPI_TAG));
  return retval;
}

//!
//@Module MPICOMMRANK MPI Communicator Rank
//@@Section MPI
//@@Usage
//This function returns the rank of a process within
//the specified communicator.  The general syntax for 
//its use is
//@[
//   y = mpicommrank(comm)
//@]
//where @|comm| is the communicator to use.  If no communicator
//is provided, then @|MPI_COMM_WORLD| is used.  The returned value
//@|y| is the rank of the current process in the communicator.
//@@Example
//Here is a simple example of using @|mpicommrank| to obtain the
//process rank.  It defaults to 0, because the process is the root
//of the group (which contains only itself).
//@<
//mpiinit
//mpicommrank
//@>
//!
ArrayVector MPICommRank(int nargout, const ArrayVector& args) {
  int comhandle;
  if (args.size() == 0) {
    comhandle = 1;
  } else {
    Array tmp(args[0]);
    comhandle = tmp.getContentsAsIntegerScalar();
  }
  MPI_Comm comm(comms.lookupHandle(comhandle));
  int rank;
  MPI_Comm_rank(comm,&rank);
  ArrayVector retval;
  retval.push_back(Array::int32Constructor(rank));
  return retval;
}

//!
//@Module MPICOMMSIZE MPI Communicator Size
//@@Section MPI
//@@Usage
//This function returns the size of the group using the
//given communicator.  The general syntax for its use is
//@[
//    y = mpicommsize(comm)
//@]
//where @|comm| is the communicator to use.  If no communicator 
//is provided, then @|MPI_COMM_WORLD| is assumed.  
//@@Example
//Here is a simple example of using @|mpicommsize|:
//@<
//mpiinit
//mpicommrank
//@>
//!
ArrayVector MPICommSize(int nargout, const ArrayVector& args) {
  int comhandle;
  if (args.size() == 0) {
    comhandle = 1;
  } else {
    Array tmp(args[0]);
    comhandle = tmp.getContentsAsIntegerScalar();
  }
  MPI_Comm comm(comms.lookupHandle(comhandle));
  int size;
  MPI_Comm_size(comm,&size);
  ArrayVector retval;
  retval.push_back(Array::int32Constructor(size));
  return retval;
}

/*
 * syntax: x = mpiallreduce(y,operation,root,comm)
 */
//!
//@Module MPIALLREDUCE MPI All Reduce Operation
//@@Section MPI
//@@Usage
//This function implements the all-reduce operation using MPI.
//The general syntax for its use is
//@[
//   x = mpiallreduce(y,operation,comm)
//@]
//!
ArrayVector MPIAllReduce(int nargout, const ArrayVector& args) {
  int comhandle;
  if (args.size() < 3)
    comhandle = 1;
  else 
    comhandle = ArrayToInt(args[2]);
  MPI_Comm comm(comms.lookupHandle(comhandle));
  if (args.size() < 2)
    throw Exception("mpiallreduce requires an array, an operation");
  char *op;
  Array oper(args[1]);
  op = oper.getContentsAsCString();
  MPI_Op mpiop;
  switch (*op) {
  case '+':
    mpiop = MPI_SUM;
    break;
  case '*':
    mpiop = MPI_PROD;
    break;
  case '<':
    mpiop = MPI_MIN;
    break;
  case '>':
    mpiop = MPI_MAX;
    break;
  default:
    throw Exception(std::string("unrecognized mpiop type:") + op + ": supported types are '+','*','>' and '<'");
  }
  Array source(args[0]);
  Array dest(source);
  Class dataClass(source.getDataClass());
  switch (dataClass) {
  case FM_LOGICAL:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_CHAR,mpiop,comm);
    break;
  case FM_UINT8:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_CHAR,mpiop,comm);
    break;
  case FM_INT8:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_CHAR,mpiop,comm);
    break;
  case FM_UINT16:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_SHORT,mpiop,comm);
    break;
  case FM_INT16:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_SHORT,mpiop,comm);
    break;
  case FM_UINT32:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED,mpiop,comm);
    break;
  case FM_INT32:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_INT,mpiop,comm);
    break;
  case FM_FLOAT:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_FLOAT,mpiop,comm);
    break;
  case FM_DOUBLE:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_DOUBLE,mpiop,comm);
    break;
  case FM_COMPLEX:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),2*source.getLength(),MPI_FLOAT,mpiop,comm);
    break;
  case FM_DCOMPLEX:
    MPI_Allreduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),2*source.getLength(),MPI_DOUBLE,mpiop,comm);
    break;
  case FM_INT64:
  case FM_UINT64:
    throw Exception("MPI support for 64 bit values is still needed!");
  default:
    throw Exception("unsupported array type in argument to allreduce - must be a numerical type");
  }
  ArrayVector retval;
  retval.push_back(dest);
  return retval;
}


/*
 * syntax: x = mpireduce(y,operation,root,comm)
 */
//!
//@Module MPIREDUCE MPI Reduce Operation
//@@Section MPI
//@@Usage
//This function implements the reduction operation using MPI.
//The general syntax for its use is
//@[
//   x = mpireduce(y,operation,root,comm)
//@]
//where @|y| is the current processes contribution to the
//reduction operation, @|operation| is either @|'+','*','>','<'| for
//an additive, multiplicative, max or min type reduction operations
//respectively,
//@|root| is the rank of the process that will retrieve the
//result of the reduction operation, and @|comm| is the MPI
//communicator handle.  If no communicator is provided,
//then @|MPI_COMM_WORLD| is used by default.  Note that FreeMat does not
//check to ensure that the reduction operation @|y| arguments
//are all the same size across the various processes in the group.
//Instead, you must make sure that each process passes the same
//sized array to the @|mpireduce| operation.
//!
ArrayVector MPIReduce(int nargout, const ArrayVector& args) {
  int comhandle;
  if (args.size() < 4)
    comhandle = 1;
  else 
    comhandle = ArrayToInt(args[3]);
  MPI_Comm comm(comms.lookupHandle(comhandle));
  if (args.size() < 3)
    throw Exception("mpireduce requires an array, an operation, and a root rank");
  int root = ArrayToInt(args[2]);
  char *op;
  Array oper(args[1]);
  op = oper.getContentsAsCString();
  MPI_Op mpiop;
  switch (*op) {
  case '+':
    mpiop = MPI_SUM;
    break;
  case '*':
    mpiop = MPI_PROD;
    break;
  case '<':
    mpiop = MPI_MIN;
    break;
  case '>':
    mpiop = MPI_MAX;
    break;
  default:
    throw Exception(std::string("unrecognized mpiop type:") + op + ": supported types are '+','*','>' and '<'");
  }
  Array source(args[0]);
  Array dest(source);
  Class dataClass(source.getDataClass());
  switch (dataClass) {
  case FM_LOGICAL:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_CHAR,mpiop,root,comm);
    break;
  case FM_UINT8:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_CHAR,mpiop,root,comm);
    break;
  case FM_INT8:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_CHAR,mpiop,root,comm);
    break;
  case FM_UINT16:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED_SHORT,mpiop,root,comm);
    break;
  case FM_INT16:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_SHORT,mpiop,root,comm);
    break;
  case FM_UINT32:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_UNSIGNED,mpiop,root,comm);
    break;
  case FM_INT32:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_INT,mpiop,root,comm);
    break;
  case FM_FLOAT:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_FLOAT,mpiop,root,comm);
    break;
  case FM_DOUBLE:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),source.getLength(),MPI_DOUBLE,mpiop,root,comm);
    break;
  case FM_COMPLEX:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),2*source.getLength(),MPI_FLOAT,mpiop,root,comm);
    break;
  case FM_DCOMPLEX:
    MPI_Reduce((void*)source.getDataPointer(),dest.getReadWriteDataPointer(),2*source.getLength(),MPI_DOUBLE,mpiop,root,comm);
    break;
  case FM_INT64:
  case FM_UINT64:
    throw Exception("MPI support for 64 bit values is still needed!");
  default:
    throw Exception("unsupported array type in argument to reduce - must be a numerical type");
  }
  ArrayVector retval;
  retval.push_back(dest);
  return retval;
}


//!
//@Module MPIINITIALIZED MPI Initialized Test
//@@Section MPI
//@@Usage
//This function tests to see if MPI is already initialized.
//The general syntax for its use is
//@[
//   x = mpiinitialized
//@]
//It returns a logical 1 if @|mpiinit| has already been called,
//and a logical 0 otherwise.
//@@Example
//Here we call @|mpiinitialized| before and after a call to 
//@|mpiinit|.
//@<
//mpiinitialized
//mpiinit
//mpiinitialized
//@>
//!
ArrayVector MPIInitialized(int nargout, const ArrayVector& args) {
  int flag;
  MPI_Initialized(&flag);
  ArrayVector retval;
  retval.push_back(Array::logicalConstructor(flag));
  return retval;
}

//!
//@Module MPIINIT MPI Initialize
//@@Section MPI
//@@Usage
//This function initializes the MPI subsystem and joins
//the current FreeMat process to the MPI environment.
//The general syntax for its use is
//@[
//   mpiinit
//@]
//Note that @|mpiinit| must be called before any other
//MPI routines (with the exception of @|mpiinitialized|), or
//an MPI error will occur.
//!
ArrayVector MPIInit(int nargout, const ArrayVector& args) {
  int flag;
  MPI_Initialized(&flag);
  if (flag) {
    return ArrayVector();
  }
  MPI_Init(NULL,NULL);
  InitializeMPIWrap();
  return ArrayVector();
}

//!
//@Module MPICOMMSPAWN MPI Communicator Spawn
//@@Section MPI
//@@Usage
//This function uses MPI to spawn a process on a members of a group.
//The full power of the underlying routine, @|MPI_Comm_spawn| is not
//yet available via the @|mpicommspawn| routine in FreeMat.  The
//general syntax for its use is
//@[
//  errcodes = mpicommspawn(command,argv,maxprocs,root,comm)
//@]
//where @|command| is the command to execute, @|argv| is a cell-array
//of strings to pass as arguments to @|command|, @|maxprocs| is the
//number of processes to spawn, @|root| is the node that will actually
//do the process spawn, and @|comm| is the communicator to use.
//If no communicator is specified, @|comm| defaults to @|MPI_COMM_SELF|.
//If @|root| is not specified, it defaults to 0.  If @|maxprocs| is
//not specified, it defaults to 1.  If @|argv| is not specified no
//arguments are passed to the spawned processes.
//!
ArrayVector MPICommSpawn(int nargout, const ArrayVector& args) {
  char *command;
  char **argv;
  int maxprocs;
  int root;
  MPI_Info info;
  MPI_Info_create(&info);
  MPI_Comm intercomm;
  MPI_Comm comm;
  if (args.size() == 0)
    throw Exception("mpicommspawn requires at least one argument (name of the command to spawn");
  if (args.size() > 5)
    throw Exception("too many arguments to mpicommspawn.");
  Array t1(args[0]);
  command = t1.getContentsAsCString();
  if (args.size() < 2)
    argv = NULL;
  else {
    Array t2(args[1]);
    if (t2.isEmpty())
      argv = NULL;
    else {
      if (t2.isString()) {
	argv = (char**) malloc(sizeof(char*)*2);
	argv[1] = NULL;
	argv[0] = t2.getContentsAsCString();
      } else if (t2.getDataClass() == FM_CELL_ARRAY) {
	Array *dp;
	dp = (Array*) t2.getDataPointer();
	int len;
	len = t2.getLength();
	argv = (char**) malloc(sizeof(char*)*(len+1));
	argv[len] = 0;
	for (int m=0;m<len;m++) {
	  Array q(dp[m]);
	  argv[m] = q.getContentsAsCString();
	}
      } else
	throw Exception("mpicommspawn requires the argument array to either be a string or a cell array of strings (i.e., {'arg','arg',...'arg'}).");
    }
  }
  if (args.size() < 3)
    maxprocs = 1;
  else
    maxprocs = ArrayToInt(args[2]);
  if (args.size() < 4)
    root = 0;
  else
    root = ArrayToInt(args[3]);
  if (args.size() < 5)
    comm = MPI_COMM_SELF;
  else
    comm = comms.lookupHandle(ArrayToInt(args[4]));
  int *errcodes;
  errcodes = (int*) Malloc(sizeof(int)*maxprocs);
  int res;
  res = MPI_Comm_spawn(command,argv,maxprocs,info,root,
		       comm,&intercomm,errcodes);
  ArrayVector retarr;
  retarr.push_back(Array::int32Constructor(comms.assignHandle(intercomm)));
  Dimensions dim;
  dim[0] = maxprocs;
  dim[1] = 1;
  retarr.push_back(Array::Array(FM_INT32,dim,errcodes));
  return retarr;
}

//!
//@Module MPIINTERCOMMMERGE MPI Intercommunicator Merge
//@@Section MPI
//@@Usage
//This routine merges the current process with an existing group.
//The general syntax for its use is 
//@[
//  newcomm = mpiintercommmerge(comm,highflag)
//@]
//where @|comm| is the communicator we want to merge onto,
//@|highflag| determines if our rank is at the high end of
//ranks in the new communicator or at the bottom end, and
//@|newcomm| is the handle to the new communicator.
//!
ArrayVector MPIIntercommMerge(int nargout, const ArrayVector& args) {
  ArrayVector retval;
  MPI_Comm newcomm;
  int highflag;
  MPI_Comm comm;
  if (args.size() < 1)
    throw Exception("must supply a handle for the intercommunicator");
  comm = comms.lookupHandle(ArrayToInt(args[0]));
  if (args.size() < 2) {
    highflag = 0;
  } else {
    highflag = ArrayToInt(args[1]);
  }
  MPI_Intercomm_merge(comm,highflag,&newcomm);
  retval.push_back(Array::int32Constructor(comms.assignHandle(newcomm)));
  return retval;
}
  
//!
//@Module MPICOMMGETPARENT MPI Get Parent Communicator
//@@Section MPI
//@@Usage
//This routine returns the communicator for the group that
//spawned the current process.  Calling this routine for a 
//process that was not spawned using @|mpicommspawn| will 
//cause an error.  The general syntax for its use is
//@[
//  comm = mpicommgetparent
//@]
//!
ArrayVector MPICommGetParent(int nargout, const ArrayVector& args) {
  ArrayVector retval;
  MPI_Comm parent;
  int res;

  res = MPI_Comm_get_parent(&parent);
  // Map this back to a handle
  int maxsize(comms.maxHandle());
  bool matchFound = false;
  int i;
  for (i=1;i<=maxsize;i++) {
    int result;
    MPI_Comm_compare(parent,comms.lookupHandle(i),&result);
    matchFound = (result == MPI_IDENT);
    if (matchFound) break;
  }
  // Was the comm found?/
  if (matchFound)
    retval.push_back(Array::int32Constructor(i));
  else
    retval.push_back(Array::int32Constructor(comms.assignHandle(parent)));
  return retval;
}

//!
//@Module MPIFINALIZE MPI Finalize
//@@Section MPI
//@@usage
//This routine will shut down the MPI interface.  Once called,
//no more MPI calls can be made (except for @|mpiinitialized|).
//The syntax for its use is
//@[
//  mpifinalize
//@]
//!
ArrayVector MPIFinalize(int nargout, const ArrayVector& args) {
  MPI_Finalize();
  return ArrayVector();
}

void LoadMPIFunctions(Context*context) {
  context->addFunction("mpisend",MPISend,4,0,"array","dest","tag","communicator");
  context->addFunction("mpirecv",MPIRecv,3,3,"source","tag","communicator");
  context->addFunction("mpibcast",MPIBcast,3,1,"array","root","communicator");
  context->addFunction("mpibarrier",MPIBarrier,1,0,"communicator");
  context->addFunction("mpicommrank",MPICommRank,1,1,"communicator");
  context->addFunction("mpicommsize",MPICommSize,1,1,"communicator");
  context->addFunction("mpireduce",MPIReduce,4,1,"y","operation","root","comm");
  context->addFunction("mpiallreduce",MPIAllReduce,3,1,"y","operation","root");
  context->addFunction("mpiinitialized",MPIInitialized,0,1);
  context->addFunction("mpiinit",MPIInit,0,0);
  context->addFunction("mpifinalize",MPIFinalize,0,0);
  context->addFunction("mpicommgetparent",MPICommGetParent,0,1);
  context->addFunction("mpicommspawn",MPICommSpawn,5,2,
		       "command","args","maxprocs","root","comm");
  context->addFunction("mpiintercommmerge",MPIIntercommMerge,2,1,"intercomm","highflag");
}

#endif


syntax highlighted by Code2HTML, v. 0.9.1