#include "cxx_codegen.hh" #include "codegen.hh" #include "ast.hh" #include "symbol_checking.hh" #include "type_checking.hh" #include #include #include #include using namespace ast; using namespace std; using namespace symbol_checking; using namespace type_checking; using namespace codegen; namespace codegen { void cxx_generate_fun(Fun &fun, ostream &os = std::cout); }; static const char * translate_basic_type(basic_type_t bt) { switch (bt) { case TYPE_CHAR: return "char"; break; case TYPE_INT: return "int"; break; case TYPE_FLOAT: return "double"; break; default: assert(false); } } void CXXCodeGen::visit(BinRelLT &) { _cc_os << " < "; } void CXXCodeGen::visit(BinRelLTE &) { _cc_os << " <= "; } void CXXCodeGen::visit(BinRelGT &) { _cc_os << " > "; } void CXXCodeGen::visit(BinRelGTE &) { _cc_os << " >= "; } void CXXCodeGen::visit(BinRelEQ &) { _cc_os << " == "; } void CXXCodeGen::visit(BinRelNEQ &) { _cc_os << " != "; } void CXXCodeGen::visit(RExpr &rexpr) { _cc_os << "for (int " << rexpr.index() << " = "; rexpr.begin()->accept(*this); if (! rexpr.begin_br()->include()) { if (rexpr.begin_br()->is_less()) _cc_os << " + 1"; else _cc_os << " - 1"; } _cc_os << "; " << rexpr.index(); rexpr.end_br()->accept(*this); rexpr.end()->accept(*this); if (rexpr.begin_br()->is_less()) _cc_os << "; ++" << rexpr.index() << ")\n"; else _cc_os << "; --" << rexpr.index() << ")\n"; } void CXXCodeGen::visit(Range &range) { cxx_generate_fun(range, _cc_os); } void CXXCodeGen::visit(Values &values) { cxx_generate_fun(values, _cc_os); } void CXXCodeGen::visit(SimpleFun &simple_fun) { _cc_os << "int _val_" << simple_fun.ast_id() << " = "; simple_fun.expr()->accept(*this); _cc_os << ";\n"; } void CXXCodeGen::visit(IDExpr &id_expr) { _cc_os << id_expr.id(); } void CXXCodeGen::visit(IntegerExpr &integer_expr) { _cc_os << integer_expr.integer(); } void CXXCodeGen::visit(MatrixExpr &matrix_expr) { emit_matrix_index(matrix_expr); } void CXXCodeGen::visit(FunCallExpr &matrix_expr) { _cc_os << matrix_expr.id() << '('; const list *expr_list = matrix_expr.expr_list(); list::const_iterator i = expr_list->begin(); bool first = true; for (i = expr_list->begin(); i != expr_list->end(); ++i) { if (!first) _cc_os << ','; first = false; (*i)->accept(*this); } _cc_os << ')'; } void CXXCodeGen::visit(BinOpExpr &binop_expr) { _cc_os << '('; binop_expr.left()->accept(*this); _cc_os << ' ' << binop_expr.binop() << ' '; binop_expr.right()->accept(*this); _cc_os << ')'; } void CXXCodeGen::visit(NEGExpr &neg_expr) { _cc_os << '-'; neg_expr.expr()->accept(*this); } void CXXCodeGen::visit(RelBExpr &bexpr) { _cc_os << '('; bexpr.left()->accept(*this); bexpr.rel()->accept(*this); bexpr.right()->accept(*this); _cc_os << ')'; } void CXXCodeGen::visit(ANDBExpr &bexpr) { _cc_os << '('; bexpr.left()->accept(*this); _cc_os << " && "; bexpr.right()->accept(*this); _cc_os << ')'; } void CXXCodeGen::visit(ORBExpr &bexpr) { _cc_os << '('; bexpr.left()->accept(*this); _cc_os << " || "; bexpr.right()->accept(*this); _cc_os << ')'; } void CXXCodeGen::visit(NOTBExpr &bexpr) { _cc_os << '!'; bexpr.bexpr()->accept(*this); } void CXXCodeGen::visit(TrueExpr &bexpr) { _cc_os << "1"; } // ---=== FUN GENERATOR ==================================--- #include "functions.hh" namespace codegen { class CXXFunGenerator : public CXXCodeGen { Fun &_fun; functions::FunctionWrapper &_wrapper; public: CXXFunGenerator(Fun &fun, ostream &hh_os, ostream &cc_os) : CXXCodeGen(hh_os, cc_os), _fun(fun), _wrapper(*(functions::get_function(fun.id()))) {} virtual void visit(Range &range); virtual void visit(Values &values); virtual void visit(WhereExpr &where_expr); virtual void visit(WhenExpr &when_expr); }; class CXXSelectGenerator : public CXXCodeGen { Fun &_fun; public: CXXSelectGenerator(Fun &fun, ostream &hh_os, ostream &cc_os) : CXXCodeGen(hh_os, cc_os), _fun(fun) {} virtual void visit(Values &values); virtual void visit(WhenExpr &when_expr); }; }; void CXXFunGenerator::visit(Range &range) { assert(&_fun == &range); _cc_os << "int _val_" << _fun.ast_id() << " = " << _wrapper.neutral_element() << ";\n"; range.where_expr()->accept(*this); } void CXXFunGenerator::visit(Values &values) { assert(&_fun == &values); _cc_os << "int _val_" << _fun.ast_id() << " = " << _wrapper.neutral_element() << ";\n"; list *when_exprs = values.when_expr_list(); list::iterator i; for (i = when_exprs->begin(); i != when_exprs->end(); ++i) (*i)->accept(*this); } void CXXFunGenerator::visit(WhereExpr &where_expr) { list *rexpr_list = where_expr.rexpr_list(); list::const_iterator i; for (i = rexpr_list->begin(); i != rexpr_list->end(); ++i) { (*i)->accept(*this); _cc_os << "{\n"; } CXXCodeGen cgen(_hh_os, _cc_os); ostringstream acc; acc << "_val_" << _fun.ast_id(); ostringstream newval; newval << "_val_" << where_expr.fun()->ast_id(); where_expr.fun()->accept(cgen); _cc_os << "_val_" << _fun.ast_id() << " = " << _wrapper.combine(acc.str(), newval.str()) << ";\n"; for (i = rexpr_list->begin(); i != rexpr_list->end(); ++i) { _cc_os << "}\n"; } } void CXXFunGenerator::visit(WhenExpr &when_expr) { CXXCodeGen cgen(_hh_os, _cc_os); ostringstream acc; acc << "_val_" << _fun.ast_id(); ostringstream newval; newval << "_val_" << when_expr.fun()->ast_id(); _cc_os << "if ("; when_expr.bexpr()->accept(cgen); _cc_os << ")\n{\n"; when_expr.fun()->accept(cgen); _cc_os << "_val_" << _fun.ast_id() << " = " << _wrapper.combine(acc.str(), newval.str()) << ";\n}\n"; } void CXXSelectGenerator::visit(Values &values) { _cc_os << "int _val_" << values.ast_id() << ";\n"; CXXCodeGen cgen(_hh_os, _cc_os); list *when_exprs = values.when_expr_list(); list::iterator i = when_exprs->begin(); _cc_os << "if ("; (*i)->bexpr()->accept(cgen); _cc_os << ")\n{\n"; (*i)->fun()->accept(cgen); _cc_os << "_val_" << values.ast_id() << " = _val_" << (*i)->fun()->ast_id() << ";\n}\n"; for (i++ ; i != when_exprs->end(); ++i) (*i)->accept(*this); _cc_os << "else {\n" << "assert(0);\n" << "}\n"; } void CXXSelectGenerator::visit(WhenExpr &when_expr) { CXXCodeGen cgen(_hh_os, _cc_os); _cc_os << "else if ("; when_expr.bexpr()->accept(cgen); _cc_os << ")\n{\n"; when_expr.fun()->accept(cgen); _cc_os << "_val_" << _fun.ast_id() << " = _val_" << when_expr.fun()->ast_id() << ";\n}\n"; } void codegen::cxx_generate_fun(Fun &fun, ostream &cc_os) { // handle 'select' special static string select = "select"; if (fun.id() == select) { CXXSelectGenerator sgen(fun, std::cout, cc_os); fun.accept(sgen); return; } // other functions CXXFunGenerator fungen(fun, std::cout, cc_os); fun.accept(fungen); } // ---=== MAIN GENERATOR =================================--- void CXXCodeGen::emit_matrix_index(Update &update) { // FIXME: refactor this const list *indices = update.indices(); int arity = indices->size(); switch (arity) { case 1: _cc_os << update.name() << '['; _cc_os << *(indices->begin()); _cc_os << ']'; break; default: _cc_os << update.name() << ".cell("; list::const_iterator j = indices->begin(); _cc_os << *j; for (j++; j != indices->end(); ++j) _cc_os << ", " << *j; _cc_os << ')'; } } void CXXCodeGen::emit_matrix_index(MatrixExpr &matrix_expr) { // FIXME: maybe refactor this const list *expr_list = matrix_expr.expr_list(); int arity = expr_list->size(); switch (arity) { case 1: _cc_os << matrix_expr.id() << '['; (*(expr_list->begin()))->accept(*this); _cc_os << ']'; break; default: _cc_os << matrix_expr.id() << ".cell("; list::const_iterator j = expr_list->begin(); (*j)->accept(*this); for (j++; j != expr_list->end(); ++j) { _cc_os << ", "; (*j)->accept(*this); } _cc_os << ')'; } } void CXXCodeGen::visit(Update &update) { update.fun()->accept(*this); emit_matrix_index(update); _cc_os << " = _val_" << update.fun()->ast_id() << ";\n"; } namespace codegen { class CXXGenerator : public type_checking::TypeVisitor { protected: ostream &_os; CXXGenerator(ostream &os) : _os(os) {}; void print (const ValType &val_type); void print (const IndexType &index_type); void print (const MatrixType &matrix_type, bool ref); void print (const FunctionType &function_type); }; class CXXParameterGenerator : public CXXGenerator { bool _first; public: CXXParameterGenerator(ostream &os) : CXXGenerator(os), _first(true) {} void comma_if_not_first() { if (_first) _first = false; // toggle flag else _os << ", "; } virtual void visit(MatrixType &matrix_type) { comma_if_not_first(); print(matrix_type, true); } virtual void visit(FunctionType &function_type) { /* FIXME: NOT SUPPORTED RIGHT NOW */ } virtual void visit(ValType &val_type) { comma_if_not_first(); print(val_type); } virtual void visit(IndexType &index_type) { comma_if_not_first(); print(index_type.name()); } }; class CXXGlobalGenerator : public CXXGenerator { public: CXXGlobalGenerator(ostream &os) : CXXGenerator(os) {} virtual void visit(MatrixType &matrix_type) { _os << "extern "; print(matrix_type, false); _os << ";\n"; } virtual void visit(FunctionType &function_type) { _os << "extern "; print(function_type); _os << ";\n"; } virtual void visit(ValType &val_type) { _os << "extern "; print(val_type); _os << ";\n"; } virtual void visit(IndexType &index_type) { _os << "extern "; print(index_type); _os << ";\n"; } }; class CXXLocalGenerator : public CXXGenerator { CXXCodeGen &_cgen; public: CXXLocalGenerator(ostream &os, CXXCodeGen &cgen) : CXXGenerator(os), _cgen(cgen) {} virtual void visit(MatrixType &matrix_type); virtual void visit(FunctionType &function_type) { /* FIXME: not supported yet */ } virtual void visit(ValType &val_type) { /* FIXME: not supported yet */ } virtual void visit(IndexType &index_type) { /* FIXME: not supported yet */ } }; }; void CXXGenerator::print (const ValType &val_type) { _os << translate_basic_type(val_type.type()) << ' ' << val_type.name(); } void CXXGenerator::print (const IndexType &index_type) { _os << "unsigned int " << index_type.name(); } void CXXGenerator::print (const MatrixType &matrix_type, bool ref) { switch (matrix_type.arity()) { case 1: _os << translate_basic_type(matrix_type.cell_type()) << ' ' << matrix_type.name() << "[]"; break; case 2: _os << "dprog::Matrix2<" << translate_basic_type(matrix_type.cell_type()) << "> " << (ref ? '&' : ' ') << matrix_type.name(); break; default: _os << "dprog::Matrix<" << translate_basic_type(matrix_type.cell_type()) << "> " << (ref ? '&' : ' ') << matrix_type.name(); break; } } void CXXGenerator::print (const FunctionType &function_type) { // FIXME: this only works for external declarations! _os << translate_basic_type(function_type.return_type()) << ' ' << function_type.name() << " ("; const list *par_types = function_type.parameter_types(); list::const_iterator i; const char *sep = ""; for (i = par_types->begin(); i != par_types->end(); ++i) { _os << sep << translate_basic_type(*i); sep = ", "; } _os << ')'; } void CXXCodeGen::print_parameters(ostream &os) { os << "void\n" << _method_name << " ("; CXXParameterGenerator pgen(os); set *symbols = symbol_checking::parameters(); set::iterator i; for (i = symbols->begin(); i != symbols->end(); ++i) (*i)->accept(pgen); os << ')'; } void CXXCodeGen::print_globals(std::ostream &os) { CXXGlobalGenerator gen(os); set *symbols = symbol_checking::globals(); set::iterator i; for (i = symbols->begin(); i != symbols->end(); ++i) (*i)->accept(gen); } void CXXLocalGenerator::visit(MatrixType &matrix_type) { const list *dims = matrix_type.dim_exprs(); list::const_iterator i; const char *sep = ""; switch (matrix_type.arity()) { case 1: _os << translate_basic_type(matrix_type.cell_type()) << ' ' << matrix_type.name() << '['; dims->front()->accept(_cgen); _os << "];\n"; break; case 2: _os << "Matrix2< " << translate_basic_type(matrix_type.cell_type()) << " > " << matrix_type.name() << '('; sep = ""; for (i = dims->begin(); i != dims->end(); ++i) { _os << sep; (*i)->accept(_cgen); sep = ", "; } _os << ");\n"; break; default: _os << "Matrix< " << translate_basic_type(matrix_type.cell_type()) << " > " << matrix_type.name() << '(' << matrix_type.arity(); sep = ", "; for (i = dims->begin(); i != dims->end(); ++i) { _os << sep; (*i)->accept(_cgen); } _os << ");\n"; break; } } void CXXCodeGen::print_locals(ostream &os) { CXXLocalGenerator gen(os, *this); set *symbols = symbol_checking::locals(); set::iterator i; for (i = symbols->begin(); i != symbols->end(); ++i) (*i)->accept(gen); } void CXXCodeGen::visit(DProg &dprog) { const list *update_list = dprog.update_list(); list::const_iterator u_itr; const list *rexpr_list = dprog.rexpr_list(); list::const_iterator re_itr; _hh_os << "#ifndef DPROG_" << _method_name << "_HH\n" << "#define DPROG_" << _method_name << "_HH\n\n" << "#include \n\n"; print_parameters(_hh_os); _hh_os << ";\n\n#endif\n"; _cc_os << "#include \n"; print_globals(_cc_os); print_parameters(_cc_os); _cc_os << "\n{\n"; print_locals(_cc_os); for (re_itr = rexpr_list->begin(); re_itr != rexpr_list->end(); ++re_itr) { (*re_itr)->accept(*this); _cc_os << "{\n"; } for (u_itr = update_list->begin(); u_itr != update_list->end(); ++u_itr) (*u_itr)->accept(*this); for (re_itr = rexpr_list->begin(); re_itr != rexpr_list->end(); ++re_itr) { _cc_os << "}\n"; } _cc_os << "}\n"; } void codegen::cxx_emit_code(DProg *dprog, const char *output_name) { if (output_name == 0) { // use stdout ofstream hh_os("/dev/null"); CXXCodeGen cgen(hh_os, std::cout); dprog->accept(cgen); hh_os.close(); } else { string hh_file = string(output_name)+".hh"; string cc_file = string(output_name)+".cc"; ofstream hh_os(hh_file.c_str()); ofstream cc_os(cc_file.c_str()); CXXCodeGen cgen(hh_os, cc_os, output_name); dprog->accept(cgen); hh_os.close(); cc_os.close(); } }