/*
    DFT++ is a density functional package developed by the research group
    of Professor Tomas Arias

    Copyright 1996-2003 Sohrab Ismail-Beigi

    This file is part of DFT++.

    DFT++ is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    DFT++ is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with DFT++; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA

    Please see the file CREDITS for a list of authors.

    For academic users, we request that publications using results obtained with
    this software reference

    "New algebraic formulation of density functional calculation," by Sohrab Ismail-Beigi
    and T.A. Arias, Computer Physics Communications 128:1-2, 1-45 (June 2000).

    and, if using the wavelet basis, further reference

    "Multiresolution analysis of electronic structure: semicardinal and wavelet bases,"
    T.A. Arias, Reviews of Modern Physics 71:1, 267-311 (January 1999).

    and 

    "Robust ab initio calculation of condensed matter: transparent convergence through
    semicardinal multiresolution analysis,'' I.P. Daykov, T.A. Arias, and
    Torkel D. Engeness, Physical Review Letters, 90:21, 216402 (May 2003).

    For your convenience, preprints of the above articles may be obtained from
    http://arXiv.org/abs/cond-mat/9909130, 9805262, and 0204411, respectively.
*/

/*
 * Routines that do various matrix multiplications.  These routines
 * are the computational Kernels for all the matrix multiplies in the code,
 * so optimizing them is the way to improve matrix multiplication
 * performance.
 */

#include "header.h"

//
// This routine does the multiplication
//
//       Bres[i][j] += sum_k { B1[i][k]*B2[j][k] }
//
// B1   is s1 x cs in memory layout
// B2   is s2 x cs in memory layout
// Bres is s1 x s2 in memory layout
//
// The matrix multiply is done for the sublocks in memory of size:
// B1   n1 x nc
// B2   n2 x nc
// Bres n1 x n2
//

// extern "C" void esvzgemul(void *,int,char *,
// 			  void *,int,char *,
// 			  void *,int,int,int,int);

