/*---------------------------------------------------------------------------*
 *                                   IT++			             *
 *---------------------------------------------------------------------------*
 * Copyright (c) 1995-2004 by Tony Ottosson, Thomas Eriksson, Pål Frenger,   *
 * Tobias Ringström, and Jonas Samuelsson.                                   *
 *                                                                           *
 * Permission to use, copy, modify, and distribute this software and its     *
 * documentation under the terms of the GNU General Public License is hereby *
 * granted. No representations are made about the suitability of this        *
 * software for any purpose. It is provided "as is" without expressed or     *
 * implied warranty. See the GNU General Public License for more details.    *
 *---------------------------------------------------------------------------*/

/*! 
  \file 
  \brief Gaussian mixture model

  1.3

  2004/09/28 07:21:31
*/


#include "base/random.h"
#include "base/timing.h"
#include "srccode/vqtrain.h"
#include "base/matfunc.h"
#include "base/specmat.h"
#include "srccode/gmm.h"
#include <string>
#include <fstream>

using std::ifstream;
using std::ofstream;
using std::cout;
using std::endl;


namespace itpp {

  GMM::GMM()
  {
    d=0;
    M=0;
  }

  GMM::GMM(string filename)
  {
    load(filename);
  }

  GMM::GMM(int M_in, int d_in)
  {
    M=M_in;
    d=d_in;
    m=zeros(M*d);
    sigma=zeros(M*d);
    w=1./M*ones(M);

    for (int i=0;i<M;i++) {
      w(i)=1.0/M;
    }
    compute_internals();
  }

  void GMM::init_from_vq(const vec &codebook, int dim)
  {

    mat		C(dim,dim);
    int		i;
    vec		v;

    d=dim;
    M=codebook.length()/dim;

    m=codebook;
    w=ones(M)/double(M);

    C.clear();
    for (i=0;i<M;i++) {
      v=codebook.mid(i*d,d);
      C=C+outer_product(v,v);
    }
    C=1./M*C;
    sigma.set_length(M*d);
    for (i=0;i<M;i++) {
      sigma.replace_mid(i*d,diag(C));
    }

    compute_internals();
  }

  //void GMM::init(const vec &m_in, const vec &sigma_in, const vec &w_in)
  //{
  //	m=m_in;
  //	sigma=sigma_in;
  //	w=w_in;
  //
  //	compute_internals();
  //}
  void GMM::init(const vec &w_in, const mat &m_in, const mat &sigma_in)
  {
    int		i,j;
    d=m_in.rows();
    M=m_in.cols();

    m.set_length(M*d);
    sigma.set_length(M*d);
    for (i=0;i<M;i++) {
      for (j=0;j<d;j++) {
	m(i*d+j)=m_in(j,i);
	sigma(i*d+j)=sigma_in(j,i);
      }
    }
    w=w_in;

    compute_internals();
  }

  void GMM::set_mean(const mat &m_in)
  {
    int		i,j;

    d=m_in.rows();
    M=m_in.cols();

    m.set_length(M*d);
    for (i=0;i<M;i++) {
      for (j=0;j<d;j++) {
	m(i*d+j)=m_in(j,i);
      }
    }
    compute_internals();
  }

  void GMM::set_mean(int i, const vec &means, bool compflag)
  {
    m.replace_mid(i*length(means),means); 
    if (compflag) compute_internals(); 
  }

  void GMM::set_covariance(const mat &sigma_in)
  {
    int		i,j;

    d=sigma_in.rows();
    M=sigma_in.cols();

    sigma.set_length(M*d);
    for (i=0;i<M;i++) {
      for (j=0;j<d;j++) {
	sigma(i*d+j)=sigma_in(j,i);
      }
    }
    compute_internals();
  }

  void GMM::set_covariance(int i, const vec &covariances, bool compflag) 
  {
    sigma.replace_mid(i*length(covariances),covariances); 
    if (compflag) compute_internals(); 
  }

  void GMM::marginalize(int d_new)
  {
    it_error_if(d_new>d,"GMM.marginalize: cannot change to a larger dimension");

    vec		mnew(d_new*M),sigmanew(d_new*M);
    int		i,j;

    for (i=0;i<M;i++) {
      for (j=0;j<d_new;j++) {
	mnew(i*d_new+j)=m(i*d+j);
	sigmanew(i*d_new+j)=sigma(i*d+j);
      }
    }
    m=mnew;
    sigma=sigmanew;
    d=d_new;

    compute_internals(); 
  }

