// Copyright (C) 2001 Jean-Marc Valin
#include "Cell.h"
#include <string>
#include "ObjectParser.h"
#include "misc.h"
#include <algorithm>
#ifdef HAVE_FLOAT_H
#include <float.h>
#endif
using namespace std;
namespace FD {
DECLARE_TYPE(Cell)
//@implements VQ
void Cell::recursiveSplit (const vector<pair<int, float *> > &data, int level)
{
if (level <= 0)
//if (data.size() < 50)
{
cout << "LEAF: " << data.size() << endl;
return;
}
int dim;
float thresh;
//cerr << "aa\n";
split(data, dim, thresh);
//cerr << "bb\n";
vector<pair<int, float *> > firstData;
vector<pair<int, float *> > secondData;
for (int i=0;i<data.size();i++)
if (data[i].second[dim] < thresh)
{
//cerr << i << "(" << data[i].second[0] << "," << data[i].second[1] << ") goes to first\n";
firstData.insert(firstData.end(), data[i]);
} else {
//cerr << i << "(" << data[i].second[0] << "," << data[i].second[1] << ") goes to second\n";
secondData.insert(secondData.end(), data[i]);
}
splitDimension = dim;
threshold = thresh;
//cout << dimension << endl;
first = new Cell (dimension, numberClasses);
second = new Cell (dimension, numberClasses);
terminal = false;
first->recursiveSplit(firstData, level-1);
second->recursiveSplit(secondData, level-1);
}
void Cell::split(const vector<pair<int, float *> > &data, int &bestDim, float &bestThreshold)
{
bestDim=0;
int nbEqual=0;
bestThreshold=0;
float bestMutual = -FLT_MAX;
for (int i=0;i<dimension;i++)
{
//cerr << "dim " << i << endl;
float threshold;
float currentMutual;
findThreshold(data, i, threshold, currentMutual);
//cerr << "threshold: " << threshold << " currentMutual: " << currentMutual << endl;
bool isBest = false;
if (currentMutual > bestMutual)
{
isBest=true;
nbEqual=0;
}
if (currentMutual == bestMutual)
{
cerr << "randomizing at " << currentMutual << " size = " << data.size() << "\n";
nbEqual++;
if (rand()%nbEqual==0)
isBest=true;
}
if (isBest)
{
bestMutual=currentMutual;
bestDim=i;
bestThreshold=threshold;
}
}
//cerr << "bestDim: " << bestDim << " bestThreshold: " << bestThreshold << endl;
//if (some condition on bestMutual) don't perform the split
//splitWithThreshold(data, bestDim, bestThreshold);
}
/*void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &bestThresh, float &bestScore)
{
if (data.size()==0)
{
bestThresh=0;
bestScore = 0;
return;
}
float min_value = FLT_MAX, max_value = -FLT_MAX;
int min_ind = 0, max_ind = 0;
int i,k;
for (i=0;i<data.size();i++)
{
if (data[i].second[dim] > max_value)
{
max_value = data[i].second[dim];
max_ind=i;
}
if (data[i].second[dim] < min_value)
{
min_value = data[i].second[dim];
min_ind=i;
}
}
bestThresh = 0;
bestScore = -FLT_MAX;
float thresh;
float score;
for (thresh = min_value; thresh < max_value; thresh += (max_value-min_value)/15.0)
{
int sumAi = 0, sumBi = 0;
vector<int> Ai (numberClasses, 0);
vector<int> Bi (numberClasses, 0);
for (k=0;k<data.size();k++)
{
if (data[k].second[dim] >= thresh)
{
sumAi++;
Ai[data[k].first]++;
} else {
sumBi++;
Bi[data[k].first]++;
}
}
double weight = double(sumAi)/data.size();
//cerr << "weight: " << weight << " sumAi = " << sumAi << endl;
score = 0.0;
for (i = 0;i<numberClasses;i++)
{
//cerr << "A[" << i << "] = " << Ai[i] << "\t" << "Ai[i] / sumAi = " << (double( Ai[i] ) / sumAi ) << "\t";
//cerr << "B[" << i << "] = " << Bi[i] << "\t" << "Bi[i] / sumBi = " << (double( Bi[i] ) / sumBi ) << "\t";
if (sumAi)
score -= weight * entropy_funct (double( Ai[i] ) / sumAi );
if (sumBi)
score -= (1-weight) * entropy_funct (double( Bi[i] ) / sumBi );
//cerr << "score = " << score << endl;
}
cerr << "got " << score << " for threshold " << thresh << endl;
if (score > bestScore)
{
bestThresh = thresh;
bestScore = score;
}
}
}*/
/*
void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &bestThresh, float &bestScore)
{
float sum = 0, s2 = 0;
int i,k;
for (i=0;i<data.size();i++)
{
sum += data[i].second[dim];
s2+= sqr(data[i].second[dim]);
}
if (data.size()<=1)
{
bestThresh=0;
bestScore=0;
return;
}
sum /= data.size();
s2=sqrt(s2/data.size() - sqr(sum) );
//cerr << "s2 = " << s2 << " N = " << data.size() << endl;
float min_value = sum - 1.*s2;
float max_value = sum + 1.*s2;
//thresh=sum/data.size();
//if (data.size()==0) thresh=0;
bestThresh = 0;
bestScore = -FLT_MAX;
float thresh;
float score;
for (thresh = min_value; thresh < max_value; thresh += (max_value-min_value)/15.0)
{
int sumAi = 0, sumBi = 0;
vector<int> Ai (numberClasses, 0);
vector<int> Bi (numberClasses, 0);
for (k=0;k<data.size();k++)
{
if (data[k].second[dim] >= thresh)
{
sumAi++;
Ai[data[k].first]++;
} else {
sumBi++;
Bi[data[k].first]++;
}
}
double weight = double(sumAi)/data.size();
score = - numberClasses * .01*abs(thresh-sum)/s2;
//score = 0;
for (i = 0;i<numberClasses;i++)
{
score += - weight * entropy_funct (double( Ai[i] ) / sumAi )
- (1-weight) * entropy_funct (double( Bi[i] ) / sumBi );
}
if (score > bestScore)
{
bestThresh = thresh;
bestScore = score;
}
}
}
*/
static int float_less(const void *a, const void *b)
{
return *((float *)a) < *((float *)b);
}
//find threshold using split at median and mutual information
void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &thresh, float &score)
{
float sum = 0;
int i,k;
if (data.size()==0) thresh=0;
else {
//float sorted[data.size()];
float *sorted = new float [data.size()];
for (i=0;i<data.size();i++)
sorted[i] = data[i].second[dim];
//qsort(sorted,data.size(),sizeof(float), float_less);
sort (sorted,sorted+data.size());
thresh=sorted[data.size()/2];
delete [] sorted;
}
int sumAi = 0, sumBi = 0;
vector<int> Ai (numberClasses, 0);
vector<int> Bi (numberClasses, 0);
for (k=0;k<data.size();k++)
{
if (data[k].second[dim] >= thresh)
{
sumAi++;
Ai[data[k].first]++;
} else {
sumBi++;
Bi[data[k].first]++;
}
}
double weight = double(sumAi)/data.size();
score = 0.0;
for (i = 0;i<numberClasses;i++)
{
score += - weight * entropy_funct (double( Ai[i] ) / sumAi )
- (1-weight) * entropy_funct (double( Bi[i] ) / sumBi );
}
//cerr << score << " " << sumAi << " " << sumBi << " " << weight << " " << Ai[0] << " " << Ai[1] << " " << Bi[0] << " " << Bi[1] << endl;
}
//find threshold using split at average and mutual information
/*void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &thresh, float &score)
{
float sum = 0;
int i,k;
for (i=0;i<data.size();i++)
sum += data[i].second[dim];
thresh=sum/data.size();
if (data.size()==0) thresh=0;
int sumAi = 0, sumBi = 0;
vector<int> Ai (numberClasses, 0);
vector<int> Bi (numberClasses, 0);
for (k=0;k<data.size();k++)
{
if (data[k].second[dim] >= thresh)
{
sumAi++;
Ai[data[k].first]++;
} else {
sumBi++;
Bi[data[k].first]++;
}
}
double weight = double(sumAi)/data.size();
score = 0.0;
for (i = 0;i<numberClasses;i++)
{
score += - weight * entropy_funct (double( Ai[i] ) / sumAi )
- (1-weight) * entropy_funct (double( Bi[i] ) / sumBi );
}
}
*/
/*void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &thresh, float &score)
{
float sum = 0;
int i,k;
for (i=0;i<data.size();i++)
sum += data[i].second[dim];
thresh=sum/data.size();
if (data.size()==0) thresh=0;
vector<int> scores (numberClasses, 0);
for (i=0;i<data.size();i++)
if (data[i].second[dim] >= thresh)
{
scores[data[i].first]++;
}
else
{
scores[data[i].first]--;
}
score = 0.0;
for (i=0;i<numberClasses;i++)
score += abs(scores[i]);
}*/
/*
void Cell::findThreshold(const vector<pair<int, float *> > &data, int dim, float &bestThresh, float &bestScore)
{
float sum = 0, s2 = 0;
int i,k;
for (i=0;i<data.size();i++)
{
sum += data[i].second[dim];
s2+= sqr(data[i].second[dim]);
}
if (data.size()<=1)
{
bestThresh=0;
bestScore=0;
return;
}
sum /= data.size();
s2=sqrt(s2/data.size() - sqr(sum) );
//cerr << "s2 = " << s2 << " N = " << data.size() << endl;
float min_value = sum - 1.5*s2;
float max_value = sum + 1.5*s2;
//thresh=sum/data.size();
//if (data.size()==0) thresh=0;
bestThresh = 0;
bestScore = -FLT_MAX;
float thresh;
float score;
for (thresh = min_value; thresh < max_value; thresh += (max_value-min_value)/15.0)
{
int diff=0;
//vector<int> scores (numberClasses, 0);
int binA[numberClasses];
int binB[numberClasses];
for (i=0;i<numberClasses;i++)
{
binA[i]=0;
binB[i]=0;
}
for (i=0;i<data.size();i++)
if (data[i].second[dim] >= thresh)
{
diff++;
binA[data[i].first]++;
}
else
{
diff--;
binB[data[i].first]++;
}
score = 0.0;
for (i=0;i<numberClasses;i++)
score += abs(binA[i]-binB[i])/(binA[i]+binB[i]+15);
score -= .3*numberClasses*abs(diff)/data.size();
//cerr << .49*numberClasses*abs(diff)/data.size() << endl;
if (score > bestScore)
{
bestThresh = thresh;
bestScore = score;
}
}
}
*/
int Cell::setNumbering(int start)
{
if (terminal)
{
cellID=start;
//cerr << start << endl;
return start+1;
} else {
return second->setNumbering(first->setNumbering(start));
}
}
int Cell::belongs(float *vect) const
{
if (terminal) return cellID;
if (vect[splitDimension] < threshold)
return first->belongs(vect);
else
return second->belongs(vect);
}
void Cell::calcTemplate (const vector<float *> &features, vector<int> &templ) const
{
for (vector<float *>::const_iterator feature = features.begin();
feature < features.end(); feature++)
{
//cerr << "(" << (*feature)[0] << "," << (*feature)[1] << "): " << belongs(*feature) << endl;
templ[belongs(*feature)]++;
}
}
void Cell::printOn(ostream &out) const
{
out << "<Cell " << endl;
out << "<dimension " << dimension << ">" << endl;
out << "<numberClasses " << numberClasses << ">" << endl;
out << "<terminal " << terminal << ">" << endl;
if (terminal)
{
out << "<cellID " << cellID << ">" << endl;
} else {
out << "<threshold " << threshold << ">" << endl;
out << "<splitDimension " << splitDimension << ">" << endl;
out << "<first " << *first << ">" << endl;;
out << "<second " << *second << ">" << endl;;
}
out << ">\n";
}
ostream &operator << (ostream &out, const Cell &cell)
{
cell.printOn(out);
return out;
}
void Cell::readFrom (istream &in)
{
string tag;
while (1)
{
char ch;
in >> ch;
if (ch == '>') break;
in >> tag;
if (tag == "dimension")
in >> dimension;
else if (tag == "numberClasses")
in >> numberClasses;
else if (tag == "terminal")
in >> terminal;
else if (tag == "cellID")
in >> cellID;
else if (tag == "threshold")
in >> threshold;
else if (tag == "splitDimension")
in >> splitDimension;
else if (tag == "first")
{
Cell *tmp = new Cell;
in >> *tmp;
first = tmp;
}
else if (tag == "second")
{
Cell *tmp = new Cell;
in >> *tmp;
second = tmp;
} else
throw new ParsingException ("Cell::readFrom : unknown argument: " + tag);
if (!in) throw new ParsingException ("Cell::readFrom : Parse error trying to build " + tag);
in >> tag;
if (tag != ">") throw new ParsingException ("Cell::readFrom : Parse error: '>' expected ");
}
}
istream &operator >> (istream &in, Cell &cell)
{
if (!isValidType(in, "Cell"))
return in;
cell.readFrom(in);
return in;
}
}//namespace FD
syntax highlighted by Code2HTML, v. 0.9.1