void
small_block_matrix_mult(int n1, int n2, int nc,
			int s1, int s2, int cs,
			scalar *B1, scalar *B2, scalar *Bres)
{
// This uses the essl esvzgemul routine to do the work.
// using it turned out to not give ANY performance benefit.
//
//   static scalar Btemp[10000];
//   esvzgemul(B2,cs,"T",B1,cs,"N",Btemp,s2,n2,nc,n1);
//   for (int i=0; i < n1; i++)
//     {
//       register int is2 = i*s2;
//       for (int j=0; j < n2; j++)
// 	{
// 	  Bres[is2+j].x += Btemp[is2+j].x;
// 	  Bres[is2+j].y += Btemp[is2+j].y;
// 	}
//     }
//   return;

  // The code blocks below does this:
  //    scalar sum = 0.0;
  //    for (register int k=0; k<nc; k++)
  //       sum += B1[r1][k] * B2[r2][k];
  //    Bres[r1][r2] += sum;
  //
  // If we have even sized blocks, do register level 2 x 2 blocks
  if ( (n1%2==0) && (n2%2==0) && (nc%2==0) )
    {
      for (int r1 = 0; r1 < n1; r1+=2)		
	for (int r2 = 0; r2 < n2; r2+=2)
	  {
	    int r1cs = r1*cs;
	    int r1s2 = r1*s2;
	    int r2cs = r2*cs;
	    double ax, bx, cx, dx, ex, fx, gx, hx;
	    double ay, by, cy, dy, ey, fy, gy, hy;
	    double wx, xx, yx, zx, wy, xy, yy, zy;
	    wx = xx = yx = zx = wy = xy = yy = zy = 0.0;
	    for (int k=0; k < nc; k+=2)
	      {
		// a = B1[r1][k]    b = B1[r1][k+1],
		// c = B1[r1+1][k]  d = B1[r1+1][k+1]
		ax = B1[r1cs+k].x;
		ay = B1[r1cs+k].y;
		bx = B1[r1cs+k+1].x;
		by = B1[r1cs+k+1].y;
		cx = B1[r1cs+cs+k].x;
		cy = B1[r1cs+cs+k].y;
		dx = B1[r1cs+cs+k+1].x;
		dy = B1[r1cs+cs+k+1].y;
		// e = B2[r2][k]    f = B2[r2][k+1]
		// g = B2[r2+1][k]  h = B2[r2+1][k+1]
		ex = B2[r2cs+k].x;
		ey = B2[r2cs+k].y;
		fx = B2[r2cs+k+1].x;
		fy = B2[r2cs+k+1].y;
		gx = B2[r2cs+cs+k].x;
		gy = B2[r2cs+cs+k].y;
		hx = B2[r2cs+cs+k+1].x;
		hy = B2[r2cs+cs+k+1].y;
		// w = a*e + b*f    x = a*g + b*h
		// y = c*e + d*f    z = c*g + d*h
		wx += ax*ex - ay*ey + bx*fx - by*fy;
		wy += ax*ey + ay*ex + bx*fy + by*fx;
		xx += ax*gx - ay*gy + bx*hx - by*hy;
		xy += ax*gy + ay*gx + bx*hy + by*hx;
		yx += cx*ex - cy*ey + dx*fx - dy*fy;
		yy += cx*ey + cy*ex + dx*fy + dy*fx;
		zx += cx*gx - cy*gy + dx*hx - dy*hy;
		zy += cx*gy + cy*gx + dx*hy + dy*hx;
	      }
	    // Bres[r1][r2]   += w    Bres[r1][r2+1]   += x
	    // Bres[r1+1][r2] += y    Bres[r1+1][r2+1] += z
	    Bres[r1s2+r2].x += wx;
	    Bres[r1s2+r2].y += wy;
	    Bres[r1s2+r2+1].x += xx;
	    Bres[r1s2+r2+1].y += xy;
	    Bres[r1s2+s2+r2].x += yx;
	    Bres[r1s2+s2+r2].y += yy;
	    Bres[r1s2+s2+r2+1].x += zx;
	    Bres[r1s2+s2+r2+1].y += zy;	
	  }
    }
  // The blocks aren't even sized; so do the simpler matrix multiply
  else
    {
      for (int r1 = 0; r1 < n1; r1++)
	for (int r2 = 0; r2 < n2; r2++)
	  {
	    int r1cs = r1*cs;
	    int r1s2 = r1*s2;
	    int r2cs = r2*cs;
	    double sx = 0.0;
	    double sy = 0.0;
	    for (int k=0; k < nc; k++)
	      {
		double ax;
		double ay;
		double bx;
		double by;
		ax = B1[r1cs+k].x;
		ay = B1[r1cs+k].y;
		bx = B2[r2cs+k].x;
		by = B2[r2cs+k].y;
		sx += ax * bx - ay * by;
		sy += ay * bx + ax * by;
	      }
	    Bres[r1s2+r2].x += sx;
	    Bres[r1s2+r2].y += sy;
	  }
    }
}

// define block sizes used for block matrix multiplies
#define BL1 32
#define BL2 32
#define BLK 32

