#include "cxx_codegen.hh"
#include "codegen.hh"
#include "ast.hh"
#include "symbol_checking.hh"
#include "type_checking.hh"

#include <fstream>
#include <sstream>
#include <set>
#include <cassert>

using namespace ast;
using namespace std;
using namespace symbol_checking;
using namespace type_checking;
using namespace codegen;

namespace codegen {
    void cxx_generate_fun(Fun &fun, ostream &os = std::cout);
};

static const char *
translate_basic_type(basic_type_t bt)
{
    switch (bt)
	{
	case TYPE_CHAR:
	    return "char";
	    break;

	case TYPE_INT:
	    return "int";
	    break;

	case TYPE_FLOAT:
	    return "double";
	    break;

	default:
	    assert(false);
	}
}


void
CXXCodeGen::visit(BinRelLT &)
{
    _cc_os << " < ";
}

void
CXXCodeGen::visit(BinRelLTE &)
{
    _cc_os << " <= ";
}

void
CXXCodeGen::visit(BinRelGT &)
{
    _cc_os << " > ";
}

void
CXXCodeGen::visit(BinRelGTE &)
{
    _cc_os << " >= ";
}

void
CXXCodeGen::visit(BinRelEQ &)
{
    _cc_os << " == ";
}

void
CXXCodeGen::visit(BinRelNEQ &)
{
    _cc_os << " != ";
}

void
CXXCodeGen::visit(RExpr &rexpr)
{
    _cc_os << "for (int " << rexpr.index() << " = ";
    rexpr.begin()->accept(*this);
    if (! rexpr.begin_br()->include())
	{
	    if (rexpr.begin_br()->is_less())
		_cc_os << " + 1";
	    else
		_cc_os << " - 1";
	}
    _cc_os << "; " << rexpr.index();
    rexpr.end_br()->accept(*this);
    rexpr.end()->accept(*this);
    if (rexpr.begin_br()->is_less())
	_cc_os << "; ++" << rexpr.index() << ")\n";
    else
	_cc_os << "; --" << rexpr.index() << ")\n";
}

void
CXXCodeGen::visit(Range &range)
{
    cxx_generate_fun(range, _cc_os);
}

void
CXXCodeGen::visit(Values &values)
{
    cxx_generate_fun(values, _cc_os);
}

void
CXXCodeGen::visit(SimpleFun &simple_fun)
{
    _cc_os << "int _val_" << simple_fun.ast_id()
	   << " = ";

    simple_fun.expr()->accept(*this);

    _cc_os << ";\n";
}

void
CXXCodeGen::visit(IDExpr &id_expr)
{
    _cc_os << id_expr.id();
}

void
CXXCodeGen::visit(IntegerExpr &integer_expr)
{
    _cc_os << integer_expr.integer();
}

void
CXXCodeGen::visit(MatrixExpr &matrix_expr)
{
    emit_matrix_index(matrix_expr);
}

void
CXXCodeGen::visit(FunCallExpr &matrix_expr)
{
    _cc_os << matrix_expr.id() << '(';
    const list<Expr*> *expr_list = matrix_expr.expr_list();
    list<Expr*>::const_iterator i = expr_list->begin();
    bool first = true;
    for (i = expr_list->begin(); i != expr_list->end(); ++i)
	{
	    if (!first) _cc_os << ','; first = false;
	    (*i)->accept(*this);
	}
    _cc_os << ')';
}


void
CXXCodeGen::visit(BinOpExpr &binop_expr)
{
    _cc_os << '(';
    binop_expr.left()->accept(*this);
    _cc_os << ' ' << binop_expr.binop() << ' ';
    binop_expr.right()->accept(*this);
    _cc_os << ')';
}

void
CXXCodeGen::visit(NEGExpr &neg_expr)
{
    _cc_os << '-';
    neg_expr.expr()->accept(*this);
}


void
CXXCodeGen::visit(RelBExpr &bexpr)
{
    _cc_os << '(';
    bexpr.left()->accept(*this);
    bexpr.rel()->accept(*this);
    bexpr.right()->accept(*this);
    _cc_os << ')';
}