  void GMM::join(const GMM &newgmm)
  {
    if (d==0) {
      w=newgmm.w;
      m=newgmm.m;
      sigma=newgmm.sigma;
      d=newgmm.d;
      M=newgmm.M;
    } else {
      it_error_if( d!=newgmm.d,"GMM.join: cannot join GMMs of different dimension");

      w=concat(double(M)/(M+newgmm.M)*w,double(newgmm.M)/(M+newgmm.M)*newgmm.w);
      w=w/sum(w);
      m=concat(m,newgmm.m);
      sigma=concat(sigma,newgmm.sigma);

      M=M+newgmm.M;
    }
    compute_internals(); 
  }

  void GMM::clear()
  {
    w.set_length(0);
    m.set_length(0);
    sigma.set_length(0);
    d=0;
    M=0;
  }

  void GMM::save(string filename)
  {
    ofstream	f(filename.c_str());
    int			i,j;

    f << M << " " << d << endl ;
    for (i=0;i<w.length();i++) {
      f << w(i) << endl ;
    }
    for (i=0;i<M;i++) {
      f << m(i*d) ;
      for (j=1;j<d;j++) {
	f << " " << m(i*d+j) ;
      }
      f << endl ;
    }
    for (i=0;i<M;i++) {
      f << sigma(i*d) ;
      for (j=1;j<d;j++) {
	f << " " << sigma(i*d+j) ;
      }
      f << endl ;
    }
  }

  void GMM::load(string filename)
  {
    ifstream	GMMFile(filename.c_str());
    long		i,j;

    it_error_if(!GMMFile,std::string("GMM::load : cannot open file ")+filename);

    GMMFile >> M >> d ;


    w.set_length(M);
    for (i=0;i<M;i++) {
      GMMFile >> w(i) ;
    }	
    m.set_length(M*d);
    for (i=0;i<M;i++) {
      for (j=0;j<d;j++) {
	GMMFile >> m(i*d+j) ;
      }
    }	
    sigma.set_length(M*d);
    for (i=0;i<M;i++) {
      for (j=0;j<d;j++) {
	GMMFile >> sigma(i*d+j) ;
      }
    }	
    compute_internals();
    cout << "  mixtures:" << M << "  dim:" << d << endl ;
  }

  double GMM::likelihood(const vec &x)
  {
    double	fx=0;
    int		i;

    for (i=0;i<M;i++) {
      fx+=w(i)*likelihood_aposteriori(x, i);
    }
    return fx;
  }

  vec GMM::likelihood_aposteriori(const vec &x)
  {
    vec		v(M);
    int		i;

    for (i=0;i<M;i++) {
      v(i)=w(i)*likelihood_aposteriori(x, i);
    }
    return v;
  }

  double GMM::likelihood_aposteriori(const vec &x, int mixture)
  {
    int		j;
    double	s;

    it_error_if(d!=x.length(),"GMM::likelihood_aposteriori : dimensions does not match");
    s=0;
    for (j=0;j<d;j++) {
      s+=normexp(mixture*d+j)*sqr(x(j)-m(mixture*d+j));
    }
    return normweight(mixture)*std::exp(s);;
  }

  void GMM::compute_internals()
  {
    int		i,j;
    double	s;
    double	constant=1.0/std::pow(2*pi,d/2.0);

    normweight.set_length(M);
    normexp.set_length(M*d);

    for (i=0;i<M;i++) {
      s=1;
      for (j=0;j<d;j++) {
	normexp(i*d+j)=-0.5/sigma(i*d+j);  // check time
	s*=sigma(i*d+j);
      }
      normweight(i) = constant/std::sqrt(s);
    }

  }

  vec GMM::draw_sample()
  {
    static bool	first=true;
    static vec	cumweight;
    double	u=randu();
    int		k;

    if (first) {
      first=false;
      cumweight=cumsum(w);
      it_error_if(std::abs(cumweight(length(cumweight)-1)-1)>1e-6,"weight does not sum to 0");
      cumweight(length(cumweight)-1)=1;
    }
    k=0;
    while (u>cumweight(k)) k++;

    return elem_mult(sqrt(sigma.mid(k*d,d)),randn(d))+m.mid(k*d,d);
  }