//
// This routine does the operation M = Y1^Y2 via blocked matrix
// multiplies that call the routine above for the actual work.
// So all that is done below is to loop over blocks, load the
// blocks from memory, and then to output the result blocks to M.
//
// n1 and n2 are the number of columns of Y1 and Y2 resp.
// N is the length of the columns of Y1 and Y2.
//
// offsetMrow/col are offsets to the output M (see the formulae below).
//
// transpose==1 means Y1 and Y2 are actually input in transpose format.
// transpose==0 means they are in "normal" format
//
// i.e. in transpose==0 mode, the routine does
//
// M(i+offsetMrow,j+offsetMcol) =
//         sum_k { conj(Y1.col[i].c[k])*Y2.col[j].c[k] }
//
// offsetY2rowtranspose is an offset added to the row accessing of Y2
// when in transpose mode (specially needed when doing distributed
// case when data comes in transposed... see dist_multiply.c for more
// juicy details).
//
// The computation is split into three routines in order for threading
// and serial implementations to all interface most easily:
//
// The routine Y1dagY2_partial_mult() does the actual work of multiplying
//    for indices ib going from start_ib to start_ib+n_ib_todo-1.  The
//    Entire multiply would have ib going from 0 to n1.
//
// The routine Y1dagY2_thread() is a wrapper routine that a thread
//    executes:  it gets the input data given to it, decodes it,
//    and then calls Y1dagY2_partial_mult() with the correct start_ib
//    and n_ib_todo indices.
//
// The master routine Y1dagY2_block_matrix_mult() is the actual routine
// called by the outside, whose job is to compuute M=Y1^Y2.  It either
// calls Y1dagY2_partial_mult() with one big multiply (if running serial),
// or spawns threads tha will run Y1dagY2_thread() (if running threads).
//
static void
Y1dagY2_partial_mult(const ColumnBundle &Y1,
		     const ColumnBundle &Y2,
		     Matrix &M,
		     int start_ib,int n_ib_todo, int n2, int N,
		     int offsetMrow,int offsetMcol,
		     int transpose,int offsetY2rowtranspose)
{
  // loop over blocks of size BL1xBL2 of the output matrix M
  // staring ib at start_ib and ending at start_ib+n_ib_todo-1
  int ib,jb;
  for (ib=start_ib; ib < start_ib+n_ib_todo; ib+=BL1)
    for (jb=0; jb < n2; jb+=BL2)
      {
	// calculate sizes of output block
	int si = (n_ib_todo-(ib-start_ib)) >= BL1 ? BL1 : n_ib_todo%BL1;
	int sj = (n2-jb)                   >= BL2 ? BL2 : n2%BL2;

	// zero output block
	scalar out[BL1][BL2];
	int i,j;
	for (i=0; i < si; i++)
	  for (j=0; j < sj; j++)
	    out[i][j] = 0.0;

	// input blocks for multiply loop below
	scalar b1[BL1][BLK],b2[BL2][BLK];

	// loop over long direction k for the sum
	int kb;
	for (kb=0; kb < N; kb+=BLK)
	  {
	    // size of k-block
	    int sk = (N-kb) >= BLK ? BLK : N%BLK;

	    // get data from Y1 and Y2 into input blocks:
	    // if in transpose mode...
	    int k;
	    if (transpose)
	      {
		for (k=0; k < sk; k++)
		  for (i=0; i < si; i++)
		    {
#if defined SCALAR_IS_COMPLEX
		      b1[i][k] = conj(Y1.col[kb+k]->data.d[ib+i]);
#elif defined SCALAR_IS_REAL
		      b1[i][k] = Y1.col[kb+k]->data.d[ib+i];
#else
#error scalar is neither real nor complex!
#endif
		    }
		for (k=0; k < sk; k++)
		  for (j=0; j < sj; j++)
		    b2[j][k] = Y2.col[kb+k]->data.d[jb+j+offsetY2rowtranspose];
	      } 
	    // non-transpose mode
	    else
	      {
		for (i=0; i < si; i++)
		  for (k=0; k < sk; k++)
		    {
#if defined SCALAR_IS_COMPLEX
		      b1[i][k] = conj(Y1.col[ib+i]->data.d[kb+k]);
#elif defined SCALAR_IS_REAL
		      b1[i][k] = Y1.col[ib+i]->data.d[kb+k];
#else
#error scalar is neither real nor complex!
#endif
		    }
		for (j=0; j < sj; j++)
		  for (k=0; k < sk; k++)
		    b2[j][k] = Y2.col[jb+j]->data.d[kb+k];
	      }
	    // multiply blocks: out[i][j] += sum_k { b1[i][k]*b2[j][k] }
	    small_block_matrix_mult(si,sj,sk,BL1,BL2,BLK,
				    (scalar *)b1,(scalar *)b2,(scalar *)out);
	  } // over kb

	// now write out the output block to the matrix M
	for (i=0; i < si; i++)
	  for (j=0; j < sj; j++)
	    M(ib+i+offsetMrow,jb+j+offsetMcol) = out[i][j];
      } // over ib,jb
}