void
CXXCodeGen::visit(ANDBExpr &bexpr)
{
    _cc_os << '(';
    bexpr.left()->accept(*this);
    _cc_os << " && ";
    bexpr.right()->accept(*this);
    _cc_os << ')';
}

void
CXXCodeGen::visit(ORBExpr &bexpr)
{
    _cc_os << '(';
    bexpr.left()->accept(*this);
    _cc_os << " || ";
    bexpr.right()->accept(*this);
    _cc_os << ')';
}

void
CXXCodeGen::visit(NOTBExpr &bexpr)
{
    _cc_os << '!';
    bexpr.bexpr()->accept(*this);
}

void
CXXCodeGen::visit(TrueExpr &bexpr)
{
    _cc_os << "1";
}


// ---=== FUN GENERATOR ==================================---

#include "functions.hh"

namespace codegen {
    class CXXFunGenerator : public CXXCodeGen {
	Fun &_fun;
	functions::FunctionWrapper &_wrapper;

    public:
	CXXFunGenerator(Fun &fun, ostream &hh_os, ostream &cc_os)
	    : CXXCodeGen(hh_os, cc_os), _fun(fun),
	      _wrapper(*(functions::get_function(fun.id())))
	{}

	virtual void visit(Range       &range);
	virtual void visit(Values      &values);

	virtual void visit(WhereExpr   &where_expr);
	virtual void visit(WhenExpr    &when_expr);
    };

    class CXXSelectGenerator : public CXXCodeGen {
	Fun &_fun;

    public:
	CXXSelectGenerator(Fun &fun, ostream &hh_os, ostream &cc_os)
	    : CXXCodeGen(hh_os, cc_os), _fun(fun)
	{}

	virtual void visit(Values      &values);
	virtual void visit(WhenExpr    &when_expr);
    };
};

void
CXXFunGenerator::visit(Range &range)
{
    assert(&_fun == &range);

    _cc_os << "int _val_" << _fun.ast_id()
	   << " = " << _wrapper.neutral_element()
	   << ";\n";

    range.where_expr()->accept(*this);
}

void
CXXFunGenerator::visit(Values &values)
{
    assert(&_fun == &values);

    _cc_os << "int _val_" << _fun.ast_id()
	   << " = " << _wrapper.neutral_element()
	   << ";\n";

    list<WhenExpr*> *when_exprs = values.when_expr_list();
    list<WhenExpr*>::iterator i;
    for (i = when_exprs->begin(); i != when_exprs->end(); ++i)
	(*i)->accept(*this);
}

void
CXXFunGenerator::visit(WhereExpr &where_expr)
{

    list<RExpr*> *rexpr_list = where_expr.rexpr_list();
    list<RExpr*>::const_iterator i;
    for (i = rexpr_list->begin(); i != rexpr_list->end(); ++i)
	{
	    (*i)->accept(*this);
	    _cc_os << "{\n";
	}

    CXXCodeGen cgen(_hh_os, _cc_os);
    ostringstream acc;
    acc << "_val_" << _fun.ast_id();
    ostringstream newval;
    newval << "_val_" << where_expr.fun()->ast_id();

    where_expr.fun()->accept(cgen);

    _cc_os << "_val_" << _fun.ast_id()
	   << " = " << _wrapper.combine(acc.str(), newval.str())
	   << ";\n";

    for (i = rexpr_list->begin(); i != rexpr_list->end(); ++i)
	{
	    _cc_os << "}\n";
	}
}

void
CXXFunGenerator::visit(WhenExpr &when_expr)
{
    CXXCodeGen cgen(_hh_os, _cc_os);
    ostringstream acc;
    acc << "_val_" << _fun.ast_id();
    ostringstream newval;
    newval << "_val_" << when_expr.fun()->ast_id();

    _cc_os << "if (";
    when_expr.bexpr()->accept(cgen);
    _cc_os << ")\n{\n";
    when_expr.fun()->accept(cgen);
    _cc_os << "_val_" << _fun.ast_id()
	   << " = " << _wrapper.combine(acc.str(), newval.str())
	   << ";\n}\n";
}

