// SKKの辞書を管理するクラス群
// SKKFileDic関係をclean up しないといけない

#ifdef HAVE_CONFIG_H
# include "config.h"
#endif

#include <fcntl.h>
#include <unistd.h>
#include <netdb.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <stdlib.h>
#include <netinet/in.h>

#include <stdio.h>
#include <string.h>
#include <map>

#ifndef __GNUC__
# ifdef HAVE_ALLOCA_H
#  include <alloca.h>
# endif
#endif

#include "skkconv.h"

struct SKKOkuriEnt{
    std::list<jstring_t> head;
    jstring_t okuri;
};

struct SKKDicEnt{
    bool cached;// learn_map上にある。(このエントリの解放はセーブ時に行なわれる)
    int hash; // learn_map内でのhash値

    jstring_t head;
    cchar okuri_head;
    std::list<jstring_t> okurinasi;
    std::list<SKKOkuriEnt> okuri;

    void commit(SKKStat *s,jstring_t *c);
private:
};

class SKKNetDic : public SKKDic{
public:
    SKKNetDic();
    bool init();
    virtual Candidates *getCandidates(SKKStat *);
    virtual void commit(SKKStat *s,jstring_t *a);
private:
    bool connect_serv(char *hostname);
    bool send_request(SKKStat *);
    void split_reply(Candidates *,char *r);
    char *recv_reply();
    int sock_fd;// -1の時は接続されていない。
};

class SKKFileDic : public SKKDic{
public:
    SKKFileDic();
    virtual ~SKKFileDic();
    bool init(char *fn,bool isLearn);

    virtual Candidates *getCandidates(SKKStat *);
    virtual void commit(SKKStat *,jstring_t *);
    virtual void save();
private:
    bool do_map(char *fn);
    bool make_hash();
    int next_line(int );
    void add_hash_entry(int i,char *);
    int compare_jstring_with_str(jstring_t *sy,char *str);
    void add_to_learn_ent(SKKStat *,SKKDicEnt *);

    SKKDicEnt *get_dic_ent(SKKStat *);
    SKKDicEnt *get_dic_ent_from_off_map(SKKStat *st,jstring_t *s,int h);
    SKKDicEnt *get_dic_ent_from_learn(jstring_t *head,cchar okuri,int h);

    void save_1(FILE *fp,bool okuri);
    void save_1_orig(FILE *fp,bool okuri);
    void save_1_learn(FILE *fp,bool okuri);

    std::map<int,int> off_map;// hash , offset
    std::map<int,SKKDicEnt *> learn_map;
    int m_col; // number of hash collision
    int m_size;
    char *m_fn;
    char *m_ptr;
    bool m_isLearn;
    bool mIsDirty;
};

class SKKMasterDic : public SKKDic{
public:
    SKKMasterDic();
    virtual ~SKKMasterDic();
    virtual Candidates *getCandidates(SKKStat *);
    virtual void commit(SKKStat *,jstring_t *);
private:
    void open_file_dics();
    void open_file_dic(char *);
    Candidates *merge_candidate(Candidates *x,Candidates *y);
    void merge_candidate_append(Candidates *dst,Candidates *src);
    
    SKKNetDic *m_netdic;
    std::list<SKKFileDic *>m_filedics;
    SKKFileDic *m_learndic;
};

SKKDic *createSKKDic()
{
    return new SKKMasterDic();
}

int skk_serv_port = 1178;

// global functions for skk dictionary
static int linelen(char *s)
{
    int i;
    for (i = 0; s[i] != '\n'; i++);
    return i;
}

static char *linedup(char *s)
{
    int l = linelen(s);
    char *b = (char *)malloc(l+1);
    memcpy(b, s ,l);
    b[l] = 0;
    return b;
}

static int calc_hash(jstring_t *s)
{
    jstring_t::iterator i;
    int h = 0;
    for (i = s->begin(); i!= s->end(); i++) {
	h = h * 33 + (*i);
    }
    return h;
}

static int calc_str_hash(char *s)
{
    jstring_t sy;
    int i;
    for ( i = 0 ; s[i] && s[i]!=' ' ; i++);
    char *buf= (char *)alloca(i+2);
    strncpy(buf,s,i);
    buf[i] = 0;
    str_to_jstring(&sy,buf);
    return calc_hash(&sy);
}