#ifdef DFT_THREAD
static void *
Y1dagY2_thread(void *arg)
{
  // Decode what the data the thread is given
  dft_thread_data *data = (dft_thread_data *)arg;

  ColumnBundle *Y1 = (ColumnBundle *)data->p1;
  ColumnBundle *Y2 = (ColumnBundle *)data->p2;
  Matrix *M = (Matrix *)data->p3;
  int ib_start = data->start;
  int ib_n_todo = data->n;
  int n2 = data->i1;
  int N = data->i2;
  int offsetMrow = data->i3;
  int offsetMcol = data->i4;
  int transpose = data->i5;
  int offsetY2rowtranspose = data->i6;

  // Call the multiplier above with the ranges this thread works on
  Y1dagY2_partial_mult(*Y1,*Y2,*M,
		       ib_start,ib_n_todo,n2,N,offsetMrow,offsetMcol,
		       transpose,offsetY2rowtranspose);

  // Free the data-passing intermediary and end the thread
  myfree(arg);
  return NULL;
}
#endif // DFT_THREAD

void
Y1dagY2_block_matrix_mult(const ColumnBundle &Y1,
                          const ColumnBundle &Y2,
                          Matrix &M,
                          int n1, int n2, int N,
                          int offsetMrow,int offsetMcol,
                          int transpose,int offsetY2rowtranspose)
{
#ifdef DFT_THREAD
  // Threads!  Distribute the work
  dft_call_threads(n1,
		   (void *)&Y1,(void *)&Y2,(void *)&M,NULL,NULL,
		   n2,N,offsetMrow,offsetMcol,transpose,offsetY2rowtranspose,
		   Y1dagY2_thread);
#else
  // No threads... just do all the work in one call
  Y1dagY2_partial_mult(Y1,Y2,M,0,n1,n2,N,offsetMrow,offsetMcol,
		       transpose,offsetY2rowtranspose);
#endif
}



