/* Copyright (C) 1999-2005 Andy Adler
 * conv2: 2D convolution for octave
 *
 * $Id: conv2.cc,v 1.12 2005/12/29 03:50:18 pkienzle Exp $
 * 
 * This is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2, or (at your option) any
 * later version.
 * 
 * This software is distributed in the hope that it will be useful, but WITHOUT
 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
 * for more details.
 * 
## 2000-05-17: Paul Kienzle
##    * change argument to vector conversion to work for 2.1 series octave
##      as well as 2.0 series
## 2001-02-05: Paul Kienzle
##    * accept complex arguments

 */

#include <octave/oct.h>
using namespace std;

#define MAX(a,b) ((a) > (b) ? (a) : (b))

enum Shape { SHAPE_FULL, SHAPE_SAME, SHAPE_VALID };

#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
extern MArray2<double>
conv2 (MArray<double>&, MArray<double>&, MArray2<double>&, Shape);

extern MArray2<Complex>
conv2 (MArray<Complex>&, MArray<Complex>&, MArray2<Complex>&, Shape);
#endif

template <class T>
MArray2<T>
conv2 (MArray<T>& R, MArray<T>& C, MArray2<T>& A, Shape ishape)
{
  int  Rn=  R.length();
  int  Cm=  C.length();
  int  Am = A.rows();
  int  An = A.columns();

/*
 *  Calculate the size of the output matrix:
 *  in order to stay Matlab compatible, it is based
 *  on the third parameter if it's separable, and the
 *  first if it's not
 */
  int outM=0,
      outN=0,
      edgM=0,
      edgN=0;

  switch (ishape)
    {
      case SHAPE_FULL:
        outM= Am + Cm - 1;
        outN= An + Rn - 1;
        edgM= Cm - 1;
        edgN= Rn - 1;
        break;

      case SHAPE_SAME:
        outM= Am;
        outN= An;
        // Follow the Matlab convention (ie + instead of -)
        edgM= ( Cm - 1) /2;
        edgN= ( Rn - 1) /2;
        break;

      case SHAPE_VALID:
        outM= Am - Cm + 1;
        outN= An - Rn + 1;
        if (outM < 0) outM = 0;
        if (outN < 0) outN = 0;
        edgM= edgN= 0;
        break;

      default:
        error("conv2: invalid value of parameter ishape");
    }

  MArray2<T> O(outM,outN);
/*
 * X accumulates the 1-D conv for each row, before calculating
 *    the convolution in the other direction
 * There is no efficiency advantage to doing it in either direction
 *     first
 */

  MArray<T> X( An );

  for( int oi=0; oi < outM; oi++ )
    {
      for( int oj=0; oj < An; oj++ )
        {
           T sum=0;

           int      ci= Cm - 1 - MAX(0, edgM-oi);
           int      ai= MAX(0, oi-edgM) ; 
           const T* Ad= A.data() + ai + Am*oj;
           const T* Cd= C.data() + ci;
           for( ; ci >= 0 && ai < Am;
                  ci--,
                  Cd--,
                  ai++,
                  Ad++)
             {
               sum+= (*Ad) * (*Cd);
             }

             X(oj)= sum;

        }

      for( int oj=0; oj < outN; oj++ )
        {
          T sum=0;
          
          int      rj= Rn - 1 - MAX(0, edgN-oj);
          int      aj= MAX(0, oj-edgN) ; 
          const T* Xd= X.data() + aj;
          const T* Rd= R.data() + rj;
          
          for( ; rj >= 0 && aj < An;
                 rj--,
                 Rd--,
                 aj++,
                 Xd++)
            {
              sum+= (*Xd) * (*Rd);
            }
          
          O(oi,oj)= sum;
        }
    }

  return O;
}

#if !defined (CXX_NEW_FRIEND_TEMPLATE_DECL)
extern MArray2<double>
conv2 (MArray2<double>&, MArray2<double>&, Shape);

extern MArray2<Complex>
conv2 (MArray2<Complex>&, MArray2<Complex>&, Shape);
#endif

template <class T>
MArray2<T>
conv2 (MArray2<T>&A, MArray2<T>&B, Shape ishape)
{
/* Convolution works fastest if we choose the A matrix to be
 *  the largest.
 *
 * Here we calculate the size of the output matrix,
 *  in order to stay Matlab compatible, it is based
 *  on the third parameter if it's separable, and the
 *  first if it's not
 *
 * NOTE in order to be Matlab compatible, we give argueably
 *  wrong sizes for 'valid' if the smallest matrix is first
 */

  int Am = A.rows();
  int An = A.columns();
  int Bm = B.rows();
  int Bn = B.columns();

  int outM=0,
      outN=0,
      edgM=0,
      edgN=0;

  switch (ishape)
    {
      case SHAPE_FULL:
        outM= Am + Bm - 1;
        outN= An + Bn - 1;
        edgM= Bm - 1;
        edgN= Bn - 1;
        break;

      case SHAPE_SAME:
        outM= Am;
        outN= An;
        edgM= ( Bm - 1) /2;
        edgN= ( Bn - 1) /2;
        break;

      case SHAPE_VALID:
        outM= Am - Bm + 1;
        outN= An - Bn + 1;
	if (outM < 0) outM = 0;
	if (outN < 0) outN = 0;
        edgM= edgN= 0;
        break;
    }

  MArray2<T> O(outM,outN);

  for( int oi=0; oi < outM; oi++ )
    {
      for( int oj=0; oj < outN; oj++ )
        {
          T sum=0;

          for( int bj= Bn - 1 - MAX(0, edgN-oj),
                   aj= MAX(0, oj-edgN);
                   bj >= 0 && aj < An;
                   bj--,
                   aj++)
            {
              int      bi= Bm - 1 - MAX(0, edgM-oi);
              int      ai= MAX(0, oi-edgM); 
              const T* Ad= A.data() + ai + Am*aj;
              const T* Bd= B.data() + bi + Bm*bj;

              for( ; bi >= 0 && ai < Am;
                     bi--,
                     Bd--,
                     ai++,
                     Ad++)
                {
                  sum+= (*Ad) * (*Bd);
                 /* Comment: it seems to be 2.5 x faster than this:
                  *        sum+= A(ai,aj) * B(bi,bj);
                  */
                }
            }

          O(oi,oj)= sum;
        }
    }

  return O;
}