static void make_dic_key_str(SKKStat *st,jstring_t *s)
{
    append_jstring(s,&st->head);
    if ( st->okuri_head ){
	s->push_back(st->okuri_head);
    }
}

void get_dic_ent_str(SKKDicEnt *e,jstring_t *s)
{
    append_jstring(s,&e->head);
    if ( e->okuri_head ){
	s->push_back(e->okuri_head);
    }
    s->push_back(' ');
    s->push_back('/');

    std::list<jstring_t>::iterator i;
    for ( i = e->okurinasi.begin(); i != e->okurinasi.end() ; i++){
	append_jstring(s,&(*i));
	s->push_back('/');
    }

    std::list<SKKOkuriEnt>::iterator j;
    for ( j = e->okuri.begin() ; j != e->okuri.end() ; j ++){
	s->push_back('[');
	append_jstring(s,&((*j).okuri));
	for ( i = (*j).head.begin() ;  i != (*j).head.end() ; i++){
	    s->push_back('/');
	    append_jstring(s,&(*i));
	}
	s->push_back(']');
	s->push_back('/');
    }
}

jstring_t::iterator parse_okuri_ent(SKKDicEnt *e,jstring_t::iterator i)
{
    SKKOkuriEnt o;

    //一文字目に送り仮名
    while(*i != '/'){
	o.okuri.push_back(*i);
	i++;
    }
    i++;
    while( *i != ']'){
	jstring_t x;
	while ( *i != ']' && *i != '/'){
	    x.push_back(*i);
	    i++;
	}
	o.head.push_back(x);
	if ( *i == '/'){
	    i++;
	}
    }
    i++;
    i++;
    e->okuri.push_back(o);
    return i;
}

jstring_t::iterator parse_a_ent(SKKDicEnt *e,jstring_t::iterator i)
{
    if ( *i == '['){
	i++;
	return parse_okuri_ent(e,i);
    }
    jstring_t x;
    while ( *i != '/'){
	x.push_back(*i);
	i++;
    }
    e->okurinasi.push_back(x);
    i++;
    return i;
}

SKKDicEnt *make_dic_ent_from_line(char *ls)
{
    SKKDicEnt *e;

    char *l = linedup(ls);
    jstring_t s;
    str_to_jstring(&s,l);
    int h = calc_str_hash(l);
    free(l);
    jstring_t::iterator i;

    s.push_back(0);

    e = new SKKDicEnt();
    e->okuri_head = 0;
    e->cached = false;
    e->hash = h;

    for ( i = s.begin(); i != s.end() && *i != ' ' ; i++){
	if ( *s.begin()> 255 && *i < 256 ){
	    // okuri
	    e->okuri_head = *i;
	    i++;
	    break;
	}else{
	    e->head.push_back(*i);
	}
    }
  
    if ( i == s.end() || *i != ' '){
	delete e;
	return 0;
    }
    // 最初の/の次の文字から

    i++;
    i++;

    while(i != s.end() && *i){
	i = parse_a_ent(e,i);
	if ( *i == 0 ){
	    return e;
	}
    }
    printf("Syntax error in skk dictionary.\n");
    delete e;
    return 0;
}

//リストの先頭にある要素を追加する、同じものがリスト中にあれば
//削除する
void jstring_unique_prepend(std::list<jstring_t> *sl,jstring_t *s)
{
    std::list<jstring_t>::iterator i;
    for ( i = sl->begin() ; i != sl->end() ; i++){
	if ( *i == *s ){
	    sl->erase(i);
	    goto prepend;
	}
    }
prepend:
    sl->push_front(*s);
}

void fprint_jstring(FILE *fp,jstring_t *s)
{
    char *str;
    str = jstring_to_str(s);
    fprintf(fp,str);
    free((void *)str);
}

//
SKKMasterDic::SKKMasterDic()
{
    m_netdic = new SKKNetDic();
    if (!m_netdic->init()) {
	delete m_netdic;
	m_netdic = 0;
    }
    m_learndic = 0;
    open_file_dics();
}