//
// This routine does the operation YM = Y*M or YM += Y*M via blocked matrix
// multiplies that call the routine small_block_matrix_mult()
// to do the actual FLOP work. So all that is done below is to loop over
// blocks, load the blocks from memory, and then to output the result
// blocks to YM.
//
// N is the length of the columns of YM and Y.
// nrM is the number of rows of M.
// ncM is the number of cols of M.
//
// offsetMrow/col are offsets of the input M (see the formulae below) used
// to read information out of it.
//
// accum==0 does YM =  Y*M
// accum==1 does YM += Y*M
//
// transpose==0 means they are in normal format
// transpose==1 means Y and YM are actually input in transpose format.
//
// The overall routine calculation does
//
//        YM(i,j) = sum_k { Y(i,k)*M(k+offsetMrow,j+offsetMcol) }
//
// In transpose mode, often the routine is called with Y and YM being
// the same column_bundle, so the routine is written so as to be able to do
// the multiplication in place:  i.e. YM and Y being the same will not
// affect things.  This is the role of the temporary matrix temp
// below.
//
// Like the multiplication routine above, this is also broken
// into three pieces to ease threading:
//
// Y_M_partial_mult() does the multiplication for rYb going from
// rYb_start to rYb_start+n_rYb_todo-1 (the entire multiplication has
// rYb going from 0 to N-1)
//
// Y_M_thread() is run by a thread and it calls Y_M_partial_mult() with
// appropriate ranges of rYb
//
// Y_M_block_matrix_mult() is the routine called from the outside and
// it either calls Y_M_partial_mult() once with rYb going from 0 to N-1
// (if serial), or launches threads to do the job.
//
static void
Y_M_partial_mult(const ColumnBundle &Y,
		 const Matrix &M,
		 ColumnBundle &YM,
		 int rYb_start,int n_rYb_todo, int nrM, int ncM,
		 int offsetMrow, int offsetMcol,
		 int transpose,
		 int accum)
{
  // temporary work space
  Matrix temp(BL1,ncM);

  // loop over blocks of size BL1 on the rows of Y (or cols of Ytranspose)
  // rYb goes from rYb_start to rYb_start+n_rYb_todo-1
  int rYb;
  for (rYb=rYb_start; rYb < rYb_start+n_rYb_todo; rYb+=BL1)
    {
      // compute size of block in the rY direction
      int srY = (n_rYb_todo-(rYb-rYb_start)) >= BL1 ? BL1 : n_rYb_todo%BL1;

      // loop over blocks of size BL2 on the columns of M
      int cMb;
      for (cMb=0; cMb < ncM; cMb+=BL2)
	{
	  // compute actual size of block in the cM direction
	  int scM = (ncM-cMb) >= BL2 ? BL2 : ncM%BL2;
	  
	  // zero out output block
	  scalar out[BL1][BL2];
	  int rY,cM;
	  for (rY=0; rY < srY; rY++)
	    for (cM=0; cM < scM; cM++)
	      out[rY][cM] = 0.0;
	  
	  // input blocks for multiply loop below
	  scalar by[BL1][BLK],bm[BL2][BLK];
	  
	  // loop over blocks of rows of M
	  int kb;
	  for (kb=0; kb < nrM; kb+=BLK)
	    {
	      // size of block in k direction
	      int sk = (nrM-kb) >= BLK ? BLK : nrM%BLK;
	      
	      // read in data of Y and M into input blocks
	      int k;
	      if (transpose)
		for (rY=0; rY < srY; rY++)
		  for (k=0; k < sk; k++)
		    by[rY][k] = Y.col[rYb+rY]->data.d[kb+k];
	      else
		for (k=0; k < sk; k++)
		  for (rY=0; rY < srY; rY++)
		    by[rY][k] = Y.col[kb+k]->data.d[rYb+rY];
	      for (k=0; k < sk; k++)
		for (cM=0; cM < scM; cM++)
		  bm[cM][k] = M(kb+k+offsetMrow,cMb+cM+offsetMcol);
	      
	      // do the multiply
	      small_block_matrix_mult(srY,scM,sk,BL1,BL2,BLK,
				      (scalar *)by,(scalar *)bm,(scalar *)out);
	    } // over kb
	  
	  // write out to the output block to the temporary space
	  for (rY=0; rY < srY; rY++)
	    for (cM=0; cM < scM; cM++)
	      temp(rY,cMb+cM) = out[rY][cM];
	  
	} // over cMb

      // write out temporary block to main memory (depending on
      // accumulate and transpose flags)
      int rY,cM;
      if (transpose)
	{
	  if (accum)
	    for (rY=0; rY < srY; rY++)
	      for (cM=0; cM < M.nc; cM++)
		YM.col[rYb+rY]->data.d[cM] += temp(rY,cM);
	  else
	    for (rY=0; rY < srY; rY++)
	      for (cM=0; cM < M.nc; cM++)
		YM.col[rYb+rY]->data.d[cM] = temp(rY,cM);
	}
      else
	{
	  if (accum)
	    for (cM=0; cM < M.nc; cM++)
	      for (rY=0; rY < srY; rY++)
		YM.col[cM]->data.d[rYb+rY] += temp(rY,cM);
	  else
	    for (cM=0; cM < M.nc; cM++)
	      for (rY=0; rY < srY; rY++)
		YM.col[cM]->data.d[rYb+rY] = temp(rY,cM);
	}

    } // over rYb
}  