  GMM gmmtrain(Array<vec> &TrainingData, int M, int NOITER, bool VERBOSE)
  {
    mat			mean;
    int			i,j,d=TrainingData(0).length();
    vec			sig;
    GMM			gmm(M,d);
    vec			m(d*M);
    vec			sigma(d*M);
    vec			w(M);
    vec			normweight(M);
    vec			normexp(d*M);
    double		LL=0,LLold,fx;
    double		constant=1.0/std::pow(2*pi,d/2.0);
    int			T=TrainingData.length();
    vec			x1;
    int			t,n;
    vec			msum(d*M);
    vec			sigmasum(d*M);
    vec			wsum(M);
    vec			p_aposteriori(M);
    vec			x2;
    double		s;
    vec			temp1,temp2;
    //double		MINIMUM_VARIANCE=0.03;

    //-----------initialization-----------------------------------

    mean=vqtrain(TrainingData,M,200000,0.5,VERBOSE);
    for (i=0;i<M;i++) gmm.set_mean(i,mean.get_col(i),false);
    //	for (i=0;i<M;i++) gmm.set_mean(i,TrainingData(randi(0,TrainingData.length()-1)),false);
    sig=zeros(d);
    for (i=0;i<TrainingData.length();i++) sig+=sqr(TrainingData(i));
    sig/=TrainingData.length();
    for (i=0;i<M;i++) gmm.set_covariance(i,0.5*sig,false);

    gmm.set_weight(1.0/M*ones(M));

    //-----------optimization-----------------------------------

    tic();
    for (i=0;i<M;i++) {
      temp1=gmm.get_mean(i);
      temp2=gmm.get_covariance(i);
      for (j=0;j<d;j++) {
	m(i*d+j)=temp1(j);
	sigma(i*d+j)=temp2(j);
      }
      w(i)=gmm.get_weight(i);
    }
    for (n=0;n<NOITER;n++) {
      for (i=0;i<M;i++) {
	s=1;
	for (j=0;j<d;j++) {
	  normexp(i*d+j)=-0.5/sigma(i*d+j);  // check time
	  s*=sigma(i*d+j);
	}
	normweight(i) = constant*w(i)/std::sqrt(s);
      }
      LLold=LL;
      wsum.clear();
      msum.clear();
      sigmasum.clear();
      LL=0;
      for (t=0;t<T;t++) {
	x1=TrainingData(t);
	x2=sqr(x1);
	fx=0;
	for (i=0;i<M;i++) {
	  s=0;
	  for (j=0;j<d;j++) {
	    s+=normexp(i*d+j)*sqr(x1(j)-m(i*d+j));
	  }
	  p_aposteriori(i)=normweight(i)*std::exp(s);
	  fx+=p_aposteriori(i);
	}
	p_aposteriori/=fx;
	LL=LL+std::log(fx);

	for (i=0;i<M;i++) {
	  wsum(i)+=p_aposteriori(i);
	  for (j=0;j<d;j++) {
	    msum(i*d+j)+=p_aposteriori(i)*x1(j);
	    sigmasum(i*d+j)+=p_aposteriori(i)*x2(j);
	  }
	}
      }
      for (i=0;i<M;i++) {
	for (j=0;j<d;j++) {
	  m(i*d+j)=msum(i*d+j)/wsum(i);
	  sigma(i*d+j)=sigmasum(i*d+j)/wsum(i)-sqr(m(i*d+j));
	}
	w(i)=wsum(i)/T;
      }
      LL=LL/T;

      if (std::abs((LL-LLold)/LL) < 1e-6) break;
      if (VERBOSE) {
	cout << n << ":   " << LL << "   " << std::abs((LL-LLold)/LL) << "   " << toc() <<  endl ;
	cout << "---------------------------------------" << endl ;
	tic();
      } else {
	cout << n << ": LL =  " << LL << "   " << std::abs((LL-LLold)/LL) << "\r" ;cout.flush();
      }
    }
    for (i=0;i<M;i++) {
      gmm.set_mean(i,m.mid(i*d,d),false);
      gmm.set_covariance(i,sigma.mid(i*d,d),false);
    }
    gmm.set_weight(w);
    return gmm;
  }

} // namespace itpp


syntax highlighted by Code2HTML, v. 0.9.1