// Copyright (C) 1999 Jean-Marc Valin

#include "BufferedNode.h"
#include "Buffer.h"
#include "Vector.h"
#include <stdlib.h>
#include <math.h>

using namespace std;

namespace FD {

class NLMS;

DECLARE_NODE(NLMS)
/*Node
 *
 * @name NLMS
 * @category DSP:Adaptive
 * @description Normalized LMS algorithm
 *
 * @input_name INPUT
 * @input_description The input of the adaptive FIR filter
 *
 * @input_name REF
 * @input_description The signal being tracked
 *
 * @output_name OUTPUT
 * @output_description The output of the adaptive FIR filter (not the residue)
 *
 * @parameter_name FILTER_LENGTH
 * @parameter_description Length of the adaptive FIR filter
 *
 * @parameter_name ALPHA
 * @parameter_description Adaptation rate of the filter coefficients
 *
 * @parameter_name BETA
 * @parameter_description Adaptation rate of the normalization energy estimate
 *
 * @parameter_name POWER
 * @parameter_description Normalization power
 *
END*/


class NLMS : public BufferedNode {
   
   int inputID;
   int refID;
   int outputID;
   int size;
   Vector<float> a;
   float alpha;
   float beta;
   float E;
   float power;
      //Vector<float> w;
      //Vector<float> grad;
public:
   NLMS(string nodeName, ParameterSet params)
   : BufferedNode(nodeName, params)
   {
      inOrder = true;
      inputID = addInput("INPUT");
      refID = addInput("REF");
      outputID = addOutput("OUTPUT");

      size = dereference_cast<int> (parameters.get("FILTER_LENGTH"));
      alpha = dereference_cast<float> (parameters.get("ALPHA"));
      beta = dereference_cast<float> (parameters.get("BETA"));
      power = dereference_cast<float> (parameters.get("POWER"));
      a.resize(size,0.0);
      //w.resize(size,1.0);
      //grad.resize(size,0.0);
      inputsCache[inputID].lookBack=1;
   }

   void initialize()
   {
      for (int j=0;j<size;j++)
	 a[j] = 0;//1.0/size;
      E=1e-6;
      BufferedNode::initialize();
   }

   void reset()
   {
      BufferedNode::reset();
      for (int j=0;j<size;j++)
	 a[j] = 0;//1.0/size;
      E=1e-6;
   }

   void calculate(int output_id, int count, Buffer &out)
   {
      ObjectRef inputValue = getInput(inputID, count);
      ObjectRef refValue = getInput(refID, count);



      const Vector<float> &in = object_cast<Vector<float> > (inputValue);
      const Vector<float> &ref = object_cast<Vector<float> > (refValue);
      int inputLength = in.size();

      Vector<float> &output = *Vector<float>::alloc(inputLength);
      out[count] = &output;

      for (int i=0;i<inputLength;i++)
         output[i]=0;

      const Vector<float> *past;
      bool can_look_back = false;
      if (count > 0)   
      {
         ObjectRef pastInputValue = getInput(inputID, count-1);
	 can_look_back=true;
	 past = &object_cast<Vector<float> > (pastInputValue);
      }      
      
      DYN_VEC(float, inputLength+size-1, _x);
      //float _x[inputLength+size-1];
      float *x=_x+size-1;

      if (can_look_back)
      {
	 for (int i=0;i<size-1;i++)
	 {
	    _x[i]=(*past)[inputLength-size+1+i];
	 }
      } else {
	 for (int i=0;i<size-1;i++)
	 {
	    _x[i]=0.0;
	 }
      }
      for (int i=0;i<inputLength;i++)
      {
	 x[i]=in[i];
      }

      float err, norm;
      for (int i=0;i<inputLength;i++)
      {
	 for (int j=0;j<size;j++)
	    output[i] += a[j]*x[i-j];
	 float err = ref[i]-output[i];
	 
	 E = (1-beta)*E + beta*x[i]*x[i];
	 norm = alpha*err/pow(E, power);
	 for (int j=0;j<size;j++)
	    a[j] += norm*x[i-j];
      }
   }

};

}//namespace FD


syntax highlighted by Code2HTML, v. 0.9.1