// Still need: // // Function calls - // The most general approach would be to just bridge calls to built in functions. // But that seems suboptimal. // // Dynamic array resizing. The current approach is not the best. Since we no longer cache the // addresses and sizes of the arrays in local variables, the best approach is to have the // v_resize function use the actual Array class interface to resize the arrays. This avoids // the problems with memory references and leaks. On the other hand, it does mean that the // v_resize function will have to be a member of the JITVM class. Or at least, it will have // to proxy to such a function. // // Range checking. // Type promotion on loop entry (and scalar variables) // // // Just for fun, mind you.... // // List of reasonable functions to JIT // // sec // csc // tan // atan // cot // exp // expm1 // ceil // floor // round // pi, e // float // single // double // int32 // nan/NaN/inf/Inf // IsNaN/isnan/IsInf/isinf // eps/feps // // What happens with something like: // if (t>0) // pi = 3.5; // else // a = pi; // end // In this case, we assign to PI in the first block, and // read from it in the second. This should be avoided - // it will not compile correctly -- assignment to a variable // that has the name of a JIT function should disable the JIT // mechanism #include "JITVM.hpp" #ifdef HAVE_LLVM #include "llvm/Analysis/Verifier.h" #include "llvm/Bitcode/ReaderWriter.h" #if 0 #include "llvm/Pass.h" #include "llvm/PassManager.h" #include "llvm/LinkAllPasses.h" #endif #include "llvm/Target/TargetData.h" #include // We want some basic functions to be available to the JIT // such as sin, cos, abs, x^n, tan // we also need division to work. using namespace llvm; llvm::FunctionType *ddFuncTy; llvm::FunctionType *ffFuncTy; llvm::FunctionType *iiFuncTy; static llvm::FunctionType* GetScalarFunctionType(const Type* arg, const Type* ret) { std::vector tArgs; tArgs.push_back(arg); llvm::FunctionType* fret = llvm::FunctionType::get(ret,tArgs,false, (ParamAttrsList*) 0); return fret; } JITFunction::JITFunction(string fun, const llvm::Type *ret, const llvm::Type *arg, llvm::Module* mod) { retType = ret; argType = arg; funcType = GetScalarFunctionType(arg,ret); funcAddress = new llvm::Function(funcType,GlobalValue::ExternalLinkage,fun,mod); } void JITVM::initialize_JIT_functions() { JITScalars.insertSymbol("pi",ConstantFP::get(Type::DoubleTy,4.0*atan(1.0))); JITScalars.insertSymbol("e",ConstantFP::get(Type::DoubleTy,exp(1.0))); JITScalars.insertSymbol("true",int32_const(1)); JITScalars.insertSymbol("false",int32_const(0)); // SIN JITDoubleFuncs.insertSymbol("sin",JITFunction("sin",Type::getPrimitiveType(Type::DoubleTyID), Type::getPrimitiveType(Type::DoubleTyID),M)); JITFloatFuncs.insertSymbol("sin",JITFunction("sinf",Type::getPrimitiveType(Type::FloatTyID), Type::getPrimitiveType(Type::FloatTyID),M)); // COS JITDoubleFuncs.insertSymbol("cos",JITFunction("cos",Type::getPrimitiveType(Type::DoubleTyID), Type::getPrimitiveType(Type::DoubleTyID),M)); JITFloatFuncs.insertSymbol("cos",JITFunction("cosf",Type::getPrimitiveType(Type::FloatTyID), Type::getPrimitiveType(Type::FloatTyID),M)); // ABS JITDoubleFuncs.insertSymbol("abs",JITFunction("fabs",Type::getPrimitiveType(Type::DoubleTyID), Type::getPrimitiveType(Type::DoubleTyID),M)); JITFloatFuncs.insertSymbol("abs",JITFunction("fabsf",Type::getPrimitiveType(Type::FloatTyID), Type::getPrimitiveType(Type::FloatTyID),M)); JITIntFuncs.insertSymbol("abs",JITFunction("abs",IntegerType::get(32), IntegerType::get(32),M)); // SEC // JITDoubleFuncs.insertSymbol("sec",JITFunction("sec",Type::getPrimitiveType(Type::DoubleTyID), // Type::getPrimitiveType(Type::DoubleTyID),M)); // JITFloatFuncs.insertSymbol("sec",JITFunction("secf",Type::getPrimitiveType(Type::FloatTyID), // Type::getPrimitiveType(Type::FloatTyID),M)); } static inline bool isi(JITScalar arg) { return arg->getType()->isInteger(); } static inline bool isfd(JITScalar arg) { return arg->getType()->isFloatingPoint(); } static inline bool isf(JITScalar arg) { return (arg->getType()->getTypeID() == Type::FloatTyID); } static inline bool isd(JITScalar arg) { return (arg->getType()->getTypeID() == Type::DoubleTyID); } JITScalar JITVM::compile_boolean_op(Instruction::BinaryOps op, JITScalar arg1, JITScalar arg2, string inst) { arg1 = cast(arg1,IntegerType::get(1),false,ip,""); arg2 = cast(arg2,IntegerType::get(1),false,ip,""); return BinaryOperator::create(op, arg1,arg2,"",ip); } // add immediate to immediate - this instruction can be replaced by // constant folding, but for now, leave it in. JITScalar JITVM::compile_binary_op(BinaryOperator::BinaryOps opcode, JITScalar arg1, JITScalar arg2, string inst) { const Type* outType; if (arg1->getType() == arg2->getType()) outType = arg1->getType(); else if ((isi(arg1) && isfd(arg2)) || (isi(arg2) && isfd(arg1))) outType = Type::DoubleTy; else if ((isf(arg1) && isd(arg2)) || (isd(arg1) && isf(arg2))) outType = Type::DoubleTy; arg1 = cast(arg1,outType,true,ip); arg2 = cast(arg2,outType,true,ip); return BinaryOperator::create(opcode,arg1,arg2,"",ip); } JITScalar JITVM::compile_comparison_op(byte op, JITScalar arg1, JITScalar arg2, string inst) { const Type* outType; if (arg1->getType() == arg2->getType()) outType = arg1->getType(); else if ((isi(arg1) && isfd(arg2)) || (isi(arg2) && isfd(arg1))) outType = Type::DoubleTy; else if ((isf(arg1) && isd(arg2)) || (isd(arg1) && isf(arg2))) outType = Type::DoubleTy; arg1 = cast(arg1,outType,true,ip); arg2 = cast(arg2,outType,true,ip); if (outType->isInteger()) { switch (op) { default: throw Exception("Unrecognized comparison op"); case '<': return new ICmpInst(ICmpInst::ICMP_SLT,arg1,arg2,"",ip); case TOK_LE: return new ICmpInst(ICmpInst::ICMP_SLE,arg1,arg2,"",ip); case TOK_EQ: return new ICmpInst(ICmpInst::ICMP_EQ,arg1,arg2,"",ip); case TOK_GE: return new ICmpInst(ICmpInst::ICMP_SGE,arg1,arg2,"",ip); case '>': return new ICmpInst(ICmpInst::ICMP_SGT,arg1,arg2,"",ip); case TOK_NE: return new ICmpInst(ICmpInst::ICMP_NE,arg1,arg2,"",ip); } } else { switch (op) { default: throw Exception("Unrecognized comparison op"); case '<': return new FCmpInst(FCmpInst::FCMP_OLT,arg1,arg2,"",ip); case TOK_LE: return new FCmpInst(FCmpInst::FCMP_OLE,arg1,arg2,"",ip); case TOK_EQ: return new FCmpInst(FCmpInst::FCMP_OEQ,arg1,arg2,"",ip); case TOK_GE: return new FCmpInst(FCmpInst::FCMP_OGE,arg1,arg2,"",ip); case '>': return new FCmpInst(FCmpInst::FCMP_OGT,arg1,arg2,"",ip); case TOK_NE: return new FCmpInst(FCmpInst::FCMP_ONE,arg1,arg2,"",ip); } } } static const Type* array_dereference(const Type* t) { const PointerType* p = dynamic_cast(t); if (!p) throw Exception("Expected pointer type in argument to array_dereference"); return p->getElementType(); } void JITVM::compile_assignment(tree t, Interpreter* m_eval) { tree s(t.first()); string symname(s.first().text()); JITScalar rhs(compile_expression(t.second(),m_eval)); JITSymbolInfo *v = symbols.findSymbol(symname); if (!v) { if (s.numchildren() == 1) v = add_argument_scalar(symname,m_eval,rhs,false); else v = add_argument_array(symname,m_eval); if (!v) throw Exception("Undefined variable reference:" + symname); } if (s.numchildren() == 1) { if (v->data_value->getType() != PointerType::get(rhs->getType())) throw Exception("polymorphic assignment to scalar detected."); if (!v->is_scalar) throw Exception("scalar assignment to array variable."); new StoreInst(rhs, v->data_value, ip); return; } if (s.numchildren() > 2) throw Exception("multiple levels of dereference not handled yet..."); if (v->is_scalar) throw Exception("array indexing of scalar values..."); tree q(s.second()); if (!q.is(TOK_PARENS)) throw Exception("non parenthetical dereferences not handled yet..."); if (q.numchildren() == 0) throw Exception("Expecting at least 1 array reference for dereference..."); if (q.numchildren() > 2) throw Exception("Expecting at most 2 array references for dereference..."); if (v->data_value->getType() != PointerType::get(PointerType::get(rhs->getType()))) { // Handle type promotion. The only case that I can think of // (with int, float and double) is if the destination array is // double, in which case, we can promote the RHS to type double const Type* p_p_array_type = v->data_value->getType(); const Type* array_type = array_dereference(array_dereference(p_p_array_type)); if (array_type && array_type->getTypeID() == Type::DoubleTyID) { rhs = cast(rhs,array_type,true,ip,""); } else throw Exception("polymorphic assignment to array detected"); } if (q.numchildren() == 1) { JITScalar arg1 = compile_expression(q.first(),m_eval); arg1 = cast(arg1,IntegerType::get(32),false,ip); // Add code to check the address against the bounds and resize if necessary Value* under_range = new ICmpInst(ICmpInst::ICMP_SLT,arg1,int32_const(1),"",ip); BasicBlock *bb1 = new BasicBlock("under_range",func,0); BasicBlock *bb2 = new BasicBlock("not_under_range",func,0); new BranchInst(bb1,bb2,under_range,ip); new StoreInst(int32_const(2),return_val,false,bb1); new BranchInst(func_epilog,bb1); Value* over_range = new ICmpInst(ICmpInst::ICMP_SGT, arg1,new LoadInst(v->num_length,"",false,bb2),"",bb2); BasicBlock *bb3 = new BasicBlock("need_resize",func,0); BasicBlock *bb4 = new BasicBlock("valid_range",func,0); new BranchInst(bb3,bb4,over_range,bb2); // Need to resize std::vector resize_params; resize_params.push_back(this_ptr); resize_params.push_back(int32_const(v->argument_index)); resize_params.push_back(arg1); new CallInst(v_resize_func_ptr, &resize_params[0], 3, "", bb3); new StoreInst(arg1,v->num_length,bb3); int argnum = v->argument_index; Value* p_arg = cast(get_input_argument(3*argnum,bb3), PointerType::get(map_dataclass_type(v->inferred_type)), false,bb3,""); new StoreInst(p_arg,v->data_value,bb3); new BranchInst(bb4,bb3); ip = bb4; arg1 = BinaryOperator::create(Instruction::Sub,arg1,int32_const(1),"",ip); Value *g = new LoadInst(v->data_value,"",false,ip); JITScalar address = new GetElementPtrInst(g, arg1, "", ip); new StoreInst(rhs, address, false, ip); } else if (q.numchildren() == 2) { JITScalar arg1 = compile_expression(q.first(),m_eval); JITScalar arg2 = compile_expression(q.second(),m_eval); arg1 = cast(arg1,IntegerType::get(32),false,ip); arg2 = cast(arg2,IntegerType::get(32),false,ip); Value* under_range1 = new ICmpInst(ICmpInst::ICMP_SLT,arg1,int32_const(1),"",ip); Value* under_range2 = new ICmpInst(ICmpInst::ICMP_SLT,arg2,int32_const(1),"",ip); Value* under_range = BinaryOperator::create(Instruction::Or, under_range1,under_range2,"",ip); BasicBlock *bb1 = new BasicBlock("under_range",func,0); BasicBlock *bb2 = new BasicBlock("not_under_range",func,0); new BranchInst(bb1,bb2,under_range,ip); new StoreInst(int32_const(2),return_val,false,bb1); new BranchInst(func_epilog,bb1); Value* over_range1 = new ICmpInst(ICmpInst::ICMP_SGT, arg1,new LoadInst(v->num_rows,"",false,bb2),"",bb2); Value* over_range2 = new ICmpInst(ICmpInst::ICMP_SGT, arg2,new LoadInst(v->num_cols,"",false,bb2),"",bb2); Value* over_range = BinaryOperator::create(Instruction::Or, over_range1,over_range2,"",bb2); BasicBlock *bb3 = new BasicBlock("need_resize",func,0); BasicBlock *bb4 = new BasicBlock("valid_range",func,0); new BranchInst(bb3,bb4,over_range,bb2); // Need to resize std::vector resize_params; resize_params.push_back(this_ptr); resize_params.push_back(int32_const(v->argument_index)); resize_params.push_back(arg1); resize_params.push_back(arg2); new CallInst(m_resize_func_ptr, &resize_params[0], 4, "", bb3); int argnum = v->argument_index; Value* p_arg = cast(get_input_argument(3*argnum,bb3), PointerType::get(map_dataclass_type(v->inferred_type)), false,bb3,""); new StoreInst(p_arg,v->data_value,bb3); Value* r_arg = cast(get_input_argument(3*argnum+1,bb3), PointerType::get(IntegerType::get(32)),false,bb3); Value* c_arg = cast(get_input_argument(3*argnum+2,bb3), PointerType::get(IntegerType::get(32)),false,bb3); copy_value(r_arg,v->num_rows,bb3); copy_value(c_arg,v->num_cols,bb3); new StoreInst(BinaryOperator::create(Instruction::Mul, new LoadInst(c_arg, "", false, bb3), new LoadInst(r_arg, "", false, bb3), "",bb3),v->num_length,false,bb3); new BranchInst(bb4,bb3); ip = bb4; // Add code to check range arg1 = BinaryOperator::create(Instruction::Sub,arg1,int32_const(1),"",ip); arg2 = BinaryOperator::create(Instruction::Sub,arg2,int32_const(1),"",ip); JITScalar lin = BinaryOperator::create(Instruction::Mul,arg2, new LoadInst(v->num_rows,"",false,ip), "",ip); lin = BinaryOperator::create(Instruction::Add,lin,arg1,"",ip); Value *g = new LoadInst(v->data_value, "", false, ip); JITScalar address = new GetElementPtrInst(g, lin, "", ip); new StoreInst(rhs, address, false, ip); } } void JITVM::compile_if_statement(tree t, Interpreter* m_eval) { JITScalar main_cond(cast(compile_expression(t.first(),m_eval), IntegerType::get(1),false,ip,"")); BasicBlock *if_true = new BasicBlock("if_true",func,0); BasicBlock *if_continue = new BasicBlock("if_continue",func,0); BasicBlock *if_exit = new BasicBlock("if_exit",func,0); new BranchInst(if_true,if_continue,main_cond,ip); ip = if_true; compile_block(t.second(),m_eval); new BranchInst(if_exit,ip); unsigned n=2; while (n < t.numchildren() && t.child(n).is(TOK_ELSEIF)) { ip = if_continue; JITScalar ttest(cast(compile_expression(t.child(n).first(),m_eval), IntegerType::get(1),false,ip,"")); if_true = new BasicBlock("elseif_true",func,0); if_continue = new BasicBlock("elseif_continue",func,0); new BranchInst(if_true,if_continue,ttest,ip); ip = if_true; compile_block(t.child(n).second(),m_eval); new BranchInst(if_exit,ip); n++; } if (t.last().is(TOK_ELSE)) { ip = if_continue; compile_block(t.last().first(),m_eval); new BranchInst(if_exit,ip); } else { new BranchInst(if_exit,if_continue); } ip = if_exit; } JITScalar JITVM::bool_const(int32 x) { return ConstantInt::get(Type::Int1Ty,x); } JITScalar JITVM::int32_const(int32 x) { return ConstantInt::get(Type::Int32Ty,x); } JITScalar JITVM::cast(JITScalar value, const Type *type, bool sgnd, BasicBlock *where, string name) { return CastInst::create(CastInst::getCastOpcode(value,sgnd,type,sgnd), value, type, name, where); } JITScalar JITVM::get_input_argument(int arg, BasicBlock* where) { Value *s = new GetElementPtrInst(ptr_inputs,int32_const(arg),"",where); s = new LoadInst(s,"",false,where); return s; } void JITVM::copy_value(JITScalar source, JITScalar dest, BasicBlock* where) { new StoreInst(new LoadInst(source, "", false, where), dest, false, where); } const Type* JITVM::map_dataclass_type(Class aclass) { switch (aclass) { default: throw Exception("JIT does not support"); case FM_INT32: return IntegerType::get(32); case FM_FLOAT: return Type::FloatTy; case FM_DOUBLE: return Type::DoubleTy; } return NULL; } JITSymbolInfo* JITVM::add_argument_array(string name, Interpreter* m_eval) { ArrayReference ptr(m_eval->getContext()->lookupVariable(name)); Class aclass = FM_FUNCPTR_ARRAY; if (!ptr.valid()) return NULL; if (!ptr->is2D()) throw Exception("Cannot JIT multi-dimensional array:" + name); if (ptr->isString() || ptr->isReferenceType()) throw Exception("Cannot JIT strings or reference types:" + name); if (ptr->isComplex()) throw Exception("Cannot JIT complex arrays:" + name); aclass = ptr->dataClass(); // Map the array class to an llvm type const Type* ctype(map_dataclass_type(aclass)); // Allocate local variables for the row, column and pointer to data Value* r = new AllocaInst(IntegerType::get(32),name+"_rows",func_prolog); Value* c = new AllocaInst(IntegerType::get(32),name+"_cols",func_prolog); Value* p = new AllocaInst(PointerType::get(ctype),name+"_data",func_prolog); Value* l = new AllocaInst(IntegerType::get(32),name+"_len",func_prolog); // Get pointers to the argument array elements Value* r_arg = cast(get_input_argument(3*argument_count+1,func_prolog), PointerType::get(IntegerType::get(32)),false,func_prolog,name+"_rows_in"); Value* c_arg = cast(get_input_argument(3*argument_count+2,func_prolog), PointerType::get(IntegerType::get(32)),false,func_prolog,name+"_cols_in"); Value* p_arg = cast(get_input_argument(3*argument_count,func_prolog), PointerType::get(ctype),false,func_prolog,name+"_data_in"); // Initialize the local variables from the argument array copy_value(r_arg,r,func_prolog); copy_value(c_arg,c,func_prolog); new StoreInst(p_arg,p,func_prolog); new StoreInst(BinaryOperator::create(Instruction::Mul, new LoadInst(c, "", false, func_prolog), new LoadInst(r, "", false, func_prolog), "",func_prolog),l,false,func_prolog); symbols.insertSymbol(name,JITSymbolInfo(true,argument_count,false,true, r,c,l,p,aclass,false)); argument_count++; return symbols.findSymbol(name); } // FIXME - Simplify JITSymbolInfo* JITVM::add_argument_scalar(string name, Interpreter* m_eval, JITScalar val, bool override) { ArrayReference ptr(m_eval->getContext()->lookupVariable(name)); Class aclass = FM_FUNCPTR_ARRAY; if (!val && !ptr.valid()) return NULL; if (!ptr.valid() || override) { if (isi(val)) aclass = FM_INT32; else if (isf(val)) aclass = FM_FLOAT; else if (isd(val)) aclass = FM_DOUBLE; } else { if (!ptr->isScalar()) throw Exception("Expect " + name + " to be a scalar"); if (ptr->isString() || ptr->isReferenceType()) throw Exception("Cannot JIT strings or reference types:" + name); if (ptr->isComplex()) throw Exception("Cannot JIT complex arrays:" + name); aclass = ptr->dataClass(); } Value *s = new GetElementPtrInst(ptr_inputs,int32_const(3*argument_count),"",func_prolog); s = new LoadInst(s, "", false, func_prolog); const Type* ctype(map_dataclass_type(aclass)); Value *r = cast(s,PointerType::get(ctype),false,func_prolog,name+"_in"); Value *t = new AllocaInst(ctype,name,func_prolog); new StoreInst(new LoadInst(r, "", false, func_prolog), t, false, func_prolog); new StoreInst(new LoadInst(t, "", false, func_epilog), r, false, func_epilog); symbols.insertSymbol(name,JITSymbolInfo(true,argument_count,true,true,NULL, NULL,NULL,t,aclass,false)); argument_count++; return symbols.findSymbol(name); } JITScalar JITVM::compile_scalar_function(string symname, Interpreter* m_eval) { // Look up the function in the set of scalars JITScalar *val; val = JITScalars.findSymbol(symname); if (!val) throw Exception("No JIT version of function " + symname); return *val; } JITScalar JITVM::compile_function_call(tree t, Interpreter* m_eval) { // First, make sure it is a function string symname(t.first().text()); FuncPtr funcval; if (!m_eval->lookupFunction(symname,funcval)) throw Exception("Couldn't find function " + symname); if (funcval->type() != FM_BUILT_IN_FUNCTION) throw Exception("Can only JIT built in functions - not " + symname); if (t.numchildren() != 2) return compile_scalar_function(symname,m_eval); // Evaluate the argument tree s(t.second()); if (!s.is(TOK_PARENS)) throw Exception("Expecting function arguments."); if (s.numchildren() > 1) throw Exception("Cannot JIT functions that take more than one argument"); if (s.numchildren() == 0) return compile_scalar_function(symname,m_eval); else { JITScalar arg = compile_expression(s.first(),m_eval); // First look up direct functions - also try double arg functions, as type // promotion means sin(int32) --> sin(double) JITFunction *fun; if (isi(arg)) { fun = JITIntFuncs.findSymbol(symname); if (!fun) fun = JITDoubleFuncs.findSymbol(symname); } else if (isf(arg)) { fun = JITFloatFuncs.findSymbol(symname); if (!fun) fun = JITDoubleFuncs.findSymbol(symname); } else if (isd(arg)) { fun = JITDoubleFuncs.findSymbol(symname); } if (!fun) throw Exception("No JIT version of function " + symname); if (!fun->argType) throw Exception("JIT version of function " + symname + " takes no arguments"); //The function exists and is defined - call it return new CallInst(fun->funcAddress,cast(arg,fun->argType,false,ip,""),"",ip); } } JITScalar JITVM::compile_rhs(tree t, Interpreter* m_eval) { string symname(t.first().text()); JITSymbolInfo *v = symbols.findSymbol(symname); if (!v) { if (t.numchildren() == 1) v = add_argument_scalar(symname,m_eval); else v = add_argument_array(symname,m_eval); if (!v) return compile_function_call(t,m_eval); } if (t.numchildren() == 1) { if (!v->is_scalar) throw Exception("non-scalar reference returned in scalar context!"); return new LoadInst(v->data_value, "", false, ip); } if (t.numchildren() > 2) throw Exception("multiple levels of dereference not handled yet..."); if (v->is_scalar) throw Exception("array indexing of scalar values..."); tree s(t.second()); if (!s.is(TOK_PARENS)) throw Exception("non parenthetical dereferences not handled yet..."); if (s.numchildren() == 0) throw Exception("Expecting at least 1 array reference for dereference..."); if (s.numchildren() > 2) throw Exception("Expecting at most 2 array references for dereference..."); if (s.numchildren() == 1) { JITScalar arg1 = compile_expression(s.first(),m_eval); arg1 = cast(arg1,IntegerType::get(32),false,ip); arg1 = BinaryOperator::create(Instruction::Sub,arg1,int32_const(1),"",ip); Value* under_range = new ICmpInst(ICmpInst::ICMP_SLT,arg1,int32_const(0),"",ip); Value* over_range = new ICmpInst(ICmpInst::ICMP_SGE, arg1, new LoadInst(v->num_length,"",false,ip), "",ip); Value* out_range = BinaryOperator::create(Instruction::Or, under_range,over_range,"",ip); BasicBlock *bb1 = new BasicBlock("",func,0); BasicBlock *bb2 = new BasicBlock("",func,0); new BranchInst(bb1, bb2, out_range, ip); new StoreInst(int32_const(1),return_val,false,bb1); new BranchInst(func_epilog,bb1); ip = bb2; Value *g = new LoadInst(v->data_value, "", false, ip); JITScalar address = new GetElementPtrInst(g, arg1, "", ip); return new LoadInst(address, "", false, ip); } else if (s.numchildren() == 2) { JITScalar arg1 = compile_expression(s.first(),m_eval); JITScalar arg2 = compile_expression(s.second(),m_eval); arg1 = cast(arg1,IntegerType::get(32),false,ip); arg2 = cast(arg2,IntegerType::get(32),false,ip); arg1 = BinaryOperator::create(Instruction::Sub,arg1,int32_const(1),"",ip); arg2 = BinaryOperator::create(Instruction::Sub,arg2,int32_const(1),"",ip); Value* under_range_1 = new ICmpInst(ICmpInst::ICMP_SLT,arg1,int32_const(0),"",ip); Value* under_range_2 = new ICmpInst(ICmpInst::ICMP_SLT,arg2,int32_const(0),"",ip); Value* over_range_1 = new ICmpInst(ICmpInst::ICMP_SGE, arg1, new LoadInst(v->num_rows,"",false,ip), "",ip); Value* over_range_2 = new ICmpInst(ICmpInst::ICMP_SGE, arg2, new LoadInst(v->num_cols,"",false,ip), "",ip); Value* out_range = BinaryOperator::create(Instruction::Or, BinaryOperator::create(Instruction::Or, under_range_1, under_range_2,"",ip), BinaryOperator::create(Instruction::Or, over_range_1, over_range_2,"",ip), "",ip); BasicBlock *bb1 = new BasicBlock("",func,0); BasicBlock *bb2 = new BasicBlock("",func,0); new BranchInst(bb1, bb2, out_range, ip); new StoreInst(int32_const(1),return_val,false,bb1); new BranchInst(func_epilog,bb1); ip = bb2; JITScalar lin = BinaryOperator::create(Instruction::Mul,arg2, new LoadInst(v->num_rows,"",false,ip), "",ip); lin = BinaryOperator::create(Instruction::Add,lin,arg1,"",ip); Value *g = new LoadInst(v->data_value, "", false, ip); JITScalar address = new GetElementPtrInst(g, lin, "", ip); return new LoadInst(address, "", false, ip); } throw Exception("dereference not handled yet..."); } JITScalar JITVM::compile_expression(tree t, Interpreter* m_eval) { switch(t.token()) { case TOK_VARIABLE: return compile_rhs(t,m_eval); case TOK_INTEGER: return int32_const(ArrayToInt32(t.array())); case TOK_FLOAT: return ConstantFP::get(Type::FloatTy,ArrayToDouble(t.array())); case TOK_DOUBLE: return ConstantFP::get(Type::DoubleTy,ArrayToDouble(t.array())); case TOK_COMPLEX: case TOK_DCOMPLEX: case TOK_STRING: case TOK_END: case ':': case TOK_MATDEF: case TOK_CELLDEF: throw Exception("JIT compiler does not support complex, string, END, matrix or cell defs"); case '+': return compile_binary_op(Instruction::Add, compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"add"); case '-': return compile_binary_op(Instruction::Sub, compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"sub"); case '*': case TOK_DOTTIMES: return compile_binary_op(Instruction::Mul, compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"mul"); case '/': case TOK_DOTRDIV: { JITScalar arg1 = compile_expression(t.first(),m_eval); JITScalar arg2 = compile_expression(t.second(),m_eval); if (!(isf(arg1) && isf(arg2))) { arg1 = cast(arg1,Type::DoubleTy,true,ip); arg2 = cast(arg2,Type::DoubleTy,true,ip); } return compile_binary_op(Instruction::FDiv,arg1,arg2,"div"); } case '\\': case TOK_DOTLDIV: { JITScalar arg1 = compile_expression(t.first(),m_eval); JITScalar arg2 = compile_expression(t.second(),m_eval); if (!(isf(arg1) && isf(arg2))) { arg1 = cast(arg1,Type::DoubleTy,true,ip); arg2 = cast(arg2,Type::DoubleTy,true,ip); } return compile_binary_op(Instruction::FDiv,arg2,arg1,"div"); } // FIXME: Are shortcuts handled correctly here? case TOK_SOR: case '|': return compile_boolean_op(Instruction::Or, compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"or"); case TOK_SAND: case '&': return compile_boolean_op(Instruction::And, compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"and"); case '<': return compile_comparison_op(t.token(), compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"lt"); case TOK_LE: return compile_comparison_op(t.token(), compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"le"); case '>': return compile_comparison_op(t.token(), compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"gt"); case TOK_GE: return compile_comparison_op(t.token(), compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"ge"); case TOK_EQ: return compile_comparison_op(t.token(), compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"eq"); case TOK_NE: return compile_comparison_op(t.token(), compile_expression(t.first(),m_eval), compile_expression(t.second(),m_eval),"ne"); case TOK_UNARY_MINUS: { JITScalar val(compile_expression(t.first(),m_eval)); return BinaryOperator::create(Instruction::Sub, Constant::getNullValue(val->getType()), val,"",ip); } case TOK_UNARY_PLUS: return compile_expression(t.first(),m_eval); case '~': { JITScalar val(compile_expression(t.first(),m_eval)); val = cast(val,IntegerType::get(1),false,ip); return BinaryOperator::create(Instruction::Xor,val,bool_const(1),"",ip); } case '^': throw Exception("^ is not currently handled by the JIT compiler"); case TOK_DOTPOWER: throw Exception(".^ is not currently handled by the JIT compiler"); case '\'': throw Exception("' is not currently handled by the JIT compiler"); case TOK_DOTTRANSPOSE: throw Exception(".' is not currently handled by the JIT compiler"); case '@': throw Exception("@ is not currently handled by the JIT compiler"); default: throw Exception("Unrecognized expression!"); } } void JITVM::compile_statement_type(tree t, Interpreter *m_eval) { switch(t.token()) { case '=': compile_assignment(t,m_eval); break; case TOK_MULTI: throw Exception("multi function calls do not JIT compile"); case TOK_SPECIAL: throw Exception("special function calls do not JIT compile"); case TOK_FOR: compile_for_block(t,m_eval); break; case TOK_WHILE: throw Exception("nested while loops do not JIT compile"); case TOK_IF: compile_if_statement(t,m_eval); break; case TOK_BREAK: throw Exception("break is not currently handled by the JIT compiler"); case TOK_CONTINUE: throw Exception("continue is not currently handled by the JIT compiler"); case TOK_DBSTEP: throw Exception("dbstep is not currently handled by the JIT compiler"); case TOK_DBTRACE: throw Exception("dbtrace is not currently handled by the JIT compiler"); case TOK_RETURN: throw Exception("return is not currently handled by the JIT compiler"); case TOK_SWITCH: throw Exception("switch is not currently handled by the JIT compiler"); case TOK_TRY: throw Exception("try is not currently handled by the JIT compiler"); case TOK_QUIT: throw Exception("quit is not currently handled by the JIT compiler"); case TOK_RETALL: throw Exception("retall is not currently handled by the JIT compiler"); case TOK_KEYBOARD: throw Exception("keyboard is not currently handled by the JIT compiler"); case TOK_GLOBAL: throw Exception("global is not currently handled by the JIT compiler"); case TOK_PERSISTENT: throw Exception("persistent is not currently handled by the JIT compiler"); case TOK_EXPR: compile_expression(t.first(),m_eval); break; case TOK_NEST_FUNC: break; default: throw Exception("Unrecognized statement type"); } } void JITVM::compile_statement(tree t, Interpreter *m_eval) { if (t.is(TOK_STATEMENT) && (t.first().is(TOK_EXPR) || t.first().is(TOK_SPECIAL) || t.first().is(TOK_MULTI) || t.first().is('='))) throw Exception("JIT compiler doesn't work with verbose statements"); compile_statement_type(t.first(),m_eval); } void JITVM::compile_block(tree t, Interpreter *m_eval) { const treeVector &statements(t.children()); for (treeVector::const_iterator i=statements.begin(); i!=statements.end();i++) compile_statement(*i,m_eval); } void JITVM::compile_for_block(tree t, Interpreter *m_eval) { if (!(t.first().is('=') && t.first().second().is(':') && t.first().second().first().is(TOK_INTEGER) && t.first().second().second().is(TOK_INTEGER))) throw Exception("For loop cannot be compiled - need integer bounds"); string loop_start(t.first().second().first().text()); string loop_stop(t.first().second().second().text()); string loop_index(t.first().first().text()); // Allocate a slot for the loop index register JITSymbolInfo* v = add_argument_scalar(loop_index,m_eval,ConstantInt::get(APInt(32, loop_start, 10)),true); JITScalar loop_index_address = v->data_value; new StoreInst(ConstantInt::get(APInt(32, loop_start, 10)), loop_index_address, false, ip); BasicBlock *loopbody = new BasicBlock("for_body",func,0); BasicBlock *looptest = new BasicBlock("for_test",func,0); BasicBlock *loopexit = new BasicBlock("for_exit",func,0); new BranchInst(looptest, ip); // Create 3 blocks ip = loopbody; compile_block(t.second(),m_eval); JITScalar loop_index_value = new LoadInst(loop_index_address, "", false, ip); JITScalar next_loop_value = BinaryOperator::create(Instruction::Add,loop_index_value,int32_const(1),"",ip); new StoreInst(next_loop_value, loop_index_address, false, ip); new BranchInst(looptest, ip); loop_index_value = new LoadInst(loop_index_address, "", false, looptest); JITScalar loop_comparison = new ICmpInst(ICmpInst::ICMP_SLE, loop_index_value, ConstantInt::get(APInt(32, loop_stop, 10)), "", looptest); new BranchInst(loopbody, loopexit, loop_comparison, looptest); ip = loopexit; } void JITVM::v_resize(void* base, int argnum, int r_new) { JITVM *this_ptr = static_cast(base); if (!this_ptr) throw Exception("vector resize failed"); this_ptr->array_inputs[argnum]->vectorResize(r_new); this_ptr->args[3*argnum] = (void*) this_ptr->array_inputs[argnum]->getReadWriteDataPointer(); *((int*)(this_ptr->args[3*argnum+1])) = this_ptr->array_inputs[argnum]->rows(); *((int*)(this_ptr->args[3*argnum+2])) = this_ptr->array_inputs[argnum]->columns(); } void JITVM::m_resize(void* base, int argnum, int r_new, int c_new) { JITVM *this_ptr = static_cast(base); if (!this_ptr) throw Exception("matrix resize failed"); Dimensions newDim(r_new,c_new); this_ptr->array_inputs[argnum]->resize(newDim); this_ptr->args[3*argnum] = (void*) this_ptr->array_inputs[argnum]->getReadWriteDataPointer(); *((int*)(this_ptr->args[3*argnum+1])) = this_ptr->array_inputs[argnum]->rows(); *((int*)(this_ptr->args[3*argnum+2])) = this_ptr->array_inputs[argnum]->columns(); } void JITVM::compile(tree t, Interpreter *m_eval) { // The signature for the compiled function should be: // int func(void** inputs); M = new Module("test"); initialize_JIT_functions(); // InitializeJITFunctions(M); std::vector DispatchFuncArgs; PointerType* void_pointer = PointerType::get(IntegerType::get(8)); PointerType* void_pointer_pointer = PointerType::get(void_pointer); PointerType* int32_pointer = PointerType::get(IntegerType::get(32)); std::vector vResizeFuncArgs; vResizeFuncArgs.push_back(void_pointer); //this pointer vResizeFuncArgs.push_back(IntegerType::get(32)); //array index vResizeFuncArgs.push_back(IntegerType::get(32)); //new row count vResizeFuncTy = llvm::FunctionType::get(Type::VoidTy, vResizeFuncArgs,false, (ParamAttrsList *) 0); std::vector mResizeFuncArgs; mResizeFuncArgs.push_back(void_pointer); //this pointer mResizeFuncArgs.push_back(IntegerType::get(32)); //array index mResizeFuncArgs.push_back(IntegerType::get(32)); //new row count mResizeFuncArgs.push_back(IntegerType::get(32)); //new col count mResizeFuncTy = llvm::FunctionType::get(Type::VoidTy, mResizeFuncArgs,false, (ParamAttrsList *) 0); DispatchFuncArgs.push_back(void_pointer_pointer); //argument array DispatchFuncArgs.push_back(PointerType::get(vResizeFuncTy)); //vector resize func DispatchFuncArgs.push_back(PointerType::get(mResizeFuncTy)); //matrix resize func DispatchFuncArgs.push_back(void_pointer); //this pointer llvm::FunctionType* DispatchFuncType = llvm::FunctionType::get(IntegerType::get(32), DispatchFuncArgs, false, (ParamAttrsList *)0); func = new Function(DispatchFuncType, GlobalValue::ExternalLinkage, "initArray", M); func->setCallingConv(CallingConv::C); Function::arg_iterator args = func->arg_begin(); ptr_inputs = args++; ptr_inputs->setName("inputs"); v_resize_func_ptr = args++; v_resize_func_ptr->setName("v_resize_func"); m_resize_func_ptr = args++; m_resize_func_ptr->setName("m_resize_func"); this_ptr = args++; this_ptr->setName("this_ptr"); ip = 0; argument_count = 0; func_prolog = new BasicBlock("func_prolog",func,0); func_body = new BasicBlock("func_body",func,0); func_epilog = new BasicBlock("func_epilog",func,0); return_val = new AllocaInst(IntegerType::get(32),"return_code",func_prolog); new StoreInst(int32_const(0),return_val,false,func_prolog); ip = func_body; compile_for_block(t,m_eval); new BranchInst(func_body,func_prolog); new BranchInst(func_epilog,ip); new ReturnInst(new LoadInst(return_val, "", false, func_epilog),func_epilog); std::cout << (*M); #if 0 if (0) { PassManager PM; PM.add(new TargetData(M)); PM.add(createVerifierPass()); // Verify that input is correct PM.add((Pass*)createLowerSetJmpPass()); // Lower llvm.setjmp/.longjmp // If the -strip-debug command line option was specified, do it. PM.add((Pass*)createRaiseAllocationsPass()); // call %malloc -> malloc inst PM.add((Pass*)createCFGSimplificationPass()); // Clean up disgusting code PM.add((Pass*)createPromoteMemoryToRegisterPass());// Kill useless allocas PM.add((Pass*)createGlobalOptimizerPass()); // Optimize out global vars PM.add((Pass*)createGlobalDCEPass()); // Remove unused fns and globs PM.add((Pass*)createIPConstantPropagationPass());// IP Constant Propagation PM.add((Pass*)createDeadArgEliminationPass()); // Dead argument elimination PM.add((Pass*)createInstructionCombiningPass()); // Clean up after IPCP & DAE PM.add((Pass*)createCFGSimplificationPass()); // Clean up after IPCP & DAE PM.add((Pass*)createPruneEHPass()); // Remove dead EH info PM.add((Pass*)createFunctionInliningPass()); // Inline small functions PM.add((Pass*)createArgumentPromotionPass()); // Scalarize uninlined fn args PM.add((Pass*)createTailDuplicationPass()); // Simplify cfg by copying code PM.add((Pass*)createInstructionCombiningPass()); // Cleanup for scalarrepl. PM.add((Pass*)createCFGSimplificationPass()); // Merge & remove BBs PM.add((Pass*)createScalarReplAggregatesPass()); // Break up aggregate allocas PM.add((Pass*)createInstructionCombiningPass()); // Combine silly seq's PM.add((Pass*)createCondPropagationPass()); // Propagate conditionals PM.add((Pass*)createTailCallEliminationPass()); // Eliminate tail calls PM.add((Pass*)createCFGSimplificationPass()); // Merge & remove BBs PM.add((Pass*)createReassociatePass()); // Reassociate expressions PM.add((Pass*)createLoopRotatePass()); PM.add((Pass*)createLICMPass()); // Hoist loop invariants PM.add((Pass*)createLoopUnswitchPass()); // Unswitch loops. PM.add((Pass*)createInstructionCombiningPass()); // Clean up after LICM/reassoc PM.add((Pass*)createIndVarSimplifyPass()); // Canonicalize indvars PM.add((Pass*)createLoopUnrollPass()); // Unroll small loops PM.add((Pass*)createInstructionCombiningPass()); // Clean up after the unroller PM.add((Pass*)createLoadValueNumberingPass()); // GVN for load instructions PM.add((Pass*)createGCSEPass()); // Remove common subexprs PM.add((Pass*)createSCCPPass()); // Constant prop with SCCP // Run instcombine after redundancy elimination to exploit opportunities // opened up by them. PM.add((Pass*)createInstructionCombiningPass()); PM.add((Pass*)createCondPropagationPass()); // Propagate conditionals PM.add((Pass*)createDeadStoreEliminationPass()); // Delete dead stores PM.add((Pass*)createAggressiveDCEPass()); // SSA based 'Aggressive DCE' PM.add((Pass*)createCFGSimplificationPass()); // Merge & remove BBs PM.add((Pass*)createSimplifyLibCallsPass()); // Library Call Optimizations PM.add((Pass*)createDeadTypeEliminationPass()); // Eliminate dead types PM.add((Pass*)createConstantMergePass()); // Merge dup global constants PM.run(*M); } std::cout << *M; #endif } void JITVM::run(Interpreter *m_eval) { // Collect the list of arguments stringVector argumentList(symbols.getCompletions("")); // Allocate the argument array args = (void**) malloc(sizeof(void*)*argumentList.size()*3); // For each argument in the array, retrieve it from the interpreter array_inputs.reserve(argumentList.size()); for (int i=0;igetContext()->lookupVariable(argumentList[i])); if (!ptr.valid()) { if (!v->is_scalar) throw Exception("cannot create array types in the loop"); m_eval->getContext()->insertVariable(argumentList[i], Array(v->inferred_type, Dimensions(1,1), Array::allocateArray(v->inferred_type,1))); ptr = m_eval->getContext()->lookupVariable(argumentList[i]); if (!ptr.valid()) throw Exception("unable to create variable " + argumentList[i]); } args[3*v->argument_index] = (void*) ptr->getReadWriteDataPointer(); args[3*v->argument_index+1] = (int*) malloc(sizeof(int)); *((int*)(args[3*v->argument_index+1])) = ptr->rows(); args[3*v->argument_index+2] = (int*) malloc(sizeof(int));; *((int*)(args[3*v->argument_index+2])) = ptr->columns(); array_inputs[v->argument_index] = &(*ptr); } } // std::ofstream p("jit.bc"); // WriteBitcodeToFile(M,p); // p.close(); // return; ExistingModuleProvider* MP = new ExistingModuleProvider(M); ExecutionEngine* EE = ExecutionEngine::create(MP, false); std::vector GVargs; GVargs.push_back(GenericValue(args)); GVargs.push_back(GenericValue((void*) &v_resize)); GVargs.push_back(GenericValue((void*) &m_resize)); GVargs.push_back(GenericValue((void*) this)); GenericValue gv = EE->runFunction(func,GVargs); delete EE; if (gv.IntVal == 1) throw Exception("Index exceeds variable dimensions"); } #endif