void
CXXSelectGenerator::visit(Values &values)
{
    _cc_os << "int _val_" << values.ast_id() << ";\n";

    CXXCodeGen cgen(_hh_os, _cc_os);
    list<WhenExpr*> *when_exprs = values.when_expr_list();
    list<WhenExpr*>::iterator i = when_exprs->begin();

    _cc_os << "if (";
    (*i)->bexpr()->accept(cgen);
    _cc_os << ")\n{\n";
    (*i)->fun()->accept(cgen);
    _cc_os << "_val_" << values.ast_id()
	   << " = _val_" << (*i)->fun()->ast_id()
	   << ";\n}\n";

    for (i++ ; i != when_exprs->end(); ++i)
	(*i)->accept(*this);

    _cc_os << "else {\n"
	   << "assert(0);\n"
	   << "}\n";
}

void
CXXSelectGenerator::visit(WhenExpr &when_expr)
{
    CXXCodeGen cgen(_hh_os, _cc_os);

    _cc_os << "else if (";
    when_expr.bexpr()->accept(cgen);
    _cc_os << ")\n{\n";
    when_expr.fun()->accept(cgen);
    _cc_os << "_val_" << _fun.ast_id()
	   << " = _val_" << when_expr.fun()->ast_id()
	   << ";\n}\n";
}

void
codegen::cxx_generate_fun(Fun &fun, ostream &cc_os)
{
    // handle 'select' special
    static string select = "select";
    if (fun.id() == select)
	{
	    CXXSelectGenerator sgen(fun, std::cout, cc_os);
	    fun.accept(sgen);
	    return;
	}

    // other functions
    CXXFunGenerator fungen(fun, std::cout, cc_os);
    fun.accept(fungen);
}

// ---=== MAIN GENERATOR =================================---

void
CXXCodeGen::emit_matrix_index(Update &update)
{
    // FIXME: refactor this

    const list<const char *> *indices = update.indices();
    int arity = indices->size();
    switch (arity)
	{
	case 1:
	    _cc_os << update.name() << '[';
	    _cc_os << *(indices->begin());
	    _cc_os << ']';
	    break;

	default:
	    _cc_os << update.name() << ".cell(";
	    list<const char *>::const_iterator j = indices->begin();
	    _cc_os << *j;
	    for (j++; j != indices->end(); ++j)
		_cc_os << ", " << *j;
	    _cc_os << ')';
	}
}

void
CXXCodeGen::emit_matrix_index(MatrixExpr &matrix_expr)
{
    // FIXME: maybe refactor this

    const list<Expr*> *expr_list = matrix_expr.expr_list();
    int arity = expr_list->size();
    switch (arity)
	{
	case 1:
	    _cc_os << matrix_expr.id() << '[';
	    (*(expr_list->begin()))->accept(*this);
	    _cc_os << ']';
	    break;

	default:
	    _cc_os << matrix_expr.id() << ".cell(";
	    list<Expr*>::const_iterator j = expr_list->begin();
	    (*j)->accept(*this);
	    for (j++; j != expr_list->end(); ++j)
		{
		    _cc_os << ", ";
		    (*j)->accept(*this);
		}
	    _cc_os << ')';
	}
}


void
CXXCodeGen::visit(Update &update)
{
    update.fun()->accept(*this);

    emit_matrix_index(update);
    _cc_os << " = _val_" << update.fun()->ast_id() << ";\n";
}

namespace codegen {

    class CXXGenerator : public type_checking::TypeVisitor {
    protected:
	ostream &_os;
	CXXGenerator(ostream &os) : _os(os) {};

	void print (const ValType    	&val_type);
	void print (const IndexType  	&index_type);
	void print (const MatrixType 	&matrix_type, bool ref);
	void print (const FunctionType 	&function_type);
    };