SKKMasterDic::~SKKMasterDic()
{
    if ( m_learndic){
	m_learndic->save();
	delete m_learndic;
    }
    if ( m_netdic ){
	delete m_netdic;
    }
    std::list<SKKFileDic*>::iterator i;
    for ( i = m_filedics.begin() ; i != m_filedics.end() ; i++){
	delete *i;
    }
}

Candidates *SKKMasterDic::getCandidates(SKKStat *s)
{
    Candidates *n=0,*f=0;

    if ( m_learndic ){
	f = m_learndic->getCandidates(s);
    }
  
    std::list<SKKFileDic *>::iterator i;
    for ( i = m_filedics.begin() ; i != m_filedics.end() ; i++ ){
	SKKFileDic *fdic= *i;
	Candidates *tmp1,*tmp2;
	tmp1 = fdic->getCandidates(s);
	tmp2 = merge_candidate(f,tmp1);
	if ( tmp1 ){
	    delete tmp1;
	}
	if ( f ){
	    delete f;
	}
	f = tmp2;
    }

    if ( m_netdic ){
	n = m_netdic->getCandidates(s);
    }
    Candidates *tmp;
    tmp = merge_candidate(f,n);
    if ( tmp ){
	if (f){
	    delete f;
	}
	if (n ){
	    delete n;
	}
	return tmp;
    }
    return 0;
}

void SKKMasterDic::commit(SKKStat *s,jstring_t *c)
{
    if ( c->size() == 0 ){
	return ;
    }
    if ( s->cands &&  s->cands->nth == 0){
	return ;
    }

    if ( m_learndic ){
	m_learndic->commit(s,c);
    }
}

void SKKMasterDic::open_file_dics()
{
    char *fnbuf;
    atom_t a=0;
    char *s;

    // ホームディレクトリにある辞書
    do{
	a = get_bound_atoms(A_skk_personal_dic,a);
	if (a) {
	    s = get_atom_name(a);
	    fnbuf = (char *)alloca(strlen(homedir)+ strlen(s)+2);
	    sprintf(fnbuf,"%s/%s",homedir,s);
	    open_file_dic(fnbuf);
	}
    }while(a);

    // 共有されている辞書
    a = 0;
    do{
	a = get_bound_atoms(A_skk_share_dic, a);
	if (a) {
	    s = get_atom_name(a);
	    open_file_dic(s);
	}
    } while(a);

    // 書き込みできる辞書
    a = get_bound_atoms(A_skk_learn_dic,0);
    if ( nr_bound_atoms(A_skk_learn_dic) > 1 ){
	printf("many files are specified as skk learn dic.\n");
    }
    if ( a && ( s = get_atom_name(a))){
	fnbuf = (char *)alloca(strlen(homedir)+ strlen(s)+2);
	sprintf(fnbuf,"%s/%s",homedir,s);
	SKKFileDic *fdic = new SKKFileDic();
	fdic->init(fnbuf,true);
	m_learndic = fdic;
    }
}

void SKKMasterDic::open_file_dic(char *fn)
{
    SKKFileDic *fdic= new SKKFileDic();
    if ( fdic->init(fn,false)){
	m_filedics.push_back(fdic);
    }else{
	delete fdic;
    }
}

Candidates *SKKMasterDic::merge_candidate(Candidates *x,Candidates *y)
{
    Candidates *z;
    if ( ! x && !y ){
	return 0;
    }
    z = new Candidates();
    merge_candidate_append(z,x);
    merge_candidate_append(z,y);
    return z;
}

void SKKMasterDic::merge_candidate_append(Candidates *dst, Candidates *src)
{
    if ( ! src ){
	return ;
    }
    std::vector<jstring_t>::iterator it,jt;
    for ( it = src->cands.begin() ; it != src->cands.end() ; it++){
	bool bFound = false;
	for ( jt = dst->cands.begin() ; jt != dst->cands.end() ; jt++){
	    if ( *jt == *it ){
		bFound = true;
	    }
	}
	if ( !bFound ){
	    dst->cands.push_back(*it);
	}
    }
}

SKKNetDic::SKKNetDic()
{
    sock_fd = -1;
}