#ifdef DFT_THREAD
static void *
Y_M_thread(void *arg)
{
  // Decode what the thread is given
  dft_thread_data *data = (dft_thread_data *)arg;

  ColumnBundle *Y = (ColumnBundle *)data->p1;
  Matrix *M = (Matrix *)data->p2;
  ColumnBundle *YM = (ColumnBundle *)data->p3;
  int rYb_start = data->start;
  int rYb_n_todo = data->n;
  int nrM = data->i1;
  int ncM = data->i2;
  int offsetMrow = data->i3;
  int offsetMcol = data->i4;
  int transpose = data->i5;
  int accum = data->i6;

  // Call the multiply on the part belong to this thread
  Y_M_partial_mult(*Y,*M,*YM,
		   rYb_start,rYb_n_todo,
		   nrM,ncM,offsetMrow,offsetMcol,transpose,accum);
  
  // Free the data-passing structure and exit the thread
  myfree(arg);
  return NULL;
}
#endif  // DFT_THREAD


void
Y_M_block_matrix_mult(const ColumnBundle &Y,
		      const Matrix &M,
		      ColumnBundle &YM,
		      int N, int nrM, int ncM,
		      int offsetMrow, int offsetMcol,
		      int transpose,
		      int accum)
{
#ifdef DFT_THREAD
  // Threads!  Distribute the work
  dft_call_threads(N,
		   (void *)&Y,(void *)&M,(void *)&YM,NULL,NULL,
		   nrM,ncM,offsetMrow,offsetMcol,transpose,accum,
		   Y_M_thread);
#else
  // No threads... do all the work in one big call
  Y_M_partial_mult(Y,M,YM,0,N,nrM,ncM,offsetMrow,offsetMcol,transpose,accum);
#endif
} 

//
// Does the matrix multiply mprod = m1*m2
// i.e. mprod(i,j) = sum_k { m1(i,k)*m2(k,j) }
// by doing block matrix multiplies.
//
void
matrix_matrix_block_matrix_mult(const Matrix &m1,
                                const Matrix &m2,
                                Matrix &mprod)
{
  // input blocks
  static scalar ini[BL1][BLK],inj[BL2][BLK];
  
  // output blocks
  static scalar out[BL1][BL2];

  // loop over blocks of output mprod
  int ib,jb;
  for (ib=0; ib < m1.nr; ib+=BL1)
    for (jb=0; jb < m2.nc; jb+=BL2)
      {
	// compute size of blocks
	int si = (m1.nr-ib) >= BL1 ? BL1 : m1.nr%BL1;
	int sj = (m2.nc-jb) >= BL2 ? BL2 : m2.nc%BL2;

	// zero output block 
	int i,j;
	for (i=0; i < si; i++)
	  for (j=0; j < sj; j++)
	    out[i][j].x = out[i][j].y = 0.0;

	// loop over blocks of columns of m1 (i.e. rows of m2)
	int kb;
	for (kb=0; kb < m1.nc; kb+=BLK)
	  {
	    // Size of block in k-direction
	    int sk = (m1.nc-kb) >= BLK ? BLK : m1.nc%BLK;

	    // read in input blocks 
	    int k;
	    for (k=0; k < sk; k++)
	      for (i=0; i < si; i++)
		{
		  ini[i][k].x = m1.c[(ib+i)*m1.nc+kb+k].x;
		  ini[i][k].y = m1.c[(ib+i)*m1.nc+kb+k].y;
		}
	    for (j=0; j < sj; j++)
	      for (k=0; k < sk; k++)
		{
		  inj[j][k].x = m2.c[(kb+k)*m2.nc+jb+j].x;
		  inj[j][k].y = m2.c[(kb+k)*m2.nc+jb+j].y;
		}

	    // Do block mult
	    small_block_matrix_mult(si,sj,sk,BL1,BL2,BLK,
				    (scalar *)ini,(scalar *)inj,(scalar *)out);
	  } // kb blocks

	// write out block
	for (j=0; j < sj; j++)
	  for (i=0; i < si; i++)
	    {
	      mprod.c[(ib+i)*mprod.nc+jb+j].x = out[i][j].x;
	      mprod.c[(ib+i)*mprod.nc+jb+j].y = out[i][j].y;
	    }
      } // (ib,jb) blocks
}


syntax highlighted by Code2HTML, v. 0.9.1