// Copyright (C) 1999 Jean-Marc Valin
#include "BufferedNode.h"
#include "ObjectRef.h"
#include "kmeans.h"
#include "Vector.h"
#include "CodebookMap.h"
using namespace std;
namespace FD {
class CMTrain;
DECLARE_NODE(CMTrain)
/*Node
*
* @name CMTrain
* @category VQ
* @require CMap
* @description Trains a codebook map
*
* @input_name TRAIN_IN
* @input_type Vector<ObjectRef>
* @input_description Input feature accumulator
*
* @input_name TRAIN_OUT
* @input_type Vector<ObjectRef>
* @input_description Output feature accumulator
*
* @input_name VQ
* @input_type KMeans
* @input_description Already trained vector quantizer
*
* @output_name OUTPUT
* @output_type CodebookMap
* @output_description Resulting codebook map
*
END*/
class CMTrain : public BufferedNode {
protected:
/**The ID of the 'trainIN' input*/
int trainInID;
/**The ID of the 'trainOut' input*/
int trainOutID;
/**The ID of the 'output' output*/
int outputID;
/**The ID of the 'nnet' input*/
int netInputID;
public:
/**Constructor, takes the name of the node and a set of parameters*/
CMTrain(string nodeName, ParameterSet params)
: BufferedNode(nodeName, params)
{
outputID = addOutput("OUTPUT");
netInputID = addInput("VQ");
trainInID = addInput("TRAIN_IN");
trainOutID = addInput("TRAIN_OUT");
}
/**Class specific initialization routine.
Each class will call its subclass initialize() method*/
virtual void initialize()
{
processCount=-1;
NodeInput trainInInput = inputs[trainInID];
//cerr << "in name = " << trainInInput.outputID << endl ;
NodeInput trainOutInput = inputs[trainOutID];
//cerr << "out name = " << trainOutInput.outputID << endl;
this->Node::initialize();
}
/**Class reset routine.
Each class will call its superclass reset() method*/
virtual void reset()
{
processCount=-1;
this->Node::reset();
}
/**Ask for the node's output which ID (number) is output_id
and for the 'count' iteration */
virtual void calculate(int output_id, int count, Buffer &out)
{
int i,j;
NodeInput trainInInput = inputs[trainInID];
ObjectRef trainInValue = trainInInput.node->getOutput(trainInInput.outputID,count);
NodeInput trainOutInput = inputs[trainOutID];
ObjectRef trainOutValue = trainOutInput.node->getOutput(trainOutInput.outputID,count);
NodeInput netInput = inputs[netInputID];
ObjectRef netValue = netInput.node->getOutput(netInput.outputID,count);
//cerr << "inputs calculated\n";
Vector<ObjectRef> &inBuff = object_cast<Vector<ObjectRef> > (trainInValue);
Vector<ObjectRef> &outBuff = object_cast<Vector<ObjectRef> > (trainOutValue);
//cerr << "inputs converted\n";
vector <float *> in(inBuff.size());
for (i=0;i<inBuff.size();i++)
in[i]=&object_cast <Vector<float> > (inBuff[i])[0];
vector <float *> vout(outBuff.size());
for (i=0;i<outBuff.size();i++)
vout[i]=&object_cast <Vector<float> > (outBuff[i])[0];
//FFNet *net = new FFNet( topo );
RCPtr<VQ> vq = netValue;
out[count] = ObjectRef(new CodebookMap(vq,in,vout,object_cast <Vector<float> > (outBuff[0]).size()));
}
};
}//namespace FD
syntax highlighted by Code2HTML, v. 0.9.1