bool SKKNetDic::init()
{
    char *h;
    h = my_getenv("SKKSERVER");
    if ( h ){
	if ( !connect_serv( h )){
	    printf("failed to connect SKKSERVER(%s)\n",h);
	    return false;
	}else{
	    return true;
	}
    }
    return false;
}

void SKKNetDic::commit(SKKStat *,jstring_t *)
{
}

bool SKKNetDic::connect_serv(char *hostname)
{
    struct sockaddr_in peer;
    struct hostent *server_ip;

    sock_fd = -1;

    sock_fd = socket(AF_INET,SOCK_STREAM,0);
    server_ip = gethostbyname(hostname);
    if ( sock_fd == -1 || !server_ip){
	return false;
    }

    memset(&peer,0,sizeof(peer));
    peer.sin_family = AF_INET;
    peer.sin_port = htons(skk_serv_port);
    memcpy(&peer.sin_addr,server_ip->h_addr,server_ip->h_length);  

    if ( connect(sock_fd,(struct sockaddr*)&peer,sizeof(peer)) == -1){
	sock_fd = -1;
	return false;
    }
    return true;
}

Candidates *SKKNetDic::getCandidates(SKKStat *st)
{
    Candidates *cand = new Candidates();
    if ( sock_fd > -1 && send_request(st)){
	char *r;
	r = recv_reply();
	if ( r){
	    split_reply(cand,r);
	    free(r);
	}
    }
    if ( cand->cands.size() == 0){
	delete cand;
	return 0;
    }
    return cand;
}

bool SKKNetDic::send_request(SKKStat *st)
{
    char *p,*de;
    int c;
    p = (char *)alloca( st->head.size()*2+10);
    p[0]='1';
    p[1]=0;
    jstring_t sy;
    make_dic_key_str(st,&sy);
    de = jstring_to_str(&sy);
    strcat(p,de);
    strcat(p," \n");
    free(de);

    c = strlen(p);
    if (send(sock_fd,p,c,0)!= c ){
	printf("failed to send.");
	return false;
    }
    return true;
}

char *SKKNetDic::recv_reply()
{
    int len=16;
    char *p;
    p = (char *)malloc(len);
    do{
	int l,i;
	l = recv(sock_fd,p,len-1,MSG_PEEK);
	if ( l ==0 ){
	    free(p);
	    return NULL;
	}
	for ( i = 0 ; i < l ; i++){
	    if ( p[i]=='\n' ){
		recv(sock_fd,p,i+1,0);
		p[i]=0;
		if ( i>0 && p[i-1]=='\r'){
		    p[i-1]=0;
		}
		return p;
	    }
	}
	if ( l == len-1 ){
	    len = len *2;
	    p = (char *)realloc(p,len);
	}
    }while(1);
    return NULL;
    
}

void SKKNetDic::split_reply(Candidates *c,char *r)
{
    char *buf;
    int l,i,o;
    buf = (char *)alloca(strlen(r));
    l = strlen (r );
    o =0;
    for ( i = 2 ; i < l ; i++){
	if ( r[i] != '/'){
	    buf[o] = r[i];
	    o++;
	}else{
	    buf[o] = 0;
	    o = 0;
	    jstring_t s;
	    str_to_jstring(&s,buf);
	    c->cands.push_back(s);
	}
    }
}

SKKFileDic::SKKFileDic()
{
    m_fn = 0;
    m_ptr = 0;
    mIsDirty = false;
}

SKKFileDic::~SKKFileDic()
{
    if ( m_fn ){
	free(m_fn);
    }
    if ( m_ptr ){
	munmap(m_ptr,m_size);
    }
}

bool SKKFileDic::init(char *fn,bool isLearn)
{
    m_isLearn = isLearn;
    
    //学習用の辞書なら失敗しても可
    if (!do_map(fn)){
	return isLearn;
    }
    return make_hash();
}

