// Copyright (C) 1999 Jean-Marc Valin #include "net_types.h" #include "msvq.h" #include "Vector.h" #include "BufferedNode.h" #include using namespace std; namespace FD { class MSVQTrain; DECLARE_NODE(MSVQTrain) /*Node * * @name MSVQTrain * @category VQ * @require MSVQ * @description Training of a multi-stage vector quantizer * * @input_name FRAMES * @input_description No description available * * @output_name OUTPUT * @output_description No description available * * @parameter_name STAGES * @parameter_description No description available * * @parameter_name BINARY * @parameter_description No description available * END*/ class MSVQTrain : public BufferedNode { protected: /**The ID of the 'output' output*/ int outputID; /**The ID of the 'frames' input*/ int framesInputID; /**Number of means to train model*/ vector stages; public: MSVQTrain(string nodeName, ParameterSet params) : BufferedNode(nodeName, params) { try { outputID = addOutput("OUTPUT"); framesInputID = addInput("FRAMES"); //cerr << "MSVQTrain initialization done\n"; //stages = ObjectRef(new Vector); //Vector &val = object_cast > (stages); stringstream str_vector(object_cast (parameters.get("STAGES")).c_str()); str_vector >> stages; //nbMeans = dereference_cast (parameters.get("MEANS")); } catch (BaseException *e) { //e->print(cerr); throw e->add(new NodeException(NULL, "Exception caught in MSVQTrain constructor", __FILE__, __LINE__)); } } void initialize() { processCount=-1; this->Node::initialize(); } void reset() { processCount=-1; this->Node::reset(); } /**Ask for the node's output which ID (number) is output_id and for the 'count' iteration */ virtual void calculate(int output_id, int count, Buffer &out) { bool binary = false; if (parameters.exist("BINARY")) binary = dereference_cast (parameters.get("BINARY")); int i; NodeInput framesInput = inputs[framesInputID]; cerr << "getting frames..." << endl; ObjectRef matRef = framesInput.node->getOutput(framesInput.outputID,count); cerr << "got frames..." << endl; Vector &mat = object_cast > (matRef); MSVQ *vq = new MSVQ(stages); vector data(mat.size()); for (i=0;i > (mat[i])[0]; int length = object_cast > (mat[0]).size(); cerr << "training..." << endl; vq->train(data,length,binary); cerr << "training complete." << endl; out[count] = ObjectRef(vq); } }; }//namespace FD