    class CXXParameterGenerator : public CXXGenerator {
	bool _first;
    public:
	CXXParameterGenerator(ostream &os) :
	    CXXGenerator(os), _first(true)
	{}

	void comma_if_not_first()
	{
	    if (_first) _first = false; // toggle flag
	    else _os << ", ";
	}

	virtual void visit(MatrixType &matrix_type)
	{ comma_if_not_first(); print(matrix_type, true); }
	
	virtual void visit(FunctionType &function_type)
	{ /* FIXME: NOT SUPPORTED RIGHT NOW */ }
	
	virtual void visit(ValType &val_type)
	{ comma_if_not_first(); print(val_type); }

	virtual void visit(IndexType &index_type)
	{ comma_if_not_first(); print(index_type.name()); }
    };

    class CXXGlobalGenerator : public CXXGenerator {
    public:
	CXXGlobalGenerator(ostream &os) : CXXGenerator(os) {}

	virtual void visit(MatrixType &matrix_type)
	{ _os << "extern "; print(matrix_type, false); _os << ";\n"; }
	
	virtual void visit(FunctionType &function_type)
	{ _os << "extern "; print(function_type); _os << ";\n";  }
	
	virtual void visit(ValType &val_type)
	{ _os << "extern "; print(val_type); _os << ";\n";  }

	virtual void visit(IndexType &index_type)
	{ _os << "extern "; print(index_type); _os << ";\n"; }
    };


    class CXXLocalGenerator : public CXXGenerator {
	CXXCodeGen &_cgen;
    public:
	CXXLocalGenerator(ostream &os, CXXCodeGen &cgen)
	    : CXXGenerator(os), _cgen(cgen) {}

	virtual void visit(MatrixType &matrix_type);
	
	virtual void visit(FunctionType &function_type)
	{ /* FIXME: not supported yet */  }
	
	virtual void visit(ValType &val_type)
	{ /* FIXME: not supported yet */  }

	virtual void visit(IndexType &index_type)
	{ /* FIXME: not supported yet */  }
    };
};

void
CXXGenerator::print (const ValType    &val_type)
{
    _os << translate_basic_type(val_type.type()) << ' ' << val_type.name();
}

void
CXXGenerator::print (const IndexType    &index_type)
{
    _os << "unsigned int " << index_type.name();
}


void
CXXGenerator::print (const MatrixType &matrix_type, bool ref)
{
    switch (matrix_type.arity())
	{
	case 1:
	    _os << translate_basic_type(matrix_type.cell_type())
		<<  ' ' << matrix_type.name() << "[]";
	    break;

	case 2:
	    _os << "dprog::Matrix2<"
		<< translate_basic_type(matrix_type.cell_type())
		<< "> " << (ref ? '&' : ' ') <<  matrix_type.name();
	    break;

	default:
	    _os << "dprog::Matrix<"
		<< translate_basic_type(matrix_type.cell_type())
		<< "> " << (ref ? '&' : ' ') <<  matrix_type.name();
	    break;
	}
}

void
CXXGenerator::print (const FunctionType &function_type)
{
    // FIXME: this only works for external declarations!
    _os << translate_basic_type(function_type.return_type()) << ' '
	<< function_type.name() << " (";

    const list<basic_type_t> *par_types = function_type.parameter_types();
    list<basic_type_t>::const_iterator i;
    const char *sep = "";
    for (i = par_types->begin(); i != par_types->end(); ++i)
	{
	    _os << sep << translate_basic_type(*i);
	    sep = ", ";
	}
    _os << ')';
}

void
CXXCodeGen::print_parameters(ostream &os)
{
    os << "void\n" << _method_name << " (";
    CXXParameterGenerator pgen(os);
    set<TypeInfo*> *symbols = symbol_checking::parameters();
    set<TypeInfo*>::iterator i;
    for (i = symbols->begin(); i != symbols->end(); ++i)
	(*i)->accept(pgen);
    os << ')';
}