Candidates *SKKFileDic::getCandidates(SKKStat *st)
{
    SKKDicEnt *e;
    e = get_dic_ent(st);
    if ( !e ){
	return 0;
    }

    Candidates *c = new Candidates();

    //送り仮名があって、そのエントリで一致したらそれを返す
    if ( st->okuri.size()){
	std::list<SKKOkuriEnt>::iterator i;
	for ( i = e->okuri.begin() ; i != e->okuri.end() ; i++){
	    if ( (*i).okuri == st->okuri ){
		std::list<jstring_t>::iterator j;
		for ( j = (*i).head.begin() ; j != (*i).head.end() ; j++){
		    c->cands.push_back(*j);
		}
		return c;
	    }
	}
    }

    std::list<jstring_t>::iterator i;
    for ( i = e->okurinasi.begin() ; i != e->okurinasi.end() ; i ++){
	c->cands.push_back(*i);
    }

    if ( !e->cached ){
	delete e;
    }

    return c;
}

SKKDicEnt *SKKFileDic::get_dic_ent(SKKStat *st)
{
    SKKDicEnt *e;
    jstring_t sy;
    int h;

    make_dic_key_str(st, &sy);
    h = calc_hash(&sy);

    e = get_dic_ent_from_learn(&st->head,st->okuri_head,h);
    if ( e ){
	return e;
    }

    e = get_dic_ent_from_off_map(st,&sy,h);

    //学習用の辞書ならキャッシュする
    if ( e && m_isLearn ){
	add_to_learn_ent(st,e);
    }
    return e;
}

SKKDicEnt *
SKKFileDic::get_dic_ent_from_off_map(SKKStat *st,jstring_t *sy,int h)
{
    if ( m_ptr == 0 ){
	return 0;
    }
    std::map<int,int>::iterator i;
    int off;
    SKKDicEnt *e;
    do{
	i = off_map.find(h);
	if ( i == off_map.end()){
	    return 0;
	}
	off = (*i).second;
	if ( !compare_jstring_with_str(sy,&m_ptr[off])){
	    e = make_dic_ent_from_line(&m_ptr[off]);
	    return e;
	}
	h++;
    }while(1);

    return 0;
}

SKKDicEnt *SKKFileDic::get_dic_ent_from_learn(jstring_t *s,cchar okuri,int h)
{
    std::map<int,SKKDicEnt *>::iterator i;
    while(1){
	i = learn_map.find(h);
	if ( i == learn_map.end() ){
	    return 0;
	}
	SKKDicEnt *e;
	e = (*i).second;
	if ( e->head == *s && e->okuri_head == okuri){
	    return e;
	}
	h++;
    }
    return 0;
}

void SKKFileDic::commit(SKKStat *st,jstring_t *c)
{
    SKKDicEnt *e;
    e = get_dic_ent(st);
    if ( !e){
	e = new SKKDicEnt();
	add_to_learn_ent(st,e);
    }
    e->commit(st,c);

    mIsDirty = true;
}

void SKKFileDic::save()
{
    if ( !mIsDirty ){
	return ;
    }

    FILE *fp;
    char *fn;
    fn = (char *)alloca(strlen(m_fn)+5);
    strcpy(fn,m_fn);
    strcat(fn,".bak");
    fp = fopen( fn, "w");
    if ( ! fp ){
        return ;
    }
    printf("saving dictionary.(%s)\n",fn);
    save_1(fp,true);
    save_1(fp,false);
    fclose(fp);
    rename(fn,m_fn);
}

void SKKFileDic::save_1(FILE *fp,bool okuri)
{
    save_1_orig(fp,okuri);
    save_1_learn(fp,okuri);
}

void SKKFileDic::save_1_orig(FILE *fp,bool okuri)
{
    std::map<int,int>::iterator i;
    for ( i = off_map.begin() ; i != off_map.end() ; i++){
	SKKDicEnt *e;
	e = make_dic_ent_from_line(&m_ptr[(*i).second]);
	if ( e && (( e->okuri_head && okuri ) || (!e->okuri_head && !okuri))){
	    if ( !get_dic_ent_from_learn(&e->head,e->okuri_head,e->hash)){
		// need to save
		jstring_t s;
		get_dic_ent_str(e,&s);
		delete e;
		s.push_back('\n');
		fprint_jstring(fp,&s);
	    }
	}
    }
}

