// Copyright (C) 1999 Jean-Marc Valin

#include "net_types.h"
#include "msvq.h"
#include "Vector.h"
#include "BufferedNode.h"
#include <sstream>

using namespace std;

namespace FD {

class MSVQTrain;
DECLARE_NODE(MSVQTrain)
/*Node
 *
 * @name MSVQTrain
 * @category VQ
 * @require MSVQ
 * @description Training of a multi-stage vector quantizer
 *
 * @input_name FRAMES
 * @input_description No description available
 *
 * @output_name OUTPUT
 * @output_description No description available
 *
 * @parameter_name STAGES
 * @parameter_description No description available
 *
 * @parameter_name BINARY
 * @parameter_description No description available
 *
END*/


class MSVQTrain : public BufferedNode {
      
protected:
   
   /**The ID of the 'output' output*/
   int outputID;
   
   /**The ID of the 'frames' input*/
   int framesInputID;
   
   /**Number of means to train model*/
   vector<int> stages;

public:
   
   MSVQTrain(string nodeName, ParameterSet params) 
      : BufferedNode(nodeName, params)
   { 
      try {

	 outputID = addOutput("OUTPUT");
	 framesInputID = addInput("FRAMES");
	 //cerr << "MSVQTrain initialization done\n";
         
         //stages = ObjectRef(new Vector<int>);
	 //Vector<int> &val = object_cast<Vector<int> > (stages);
	 stringstream str_vector(object_cast <String> (parameters.get("STAGES")).c_str());
	 str_vector >> stages;
	 

	 //nbMeans = dereference_cast<int> (parameters.get("MEANS"));
      } catch (BaseException *e)
      {
	 //e->print(cerr);
	 throw e->add(new NodeException(NULL, "Exception caught in MSVQTrain constructor", __FILE__, __LINE__));
      }
   }

   void initialize()
   {
      processCount=-1;
      this->Node::initialize();
   }

   void reset()
   {
      processCount=-1;
      this->Node::reset();
   }

   /**Ask for the node's output which ID (number) is output_id 
      and for the 'count' iteration */
   virtual void calculate(int output_id, int count, Buffer &out)
   {
      bool binary = false;
      if (parameters.exist("BINARY"))
         binary = dereference_cast<bool> (parameters.get("BINARY"));
      int i;
      NodeInput framesInput = inputs[framesInputID];
	    
      cerr << "getting frames..." << endl;
      ObjectRef matRef = framesInput.node->getOutput(framesInput.outputID,count);
      cerr << "got frames..." << endl;
      Vector<ObjectRef>  &mat = object_cast<Vector<ObjectRef> > (matRef);
	    
      MSVQ *vq = new MSVQ(stages);
	    
      vector <float *> data(mat.size());
      for (i=0;i<mat.size();i++)
         data[i]= &object_cast <Vector<float> > (mat[i])[0];
      int length =  object_cast <Vector<float> > (mat[0]).size();
	    
      cerr << "training..." << endl;
      vq->train(data,length,binary);
      cerr << "training complete." << endl;

      out[count] = ObjectRef(vq);
   }

};
}//namespace FD


syntax highlighted by Code2HTML, v. 0.9.1