void
CXXCodeGen::print_globals(std::ostream &os)
{
    CXXGlobalGenerator gen(os);
    set<TypeInfo*> *symbols = symbol_checking::globals();
    set<TypeInfo*>::iterator i;
    for (i = symbols->begin(); i != symbols->end(); ++i)
	(*i)->accept(gen);
}

void
CXXLocalGenerator::visit(MatrixType &matrix_type)
{
    const list<Expr*> *dims = matrix_type.dim_exprs();
    list<Expr*>::const_iterator i;
    const char *sep = "";

    switch (matrix_type.arity())
	{
	case 1:
	    _os << translate_basic_type(matrix_type.cell_type())
		<< ' '
		<< matrix_type.name()
		<< '[';
	    dims->front()->accept(_cgen);
	    _os << "];\n";
	    break;

	case 2:
	    _os << "Matrix2< "
		<< translate_basic_type(matrix_type.cell_type())
		<< " > "
		<< matrix_type.name()
		<< '(';
	    sep = "";
	    for (i = dims->begin(); i != dims->end(); ++i)
		{
		    _os << sep;
		    (*i)->accept(_cgen);
		    sep = ", ";
		}
	    _os << ");\n";
	    break;

	default:
	    _os << "Matrix< "
		<< translate_basic_type(matrix_type.cell_type())
		<< " > "
		<< matrix_type.name()
		<< '('
		<< matrix_type.arity();
	    sep = ", ";
	    for (i = dims->begin(); i != dims->end(); ++i)
		{
		    _os << sep;
		    (*i)->accept(_cgen);
		}
	    _os << ");\n";
	    break;
	}
}


void
CXXCodeGen::print_locals(ostream &os)
{
    CXXLocalGenerator gen(os, *this);
    set<TypeInfo*> *symbols = symbol_checking::locals();
    set<TypeInfo*>::iterator i;
    for (i = symbols->begin(); i != symbols->end(); ++i)
	(*i)->accept(gen);
}


void
CXXCodeGen::visit(DProg &dprog)
{
    const list<Update*> *update_list = dprog.update_list();
    list<Update*>::const_iterator u_itr;
    const list<RExpr*> *rexpr_list = dprog.rexpr_list();
    list<RExpr*>::const_iterator re_itr;

    _hh_os << "#ifndef DPROG_" << _method_name << "_HH\n"
	   << "#define DPROG_" << _method_name << "_HH\n\n"
	   << "#include <cxx_dprog.hh>\n\n";
    print_parameters(_hh_os);
    _hh_os << ";\n\n#endif\n";

    _cc_os << "#include <cxx_dprog.hh>\n";

    print_globals(_cc_os);
    print_parameters(_cc_os);
    _cc_os << "\n{\n";

    print_locals(_cc_os);


    for (re_itr = rexpr_list->begin(); re_itr != rexpr_list->end(); ++re_itr)
	{
	    (*re_itr)->accept(*this);
	    _cc_os << "{\n";
	}

    for (u_itr = update_list->begin(); u_itr != update_list->end(); ++u_itr)
	(*u_itr)->accept(*this);

    for (re_itr = rexpr_list->begin(); re_itr != rexpr_list->end(); ++re_itr)
	{
	    _cc_os << "}\n";
	}

    _cc_os << "}\n";
}


void
codegen::cxx_emit_code(DProg *dprog, const char *output_name)
{
    if (output_name == 0)
	{
	    // use stdout
	    ofstream hh_os("/dev/null");
	    CXXCodeGen cgen(hh_os, std::cout);
	    dprog->accept(cgen);
	    hh_os.close();
	}
    else
	{
	    string hh_file = string(output_name)+".hh";
	    string cc_file = string(output_name)+".cc";
	    ofstream hh_os(hh_file.c_str());
	    ofstream cc_os(cc_file.c_str());
	    CXXCodeGen cgen(hh_os, cc_os, output_name);
	    dprog->accept(cgen);
	    hh_os.close();
	    cc_os.close();
	}
}


syntax highlighted by Code2HTML, v. 0.9.1