void SKKFileDic::save_1_learn(FILE *fp,bool okuri)
{
    std::map<int,SKKDicEnt *>::iterator i;
    for ( i = learn_map.begin() ; i!= learn_map.end() ;i++){
	SKKDicEnt *e = (*i).second;
	if ( (e->okuri_head && okuri ) || (!e->okuri_head && !okuri)){
	    jstring_t s;
	    get_dic_ent_str(e,&s);
	    s.push_back('\n');
	    fprint_jstring(fp,&s);
	}
    }
}

bool SKKFileDic::do_map(char *fn)
{
    int fd;
    m_fn = strdup(fn);
    fd = open ( m_fn , O_RDONLY );
    if ( fd == -1 ){
	// maybe file not exist
	return false;
    }
    struct stat st;
    if ( fstat( fd,&st) == -1){
	perror("Failed to stat dictionary.");
	return false;
    }
    //printf("size = %d bytes.\n",st.st_size);
    m_size = st.st_size;
    void *ptr;
    ptr = mmap(0,st.st_size,PROT_READ,MAP_SHARED,fd,0);
    if ( ptr == MAP_FAILED ){
	perror("Failed to mmap dictionary.");
	return false;
    }
    m_ptr = (char *)ptr;
    close(fd);
    return true;
}

bool SKKFileDic::make_hash()
{
    int index = 0;
    int nr_ent =0;
    m_col = 0;
    while( index >= 0 && index < m_size){
	if ( m_ptr[index] != ';'){
	    nr_ent ++;
	    add_hash_entry(index,&m_ptr[index]);
	}
	index = next_line(index);
    }
    return true;
}

int SKKFileDic::next_line(int i)
{
    while ( i < m_size && m_ptr[i] != '\n'){
	i++;
    }
    if ( i >= m_size-1 ){
	return -1;
    }
    return i+1;
}

void SKKFileDic::add_hash_entry(int of,char *s)
{
    int hash = calc_str_hash(s);
    std::map<int,int>::iterator i;
    while(1){
	i = off_map.find(hash);
	if ( i == off_map.end()){
	    std::pair<int,int> p(hash,of);
	    off_map.insert(p);
	    return ;
	}else{
	    hash ++;
	    m_col ++;
	}
    }
}

int SKKFileDic::compare_jstring_with_str(jstring_t *sy,char *str)
{
    int i;
    for ( i = 0 ; str[i] && str[i] != ' ' ; i++);
    char *buf=(char *)alloca(i+2);
    strncpy(buf,str,i);
    buf[i] = 0;

    jstring_t s;
    str_to_jstring(&s,buf);
    if ( s == *sy ){
	return 0;
    }
    return 1;
}

void SKKFileDic::add_to_learn_ent(SKKStat *st,SKKDicEnt *e)
{
    jstring_t sy;
    make_dic_key_str(st,&sy);
    int h = calc_hash(&sy);

    e->cached = true;
    e->head = st->head;
    e->okuri_head = st->okuri_head;

    while(1){
	std::map<int,SKKDicEnt *>::iterator i;
	i = learn_map.find(h);
	if ( i == learn_map.end()){
	    std::pair<int,SKKDicEnt *> p(h,e);
	    learn_map.insert(p);
	    e->hash = h;
	    return ;
	}
	h++;
    }
}

void SKKDicEnt::commit(SKKStat *st,jstring_t *s)
{
    if ( st->okuri.size()){
	// 送り仮名のエントリに対して学習
	std::list<SKKOkuriEnt>::iterator i;
	for ( i = okuri.begin() ; i != okuri.end() ; i++){
	    if ( (*i).okuri == st->okuri ){
		jstring_unique_prepend(&((*i).head),s);
		goto rest;
	    }
	}
	// 新しい送り仮名のエントリを追加
	SKKOkuriEnt e;
	e.okuri = st->okuri;
	e.head.push_back(*s);
	okuri.push_back(e);
    }
rest:
    // 送り仮名で無い部分を追加
    jstring_unique_prepend(&okurinasi,s);
}
/*
 * Local variables:
 *  c-indent-level: 4
 *  c-basic-offset: 4
 * End:
 */


syntax highlighted by Code2HTML, v. 0.9.1