// Copyright (C) 1999 Jean-Marc Valin
#include "kmeans.h"
#include "ObjectParser.h"
using namespace std;
namespace FD {
DECLARE_TYPE(KMeans)
//@implements VQ
void KMeans::split (const vector<float *> &data, int len)
{
int nbMeans = means.size();
#ifdef STACK_ALLOC
float totalDist[nbMeans];
int belongs[data.size()];
int accum[data.size()];
#else
float *totalDist = new float [nbMeans];
int *belongs = new int [data.size()];
int *accum = new int [data.size()];
#endif
int i;
for (i=0; i<nbMeans;i++)
totalDist[i]=0;
for (i=0; i<nbMeans;i++)
accum[i]=0;
for (i=0; i<data.size();i++)
{
float tmp;
belongs[i] = getClassID(data[i], &tmp);
totalDist[belongs[i]] += tmp;
}
float max_dist = 0;
int maxID=0;
for (i=0; i<nbMeans;i++)
if (totalDist[i] > max_dist)
{
max_dist=totalDist[i];
maxID=i;
}
/* for (i=0; i<nbMeans;i++)
if (totalDist[i]/accum[i] > max_dist)
{
max_dist=totalDist[i]/accum[i];
maxID=i;
}
*/
/*cerr << "about to perform split\n";
cerr << "nbMeans = " << nbMeans << endl;
cerr << "length = " << length << endl;
cerr << "maxID = " << maxID << endl;*/
means.resize(nbMeans+1);
means[nbMeans].resize(length);
for (i=0; i<length;i++)
{
float factor = .99 + ((rand() % 2000) *.00001);
//factor = 1.01;
means[nbMeans][i]=means[maxID][i]*factor;
//cerr << means[nbMeans][i] << " " << means[maxID][i] << endl;
}
nbMeans++;
#ifndef STACK_ALLOC
delete [] totalDist;
delete [] belongs;
delete [] accum;
#endif
}
const vector<float> &KMeans::operator[] (int i) const
{
return means[i];
}
void KMeans::bsplit ()
{
int nbMeans = means.size();
int i;
means.resize(nbMeans*2);
for (i=nbMeans;i<nbMeans*2;i++)
{
means[i].resize(length);
for (int j=0; j<length;j++)
{
float factor = .99 + ((rand() % 2000) *.00001);
//factor = 1.01;
means[i][j]=means[i-nbMeans][j]*factor;
//cerr << means[nbMeans][i] << " " << means[maxID][i] << endl;
}
}
nbMeans*=2;
}
void KMeans::update (const vector<float *> &data, int len)
{
int nbMeans = means.size();
#ifdef STACK_ALLOC
float totalDist[nbMeans];
int belongs[data.size()];
int accum[data.size()];
#else
float *totalDist = new float [nbMeans];
int *belongs = new int [data.size()];
int *accum = new int [data.size()];
#endif
int i,j;
for (i=0; i<nbMeans;i++)
totalDist[i]=0;
for (i=0; i<nbMeans;i++)
accum[i]=0;
for (i=0; i<data.size();i++)
{
float tmp;
belongs[i] = getClassID(data[i], &tmp);
totalDist[belongs[i]] += tmp;
}
for (i=0;i<nbMeans;i++)
for (j=0;j<length;j++)
means[i][j]=0;
for (i=0; i<data.size();i++)
{
int meanID=belongs[i];
//cerr << "meanID = " << meanID << endl;
accum[meanID]++;
float *theMean = &means[meanID][0];
float *theData = data[i];
float *end = theData+length;
while (theData < end-3)
{
*theMean++ += *theData++;
*theMean++ += *theData++;
*theMean++ += *theData++;
*theMean++ += *theData++;
}
while (theData<end)
*theMean++ += *theData++;
/*for (j=0;j<length;j++)
means[meanID][j] += data[i][j];*/
}
for (i=0; i<nbMeans;i++)
{
if (accum[i]==0)
{
cerr << "empty vector " << i << "\n";
int id = rand()%data.size();
for (j=0;j<length;j++)
means[i][j] = data[id][j];
} else {
float accum_1 = 1.0/accum[i];
//cerr << "mean " << i << ": (accum = " << accum[i] << ") ";
for (j=0;j<length;j++)
{
//cerr << means[i][j] << " ";
means[i][j] *= accum_1;
//cerr << means[i][j] << " ";
}
//cerr << endl;
}
}
#ifndef STACK_ALLOC
delete [] totalDist;
delete [] belongs;
delete [] accum;
#endif
}
void KMeans::train (int codeSize, const vector<float *> &data, int len, bool binary)
{
int i,j;
//cerr << "void KMeans::train (" << codeSize << ", " << data << ", "<<len <<")" << endl;
length=len;
means.resize(1);
means[0].resize(length);
//accum.resize(1);
for (i=0;i<length;i++)
means[0][i] = 0;
for (i=0;i<data.size();i++)
for (j=0;j<length;j++)
means[0][j] += data[i][j];
//accum[0]=data.size();
for (j=0;j<length;j++)
means[0][j] /= data.size();
int splitID=0;
//cerr << "init done..." << endl;
if (binary)
{
for (i=0;i<codeSize;i++)
{
bsplit ();
for (j=0;j<10;j++)
update(data, len);
}
for (j=0;j<30;j++)
update(data, len);
} else {
for (i=1;i<codeSize;i++)
{
cerr << "iter " << i << endl;
split (data, len);
for (j=0;j<4;j++)
update(data, len);
}
for (j=0;j<30;j++)
update(data, len);
}
}
int KMeans::getClassID (const float *v, float *dist_return) const
{
float min_dist = dist(&means[0][0], v, length);
int minID=0;
for (int i=1;i<means.size();i++)
{
float tmp;
if (dist==euclidian)
tmp = euclidian(&means[i][0], v, length);
else
tmp = dist(&means[i][0], v, length);
if (tmp < min_dist)
{
minID=i;
min_dist=tmp;
}
}
if (dist_return)
*dist_return = min_dist;
//cerr << "classID: " << minID << endl;
return minID;
}
void KMeans::calcDist (const float *v, float *dist_return) const
{
for (int i=0;i<means.size();i++)
{
if (dist==euclidian)
dist_return[i] = euclidian(&means[i][0], v, length);
else
dist_return[i] = dist(&means[i][0], v, length);
}
}
void KMeans::weightMeans (const Vector<float> &w, Vector<float> &out) const
{
if ( !(w.size() == means.size() && out.size() == means[0].size()) )
{
cerr << "sizes don't match in KMeans::weightMeans\n";
cerr << w.size() << " "
<< means.size() << " "
<< out.size() << " "
<< means[0].size() << endl;
}
for (int j=0;j<out.size();j++)
out[j] = 0;
float sum = 0;
for (int i=0;i<means.size();i++)
{
sum += w[i];
}
float norm = 1.0/sum;
for (int i=0;i<means.size();i++)
{
float scale = norm*w[i];
for (int j=0;j<out.size();j++)
out[j] += scale*means[i][j];
}
}
void KMeans::printOn(ostream &out) const
{
out << "<KMeans " << endl;
out << "<means " << means << ">" << endl;
out << "<length " << length << ">" << endl;
out << ">\n";
}
void KMeans::readFrom (istream &in)
{
string tag;
while (1)
{
char ch;
in >> ch;
if (ch == '>') break;
else if (ch != '<')
throw new ParsingException ("KMeans::readFrom : Parse error: '<' expected");
in >> tag;
if (tag == "length")
in >> length;
else if (tag == "means")
in >> means;
else
throw new ParsingException ("KMeans::readFrom : unknown argument: " + tag);
if (!in) throw new ParsingException ("KMeans::readFrom : Parse error trying to build " + tag);
in >> tag;
if (tag != ">")
throw new ParsingException ("KMeans::readFrom : Parse error: '>' expected ");
}
}
istream &operator >> (istream &in, KMeans &mdl)
{
if (!isValidType(in, "KMeans")) return in;
mdl.readFrom(in);
return in;
}
}//namespace FD
syntax highlighted by Code2HTML, v. 0.9.1