/*
%!test
%! b = [0,1,2,3;1,8,12,12;4,20,24,21;7,22,25,18];
%! assert(conv2([0,1;1,2],[1,2,3;4,5,6;7,8,9]),b);
*/
DEFUN_DLD (conv2, args, ,
  "-*- texinfo -*-\n\
@deftypefn {Loadable Function} {y =} conv2 (@var{a}, @var{b}, @var{shape})\n\
@deftypefnx {Loadable Function} {y =} conv2 (@var{v1}, @var{v2}, @var{M}, @var{shape})\n\
\n\
Returns 2D convolution of @var{a} and @var{b} where the size\n\
of @var{c} is given by\n\
\n\
@table @asis\n\
@item @var{shape}= 'full'\n\
returns full 2-D convolution\n\
@item @var{shape}= 'same'\n\
same size as a. 'central' part of convolution\n\
@item @var{shape}= 'valid'\n\
only parts which do not include zero-padded edges\n\
@end table\n\
\n\
By default @var{shape} is 'full'. When the third argument is a matrix\n\
returns the convolution of the matrix @var{M} by the vector @var{v1}\n\
in the column direction and by vector @var{v2} in the row direction\n\
@end deftypefn")
{
  octave_value_list retval;
  octave_value tmp;
  int nargin = args.length ();
  string shape= "full"; //default
  bool separable= false;
  Shape ishape;

  if (nargin < 2 )
    {
     print_usage ("conv2");
     return retval;
    }
  else if (nargin == 3)
    {
      if ( args(2).is_string() )
        shape= args(2).string_value();
      else
        separable= true;
    } 
  else if (nargin >= 4)
    {
      separable= true;
      shape= args(3).string_value();
    }

  if ( shape == "full" )
    ishape = SHAPE_FULL;
  else if ( shape == "same" )
    ishape = SHAPE_SAME;
  else if ( shape == "valid" )
    ishape = SHAPE_VALID;
  else
    {
      error("Shape type not valid");
      print_usage ("conv2");
      return retval;
    }

   if (separable)
     {

     /*
      * If user requests separable, check first two params are vectors
      */
       if (
         !( 1== args(0).rows() || 1== args(0).columns() )
          ||
         !( 1== args(1).rows() || 1== args(1).columns() ) )
         {
          print_usage ("conv2");
          return retval;
         }

       if (  args(0).is_complex_type() ||
             args(1).is_complex_type() ||
             args(2).is_complex_type() )
         {
           ComplexColumnVector v1 (args(0).complex_vector_value());
           ComplexColumnVector v2 (args(1).complex_vector_value());
           ComplexMatrix a (args(2).complex_matrix_value());
           ComplexMatrix c(conv2(v1, v2, a, ishape));
           retval(0) = c;
         }
       else
         {
           ColumnVector v1 (args(0).vector_value());
           ColumnVector v2 (args(1).vector_value());
           Matrix a (args(2).matrix_value());
           Matrix c(conv2(v1, v2, a, ishape));
           retval(0) = c;
         }
     } // if (separable) 
   else
     {

       if ( args(0).is_complex_type() ||
            args(1).is_complex_type())
         {
           ComplexMatrix a (args(0).complex_matrix_value());
           ComplexMatrix b (args(1).complex_matrix_value());
           ComplexMatrix c(conv2(a, b, ishape));
           retval(0) = c;
         }
       else
         {
           Matrix a (args(0).matrix_value());
           Matrix b (args(1).matrix_value());
           Matrix c(conv2(a, b, ishape));
           retval(0) = c;
         }

     } // if (separable) 
      
   return retval;
}


template MArray2<double>
conv2 (MArray<double>&, MArray<double>&, MArray2<double>&, Shape);

template MArray2<double>
conv2 (MArray2<double>&, MArray2<double>&, Shape);

template MArray2<Complex>
conv2 (MArray<Complex>&, MArray<Complex>&, MArray2<Complex>&, Shape);

template MArray2<Complex>
conv2 (MArray2<Complex>&, MArray2<Complex>&, Shape);


syntax highlighted by Code2HTML, v. 0.9.1