#include "cxx_codegen.hh"
#include "codegen.hh"
#include "ast.hh"
#include "symbol_checking.hh"
#include "type_checking.hh"
#include <fstream>
#include <sstream>
#include <set>
#include <cassert>
using namespace ast;
using namespace std;
using namespace symbol_checking;
using namespace type_checking;
using namespace codegen;
namespace codegen {
void cxx_generate_fun(Fun &fun, ostream &os = std::cout);
};
static const char *
translate_basic_type(basic_type_t bt)
{
switch (bt)
{
case TYPE_CHAR:
return "char";
break;
case TYPE_INT:
return "int";
break;
case TYPE_FLOAT:
return "double";
break;
default:
assert(false);
}
}
void
CXXCodeGen::visit(BinRelLT &)
{
_cc_os << " < ";
}
void
CXXCodeGen::visit(BinRelLTE &)
{
_cc_os << " <= ";
}
void
CXXCodeGen::visit(BinRelGT &)
{
_cc_os << " > ";
}
void
CXXCodeGen::visit(BinRelGTE &)
{
_cc_os << " >= ";
}
void
CXXCodeGen::visit(BinRelEQ &)
{
_cc_os << " == ";
}
void
CXXCodeGen::visit(BinRelNEQ &)
{
_cc_os << " != ";
}
void
CXXCodeGen::visit(RExpr &rexpr)
{
_cc_os << "for (int " << rexpr.index() << " = ";
rexpr.begin()->accept(*this);
if (! rexpr.begin_br()->include())
{
if (rexpr.begin_br()->is_less())
_cc_os << " + 1";
else
_cc_os << " - 1";
}
_cc_os << "; " << rexpr.index();
rexpr.end_br()->accept(*this);
rexpr.end()->accept(*this);
if (rexpr.begin_br()->is_less())
_cc_os << "; ++" << rexpr.index() << ")\n";
else
_cc_os << "; --" << rexpr.index() << ")\n";
}
void
CXXCodeGen::visit(Range &range)
{
cxx_generate_fun(range, _cc_os);
}
void
CXXCodeGen::visit(Values &values)
{
cxx_generate_fun(values, _cc_os);
}
void
CXXCodeGen::visit(SimpleFun &simple_fun)
{
_cc_os << "int _val_" << simple_fun.ast_id()
<< " = ";
simple_fun.expr()->accept(*this);
_cc_os << ";\n";
}
void
CXXCodeGen::visit(IDExpr &id_expr)
{
_cc_os << id_expr.id();
}
void
CXXCodeGen::visit(IntegerExpr &integer_expr)
{
_cc_os << integer_expr.integer();
}
void
CXXCodeGen::visit(MatrixExpr &matrix_expr)
{
emit_matrix_index(matrix_expr);
}
void
CXXCodeGen::visit(FunCallExpr &matrix_expr)
{
_cc_os << matrix_expr.id() << '(';
const list<Expr*> *expr_list = matrix_expr.expr_list();
list<Expr*>::const_iterator i = expr_list->begin();
bool first = true;
for (i = expr_list->begin(); i != expr_list->end(); ++i)
{
if (!first) _cc_os << ','; first = false;
(*i)->accept(*this);
}
_cc_os << ')';
}
void
CXXCodeGen::visit(BinOpExpr &binop_expr)
{
_cc_os << '(';
binop_expr.left()->accept(*this);
_cc_os << ' ' << binop_expr.binop() << ' ';
binop_expr.right()->accept(*this);
_cc_os << ')';
}
void
CXXCodeGen::visit(NEGExpr &neg_expr)
{
_cc_os << '-';
neg_expr.expr()->accept(*this);
}
void
CXXCodeGen::visit(RelBExpr &bexpr)
{
_cc_os << '(';
bexpr.left()->accept(*this);
bexpr.rel()->accept(*this);
bexpr.right()->accept(*this);
_cc_os << ')';
}
void
CXXCodeGen::visit(ANDBExpr &bexpr)
{
_cc_os << '(';
bexpr.left()->accept(*this);
_cc_os << " && ";
bexpr.right()->accept(*this);
_cc_os << ')';
}
void
CXXCodeGen::visit(ORBExpr &bexpr)
{
_cc_os << '(';
bexpr.left()->accept(*this);
_cc_os << " || ";
bexpr.right()->accept(*this);
_cc_os << ')';
}
void
CXXCodeGen::visit(NOTBExpr &bexpr)
{
_cc_os << '!';
bexpr.bexpr()->accept(*this);
}
void
CXXCodeGen::visit(TrueExpr &bexpr)
{
_cc_os << "1";
}
// ---=== FUN GENERATOR ==================================---
#include "functions.hh"
namespace codegen {
class CXXFunGenerator : public CXXCodeGen {
Fun &_fun;
functions::FunctionWrapper &_wrapper;
public:
CXXFunGenerator(Fun &fun, ostream &hh_os, ostream &cc_os)
: CXXCodeGen(hh_os, cc_os), _fun(fun),
_wrapper(*(functions::get_function(fun.id())))
{}
virtual void visit(Range &range);
virtual void visit(Values &values);
virtual void visit(WhereExpr &where_expr);
virtual void visit(WhenExpr &when_expr);
};
class CXXSelectGenerator : public CXXCodeGen {
Fun &_fun;
public:
CXXSelectGenerator(Fun &fun, ostream &hh_os, ostream &cc_os)
: CXXCodeGen(hh_os, cc_os), _fun(fun)
{}
virtual void visit(Values &values);
virtual void visit(WhenExpr &when_expr);
};
};
void
CXXFunGenerator::visit(Range &range)
{
assert(&_fun == &range);
_cc_os << "int _val_" << _fun.ast_id()
<< " = " << _wrapper.neutral_element()
<< ";\n";
range.where_expr()->accept(*this);
}
void
CXXFunGenerator::visit(Values &values)
{
assert(&_fun == &values);
_cc_os << "int _val_" << _fun.ast_id()
<< " = " << _wrapper.neutral_element()
<< ";\n";
list<WhenExpr*> *when_exprs = values.when_expr_list();
list<WhenExpr*>::iterator i;
for (i = when_exprs->begin(); i != when_exprs->end(); ++i)
(*i)->accept(*this);
}
void
CXXFunGenerator::visit(WhereExpr &where_expr)
{
list<RExpr*> *rexpr_list = where_expr.rexpr_list();
list<RExpr*>::const_iterator i;
for (i = rexpr_list->begin(); i != rexpr_list->end(); ++i)
{
(*i)->accept(*this);
_cc_os << "{\n";
}
CXXCodeGen cgen(_hh_os, _cc_os);
ostringstream acc;
acc << "_val_" << _fun.ast_id();
ostringstream newval;
newval << "_val_" << where_expr.fun()->ast_id();
where_expr.fun()->accept(cgen);
_cc_os << "_val_" << _fun.ast_id()
<< " = " << _wrapper.combine(acc.str(), newval.str())
<< ";\n";
for (i = rexpr_list->begin(); i != rexpr_list->end(); ++i)
{
_cc_os << "}\n";
}
}
void
CXXFunGenerator::visit(WhenExpr &when_expr)
{
CXXCodeGen cgen(_hh_os, _cc_os);
ostringstream acc;
acc << "_val_" << _fun.ast_id();
ostringstream newval;
newval << "_val_" << when_expr.fun()->ast_id();
_cc_os << "if (";
when_expr.bexpr()->accept(cgen);
_cc_os << ")\n{\n";
when_expr.fun()->accept(cgen);
_cc_os << "_val_" << _fun.ast_id()
<< " = " << _wrapper.combine(acc.str(), newval.str())
<< ";\n}\n";
}
void
CXXSelectGenerator::visit(Values &values)
{
_cc_os << "int _val_" << values.ast_id() << ";\n";
CXXCodeGen cgen(_hh_os, _cc_os);
list<WhenExpr*> *when_exprs = values.when_expr_list();
list<WhenExpr*>::iterator i = when_exprs->begin();
_cc_os << "if (";
(*i)->bexpr()->accept(cgen);
_cc_os << ")\n{\n";
(*i)->fun()->accept(cgen);
_cc_os << "_val_" << values.ast_id()
<< " = _val_" << (*i)->fun()->ast_id()
<< ";\n}\n";
for (i++ ; i != when_exprs->end(); ++i)
(*i)->accept(*this);
_cc_os << "else {\n"
<< "assert(0);\n"
<< "}\n";
}
void
CXXSelectGenerator::visit(WhenExpr &when_expr)
{
CXXCodeGen cgen(_hh_os, _cc_os);
_cc_os << "else if (";
when_expr.bexpr()->accept(cgen);
_cc_os << ")\n{\n";
when_expr.fun()->accept(cgen);
_cc_os << "_val_" << _fun.ast_id()
<< " = _val_" << when_expr.fun()->ast_id()
<< ";\n}\n";
}
void
codegen::cxx_generate_fun(Fun &fun, ostream &cc_os)
{
// handle 'select' special
static string select = "select";
if (fun.id() == select)
{
CXXSelectGenerator sgen(fun, std::cout, cc_os);
fun.accept(sgen);
return;
}
// other functions
CXXFunGenerator fungen(fun, std::cout, cc_os);
fun.accept(fungen);
}
// ---=== MAIN GENERATOR =================================---
void
CXXCodeGen::emit_matrix_index(Update &update)
{
// FIXME: refactor this
const list<const char *> *indices = update.indices();
int arity = indices->size();
switch (arity)
{
case 1:
_cc_os << update.name() << '[';
_cc_os << *(indices->begin());
_cc_os << ']';
break;
default:
_cc_os << update.name() << ".cell(";
list<const char *>::const_iterator j = indices->begin();
_cc_os << *j;
for (j++; j != indices->end(); ++j)
_cc_os << ", " << *j;
_cc_os << ')';
}
}
void
CXXCodeGen::emit_matrix_index(MatrixExpr &matrix_expr)
{
// FIXME: maybe refactor this
const list<Expr*> *expr_list = matrix_expr.expr_list();
int arity = expr_list->size();
switch (arity)
{
case 1:
_cc_os << matrix_expr.id() << '[';
(*(expr_list->begin()))->accept(*this);
_cc_os << ']';
break;
default:
_cc_os << matrix_expr.id() << ".cell(";
list<Expr*>::const_iterator j = expr_list->begin();
(*j)->accept(*this);
for (j++; j != expr_list->end(); ++j)
{
_cc_os << ", ";
(*j)->accept(*this);
}
_cc_os << ')';
}
}
void
CXXCodeGen::visit(Update &update)
{
update.fun()->accept(*this);
emit_matrix_index(update);
_cc_os << " = _val_" << update.fun()->ast_id() << ";\n";
}
namespace codegen {
class CXXGenerator : public type_checking::TypeVisitor {
protected:
ostream &_os;
CXXGenerator(ostream &os) : _os(os) {};
void print (const ValType &val_type);
void print (const IndexType &index_type);
void print (const MatrixType &matrix_type, bool ref);
void print (const FunctionType &function_type);
};
class CXXParameterGenerator : public CXXGenerator {
bool _first;
public:
CXXParameterGenerator(ostream &os) :
CXXGenerator(os), _first(true)
{}
void comma_if_not_first()
{
if (_first) _first = false; // toggle flag
else _os << ", ";
}
virtual void visit(MatrixType &matrix_type)
{ comma_if_not_first(); print(matrix_type, true); }
virtual void visit(FunctionType &function_type)
{ /* FIXME: NOT SUPPORTED RIGHT NOW */ }
virtual void visit(ValType &val_type)
{ comma_if_not_first(); print(val_type); }
virtual void visit(IndexType &index_type)
{ comma_if_not_first(); print(index_type.name()); }
};
class CXXGlobalGenerator : public CXXGenerator {
public:
CXXGlobalGenerator(ostream &os) : CXXGenerator(os) {}
virtual void visit(MatrixType &matrix_type)
{ _os << "extern "; print(matrix_type, false); _os << ";\n"; }
virtual void visit(FunctionType &function_type)
{ _os << "extern "; print(function_type); _os << ";\n"; }
virtual void visit(ValType &val_type)
{ _os << "extern "; print(val_type); _os << ";\n"; }
virtual void visit(IndexType &index_type)
{ _os << "extern "; print(index_type); _os << ";\n"; }
};
class CXXLocalGenerator : public CXXGenerator {
CXXCodeGen &_cgen;
public:
CXXLocalGenerator(ostream &os, CXXCodeGen &cgen)
: CXXGenerator(os), _cgen(cgen) {}
virtual void visit(MatrixType &matrix_type);
virtual void visit(FunctionType &function_type)
{ /* FIXME: not supported yet */ }
virtual void visit(ValType &val_type)
{ /* FIXME: not supported yet */ }
virtual void visit(IndexType &index_type)
{ /* FIXME: not supported yet */ }
};
};
void
CXXGenerator::print (const ValType &val_type)
{
_os << translate_basic_type(val_type.type()) << ' ' << val_type.name();
}
void
CXXGenerator::print (const IndexType &index_type)
{
_os << "unsigned int " << index_type.name();
}
void
CXXGenerator::print (const MatrixType &matrix_type, bool ref)
{
switch (matrix_type.arity())
{
case 1:
_os << translate_basic_type(matrix_type.cell_type())
<< ' ' << matrix_type.name() << "[]";
break;
case 2:
_os << "dprog::Matrix2<"
<< translate_basic_type(matrix_type.cell_type())
<< "> " << (ref ? '&' : ' ') << matrix_type.name();
break;
default:
_os << "dprog::Matrix<"
<< translate_basic_type(matrix_type.cell_type())
<< "> " << (ref ? '&' : ' ') << matrix_type.name();
break;
}
}
void
CXXGenerator::print (const FunctionType &function_type)
{
// FIXME: this only works for external declarations!
_os << translate_basic_type(function_type.return_type()) << ' '
<< function_type.name() << " (";
const list<basic_type_t> *par_types = function_type.parameter_types();
list<basic_type_t>::const_iterator i;
const char *sep = "";
for (i = par_types->begin(); i != par_types->end(); ++i)
{
_os << sep << translate_basic_type(*i);
sep = ", ";
}
_os << ')';
}
void
CXXCodeGen::print_parameters(ostream &os)
{
os << "void\n" << _method_name << " (";
CXXParameterGenerator pgen(os);
set<TypeInfo*> *symbols = symbol_checking::parameters();
set<TypeInfo*>::iterator i;
for (i = symbols->begin(); i != symbols->end(); ++i)
(*i)->accept(pgen);
os << ')';
}
void
CXXCodeGen::print_globals(std::ostream &os)
{
CXXGlobalGenerator gen(os);
set<TypeInfo*> *symbols = symbol_checking::globals();
set<TypeInfo*>::iterator i;
for (i = symbols->begin(); i != symbols->end(); ++i)
(*i)->accept(gen);
}
void
CXXLocalGenerator::visit(MatrixType &matrix_type)
{
const list<Expr*> *dims = matrix_type.dim_exprs();
list<Expr*>::const_iterator i;
const char *sep = "";
switch (matrix_type.arity())
{
case 1:
_os << translate_basic_type(matrix_type.cell_type())
<< ' '
<< matrix_type.name()
<< '[';
dims->front()->accept(_cgen);
_os << "];\n";
break;
case 2:
_os << "Matrix2< "
<< translate_basic_type(matrix_type.cell_type())
<< " > "
<< matrix_type.name()
<< '(';
sep = "";
for (i = dims->begin(); i != dims->end(); ++i)
{
_os << sep;
(*i)->accept(_cgen);
sep = ", ";
}
_os << ");\n";
break;
default:
_os << "Matrix< "
<< translate_basic_type(matrix_type.cell_type())
<< " > "
<< matrix_type.name()
<< '('
<< matrix_type.arity();
sep = ", ";
for (i = dims->begin(); i != dims->end(); ++i)
{
_os << sep;
(*i)->accept(_cgen);
}
_os << ");\n";
break;
}
}
void
CXXCodeGen::print_locals(ostream &os)
{
CXXLocalGenerator gen(os, *this);
set<TypeInfo*> *symbols = symbol_checking::locals();
set<TypeInfo*>::iterator i;
for (i = symbols->begin(); i != symbols->end(); ++i)
(*i)->accept(gen);
}
void
CXXCodeGen::visit(DProg &dprog)
{
const list<Update*> *update_list = dprog.update_list();
list<Update*>::const_iterator u_itr;
const list<RExpr*> *rexpr_list = dprog.rexpr_list();
list<RExpr*>::const_iterator re_itr;
_hh_os << "#ifndef DPROG_" << _method_name << "_HH\n"
<< "#define DPROG_" << _method_name << "_HH\n\n"
<< "#include <cxx_dprog.hh>\n\n";
print_parameters(_hh_os);
_hh_os << ";\n\n#endif\n";
_cc_os << "#include <cxx_dprog.hh>\n";
print_globals(_cc_os);
print_parameters(_cc_os);
_cc_os << "\n{\n";
print_locals(_cc_os);
for (re_itr = rexpr_list->begin(); re_itr != rexpr_list->end(); ++re_itr)
{
(*re_itr)->accept(*this);
_cc_os << "{\n";
}
for (u_itr = update_list->begin(); u_itr != update_list->end(); ++u_itr)
(*u_itr)->accept(*this);
for (re_itr = rexpr_list->begin(); re_itr != rexpr_list->end(); ++re_itr)
{
_cc_os << "}\n";
}
_cc_os << "}\n";
}
void
codegen::cxx_emit_code(DProg *dprog, const char *output_name)
{
if (output_name == 0)
{
// use stdout
ofstream hh_os("/dev/null");
CXXCodeGen cgen(hh_os, std::cout);
dprog->accept(cgen);
hh_os.close();
}
else
{
string hh_file = string(output_name)+".hh";
string cc_file = string(output_name)+".cc";
ofstream hh_os(hh_file.c_str());
ofstream cc_os(cc_file.c_str());
CXXCodeGen cgen(hh_os, cc_os, output_name);
dprog->accept(cgen);
hh_os.close();
cc_os.close();
}
}
syntax highlighted by Code2HTML, v. 0.9.1