// Still need:
//
// Function calls -
// The most general approach would be to just bridge calls to built in functions.
// But that seems suboptimal.
//
// Dynamic array resizing. The current approach is not the best. Since we no longer cache the
// addresses and sizes of the arrays in local variables, the best approach is to have the
// v_resize function use the actual Array class interface to resize the arrays. This avoids
// the problems with memory references and leaks. On the other hand, it does mean that the
// v_resize function will have to be a member of the JITVM class. Or at least, it will have
// to proxy to such a function.
//
// Range checking.
// Type promotion on loop entry (and scalar variables)
//
//
// Just for fun, mind you....
//
// List of reasonable functions to JIT
//
// sec
// csc
// tan
// atan
// cot
// exp
// expm1
// ceil
// floor
// round
// pi, e
// float
// single
// double
// int32
// nan/NaN/inf/Inf
// IsNaN/isnan/IsInf/isinf
// eps/feps
//
// What happens with something like:
// if (t>0)
// pi = 3.5;
// else
// a = pi;
// end
// In this case, we assign to PI in the first block, and
// read from it in the second. This should be avoided -
// it will not compile correctly -- assignment to a variable
// that has the name of a JIT function should disable the JIT
// mechanism
#include "JITVM.hpp"
#ifdef HAVE_LLVM
#include "llvm/Analysis/Verifier.h"
#include "llvm/Bitcode/ReaderWriter.h"
#if 0
#include "llvm/Pass.h"
#include "llvm/PassManager.h"
#include "llvm/LinkAllPasses.h"
#endif
#include "llvm/Target/TargetData.h"
#include <fstream>
// We want some basic functions to be available to the JIT
// such as sin, cos, abs, x^n, tan
// we also need division to work.
using namespace llvm;
llvm::FunctionType *ddFuncTy;
llvm::FunctionType *ffFuncTy;
llvm::FunctionType *iiFuncTy;
static llvm::FunctionType* GetScalarFunctionType(const Type* arg,
const Type* ret) {
std::vector<const Type*> tArgs;
tArgs.push_back(arg);
llvm::FunctionType* fret = llvm::FunctionType::get(ret,tArgs,false,
(ParamAttrsList*) 0);
return fret;
}
JITFunction::JITFunction(string fun, const llvm::Type *ret, const llvm::Type *arg, llvm::Module* mod) {
retType = ret;
argType = arg;
funcType = GetScalarFunctionType(arg,ret);
funcAddress = new llvm::Function(funcType,GlobalValue::ExternalLinkage,fun,mod);
}
void JITVM::initialize_JIT_functions() {
JITScalars.insertSymbol("pi",ConstantFP::get(Type::DoubleTy,4.0*atan(1.0)));
JITScalars.insertSymbol("e",ConstantFP::get(Type::DoubleTy,exp(1.0)));
JITScalars.insertSymbol("true",int32_const(1));
JITScalars.insertSymbol("false",int32_const(0));
// SIN
JITDoubleFuncs.insertSymbol("sin",JITFunction("sin",Type::getPrimitiveType(Type::DoubleTyID),
Type::getPrimitiveType(Type::DoubleTyID),M));
JITFloatFuncs.insertSymbol("sin",JITFunction("sinf",Type::getPrimitiveType(Type::FloatTyID),
Type::getPrimitiveType(Type::FloatTyID),M));
// COS
JITDoubleFuncs.insertSymbol("cos",JITFunction("cos",Type::getPrimitiveType(Type::DoubleTyID),
Type::getPrimitiveType(Type::DoubleTyID),M));
JITFloatFuncs.insertSymbol("cos",JITFunction("cosf",Type::getPrimitiveType(Type::FloatTyID),
Type::getPrimitiveType(Type::FloatTyID),M));
// ABS
JITDoubleFuncs.insertSymbol("abs",JITFunction("fabs",Type::getPrimitiveType(Type::DoubleTyID),
Type::getPrimitiveType(Type::DoubleTyID),M));
JITFloatFuncs.insertSymbol("abs",JITFunction("fabsf",Type::getPrimitiveType(Type::FloatTyID),
Type::getPrimitiveType(Type::FloatTyID),M));
JITIntFuncs.insertSymbol("abs",JITFunction("abs",IntegerType::get(32),
IntegerType::get(32),M));
// SEC
// JITDoubleFuncs.insertSymbol("sec",JITFunction("sec",Type::getPrimitiveType(Type::DoubleTyID),
// Type::getPrimitiveType(Type::DoubleTyID),M));
// JITFloatFuncs.insertSymbol("sec",JITFunction("secf",Type::getPrimitiveType(Type::FloatTyID),
// Type::getPrimitiveType(Type::FloatTyID),M));
}
static inline bool isi(JITScalar arg) {
return arg->getType()->isInteger();
}
static inline bool isfd(JITScalar arg) {
return arg->getType()->isFloatingPoint();
}
static inline bool isf(JITScalar arg) {
return (arg->getType()->getTypeID() == Type::FloatTyID);
}
static inline bool isd(JITScalar arg) {
return (arg->getType()->getTypeID() == Type::DoubleTyID);
}
JITScalar JITVM::compile_boolean_op(Instruction::BinaryOps op, JITScalar arg1, JITScalar arg2, string inst) {
arg1 = cast(arg1,IntegerType::get(1),false,ip,"");
arg2 = cast(arg2,IntegerType::get(1),false,ip,"");
return BinaryOperator::create(op, arg1,arg2,"",ip);
}
// add immediate to immediate - this instruction can be replaced by
// constant folding, but for now, leave it in.
JITScalar JITVM::compile_binary_op(BinaryOperator::BinaryOps opcode,
JITScalar arg1, JITScalar arg2, string inst) {
const Type* outType;
if (arg1->getType() == arg2->getType())
outType = arg1->getType();
else if ((isi(arg1) && isfd(arg2)) || (isi(arg2) && isfd(arg1)))
outType = Type::DoubleTy;
else if ((isf(arg1) && isd(arg2)) || (isd(arg1) && isf(arg2)))
outType = Type::DoubleTy;
arg1 = cast(arg1,outType,true,ip);
arg2 = cast(arg2,outType,true,ip);
return BinaryOperator::create(opcode,arg1,arg2,"",ip);
}
JITScalar JITVM::compile_comparison_op(byte op, JITScalar arg1, JITScalar arg2, string inst) {
const Type* outType;
if (arg1->getType() == arg2->getType())
outType = arg1->getType();
else if ((isi(arg1) && isfd(arg2)) || (isi(arg2) && isfd(arg1)))
outType = Type::DoubleTy;
else if ((isf(arg1) && isd(arg2)) || (isd(arg1) && isf(arg2)))
outType = Type::DoubleTy;
arg1 = cast(arg1,outType,true,ip);
arg2 = cast(arg2,outType,true,ip);
if (outType->isInteger()) {
switch (op) {
default: throw Exception("Unrecognized comparison op");
case '<': return new ICmpInst(ICmpInst::ICMP_SLT,arg1,arg2,"",ip);
case TOK_LE: return new ICmpInst(ICmpInst::ICMP_SLE,arg1,arg2,"",ip);
case TOK_EQ: return new ICmpInst(ICmpInst::ICMP_EQ,arg1,arg2,"",ip);
case TOK_GE: return new ICmpInst(ICmpInst::ICMP_SGE,arg1,arg2,"",ip);
case '>': return new ICmpInst(ICmpInst::ICMP_SGT,arg1,arg2,"",ip);
case TOK_NE: return new ICmpInst(ICmpInst::ICMP_NE,arg1,arg2,"",ip);
}
} else {
switch (op) {
default: throw Exception("Unrecognized comparison op");
case '<': return new FCmpInst(FCmpInst::FCMP_OLT,arg1,arg2,"",ip);
case TOK_LE: return new FCmpInst(FCmpInst::FCMP_OLE,arg1,arg2,"",ip);
case TOK_EQ: return new FCmpInst(FCmpInst::FCMP_OEQ,arg1,arg2,"",ip);
case TOK_GE: return new FCmpInst(FCmpInst::FCMP_OGE,arg1,arg2,"",ip);
case '>': return new FCmpInst(FCmpInst::FCMP_OGT,arg1,arg2,"",ip);
case TOK_NE: return new FCmpInst(FCmpInst::FCMP_ONE,arg1,arg2,"",ip);
}
}
}
static const Type* array_dereference(const Type* t) {
const PointerType* p = dynamic_cast<const PointerType*>(t);
if (!p) throw Exception("Expected pointer type in argument to array_dereference");
return p->getElementType();
}
void JITVM::compile_assignment(tree t, Interpreter* m_eval) {
tree s(t.first());
string symname(s.first().text());
JITScalar rhs(compile_expression(t.second(),m_eval));
JITSymbolInfo *v = symbols.findSymbol(symname);
if (!v) {
if (s.numchildren() == 1)
v = add_argument_scalar(symname,m_eval,rhs,false);
else
v = add_argument_array(symname,m_eval);
if (!v) throw Exception("Undefined variable reference:" + symname);
}
if (s.numchildren() == 1) {
if (v->data_value->getType() != PointerType::get(rhs->getType()))
throw Exception("polymorphic assignment to scalar detected.");
if (!v->is_scalar)
throw Exception("scalar assignment to array variable.");
new StoreInst(rhs, v->data_value, ip);
return;
}
if (s.numchildren() > 2)
throw Exception("multiple levels of dereference not handled yet...");
if (v->is_scalar)
throw Exception("array indexing of scalar values...");
tree q(s.second());
if (!q.is(TOK_PARENS))
throw Exception("non parenthetical dereferences not handled yet...");
if (q.numchildren() == 0)
throw Exception("Expecting at least 1 array reference for dereference...");
if (q.numchildren() > 2)
throw Exception("Expecting at most 2 array references for dereference...");
if (v->data_value->getType() != PointerType::get(PointerType::get(rhs->getType()))) {
// Handle type promotion. The only case that I can think of
// (with int, float and double) is if the destination array is
// double, in which case, we can promote the RHS to type double
const Type* p_p_array_type = v->data_value->getType();
const Type* array_type = array_dereference(array_dereference(p_p_array_type));
if (array_type && array_type->getTypeID() == Type::DoubleTyID) {
rhs = cast(rhs,array_type,true,ip,"");
} else
throw Exception("polymorphic assignment to array detected");
}
if (q.numchildren() == 1) {
JITScalar arg1 = compile_expression(q.first(),m_eval);
arg1 = cast(arg1,IntegerType::get(32),false,ip);
// Add code to check the address against the bounds and resize if necessary
Value* under_range = new ICmpInst(ICmpInst::ICMP_SLT,arg1,int32_const(1),"",ip);
BasicBlock *bb1 = new BasicBlock("under_range",func,0);
BasicBlock *bb2 = new BasicBlock("not_under_range",func,0);
new BranchInst(bb1,bb2,under_range,ip);
new StoreInst(int32_const(2),return_val,false,bb1);
new BranchInst(func_epilog,bb1);
Value* over_range = new ICmpInst(ICmpInst::ICMP_SGT,
arg1,new LoadInst(v->num_length,"",false,bb2),"",bb2);
BasicBlock *bb3 = new BasicBlock("need_resize",func,0);
BasicBlock *bb4 = new BasicBlock("valid_range",func,0);
new BranchInst(bb3,bb4,over_range,bb2);
// Need to resize
std::vector<Value*> resize_params;
resize_params.push_back(this_ptr);
resize_params.push_back(int32_const(v->argument_index));
resize_params.push_back(arg1);
new CallInst(v_resize_func_ptr, &resize_params[0], 3, "", bb3);
new StoreInst(arg1,v->num_length,bb3);
int argnum = v->argument_index;
Value* p_arg = cast(get_input_argument(3*argnum,bb3),
PointerType::get(map_dataclass_type(v->inferred_type)),
false,bb3,"");
new StoreInst(p_arg,v->data_value,bb3);
new BranchInst(bb4,bb3);
ip = bb4;
arg1 = BinaryOperator::create(Instruction::Sub,arg1,int32_const(1),"",ip);
Value *g = new LoadInst(v->data_value,"",false,ip);
JITScalar address = new GetElementPtrInst(g, arg1, "", ip);
new StoreInst(rhs, address, false, ip);
} else if (q.numchildren() == 2) {
JITScalar arg1 = compile_expression(q.first(),m_eval);
JITScalar arg2 = compile_expression(q.second(),m_eval);
arg1 = cast(arg1,IntegerType::get(32),false,ip);
arg2 = cast(arg2,IntegerType::get(32),false,ip);
Value* under_range1 = new ICmpInst(ICmpInst::ICMP_SLT,arg1,int32_const(1),"",ip);
Value* under_range2 = new ICmpInst(ICmpInst::ICMP_SLT,arg2,int32_const(1),"",ip);
Value* under_range = BinaryOperator::create(Instruction::Or,
under_range1,under_range2,"",ip);
BasicBlock *bb1 = new BasicBlock("under_range",func,0);
BasicBlock *bb2 = new BasicBlock("not_under_range",func,0);
new BranchInst(bb1,bb2,under_range,ip);
new StoreInst(int32_const(2),return_val,false,bb1);
new BranchInst(func_epilog,bb1);
Value* over_range1 = new ICmpInst(ICmpInst::ICMP_SGT,
arg1,new LoadInst(v->num_rows,"",false,bb2),"",bb2);
Value* over_range2 = new ICmpInst(ICmpInst::ICMP_SGT,
arg2,new LoadInst(v->num_cols,"",false,bb2),"",bb2);
Value* over_range = BinaryOperator::create(Instruction::Or,
over_range1,over_range2,"",bb2);
BasicBlock *bb3 = new BasicBlock("need_resize",func,0);
BasicBlock *bb4 = new BasicBlock("valid_range",func,0);
new BranchInst(bb3,bb4,over_range,bb2);
// Need to resize
std::vector<Value*> resize_params;
resize_params.push_back(this_ptr);
resize_params.push_back(int32_const(v->argument_index));
resize_params.push_back(arg1);
resize_params.push_back(arg2);
new CallInst(m_resize_func_ptr, &resize_params[0], 4, "", bb3);
int argnum = v->argument_index;
Value* p_arg = cast(get_input_argument(3*argnum,bb3),
PointerType::get(map_dataclass_type(v->inferred_type)),
false,bb3,"");
new StoreInst(p_arg,v->data_value,bb3);
Value* r_arg = cast(get_input_argument(3*argnum+1,bb3),
PointerType::get(IntegerType::get(32)),false,bb3);
Value* c_arg = cast(get_input_argument(3*argnum+2,bb3),
PointerType::get(IntegerType::get(32)),false,bb3);
copy_value(r_arg,v->num_rows,bb3);
copy_value(c_arg,v->num_cols,bb3);
new StoreInst(BinaryOperator::create(Instruction::Mul,
new LoadInst(c_arg, "", false, bb3),
new LoadInst(r_arg, "", false, bb3),
"",bb3),v->num_length,false,bb3);
new BranchInst(bb4,bb3);
ip = bb4;
// Add code to check range
arg1 = BinaryOperator::create(Instruction::Sub,arg1,int32_const(1),"",ip);
arg2 = BinaryOperator::create(Instruction::Sub,arg2,int32_const(1),"",ip);
JITScalar lin = BinaryOperator::create(Instruction::Mul,arg2,
new LoadInst(v->num_rows,"",false,ip),
"",ip);
lin = BinaryOperator::create(Instruction::Add,lin,arg1,"",ip);
Value *g = new LoadInst(v->data_value, "", false, ip);
JITScalar address = new GetElementPtrInst(g, lin, "", ip);
new StoreInst(rhs, address, false, ip);
}
}
void JITVM::compile_if_statement(tree t, Interpreter* m_eval) {
JITScalar main_cond(cast(compile_expression(t.first(),m_eval),
IntegerType::get(1),false,ip,""));
BasicBlock *if_true = new BasicBlock("if_true",func,0);
BasicBlock *if_continue = new BasicBlock("if_continue",func,0);
BasicBlock *if_exit = new BasicBlock("if_exit",func,0);
new BranchInst(if_true,if_continue,main_cond,ip);
ip = if_true;
compile_block(t.second(),m_eval);
new BranchInst(if_exit,ip);
unsigned n=2;
while (n < t.numchildren() && t.child(n).is(TOK_ELSEIF)) {
ip = if_continue;
JITScalar ttest(cast(compile_expression(t.child(n).first(),m_eval),
IntegerType::get(1),false,ip,""));
if_true = new BasicBlock("elseif_true",func,0);
if_continue = new BasicBlock("elseif_continue",func,0);
new BranchInst(if_true,if_continue,ttest,ip);
ip = if_true;
compile_block(t.child(n).second(),m_eval);
new BranchInst(if_exit,ip);
n++;
}
if (t.last().is(TOK_ELSE)) {
ip = if_continue;
compile_block(t.last().first(),m_eval);
new BranchInst(if_exit,ip);
} else {
new BranchInst(if_exit,if_continue);
}
ip = if_exit;
}
JITScalar JITVM::bool_const(int32 x) {
return ConstantInt::get(Type::Int1Ty,x);
}
JITScalar JITVM::int32_const(int32 x) {
return ConstantInt::get(Type::Int32Ty,x);
}
JITScalar JITVM::cast(JITScalar value, const Type *type, bool sgnd, BasicBlock *where, string name) {
return CastInst::create(CastInst::getCastOpcode(value,sgnd,type,sgnd),
value, type, name, where);
}
JITScalar JITVM::get_input_argument(int arg, BasicBlock* where) {
Value *s = new GetElementPtrInst(ptr_inputs,int32_const(arg),"",where);
s = new LoadInst(s,"",false,where);
return s;
}
void JITVM::copy_value(JITScalar source, JITScalar dest, BasicBlock* where) {
new StoreInst(new LoadInst(source, "", false, where), dest, false, where);
}
const Type* JITVM::map_dataclass_type(Class aclass) {
switch (aclass) {
default:
throw Exception("JIT does not support");
case FM_INT32:
return IntegerType::get(32);
case FM_FLOAT:
return Type::FloatTy;
case FM_DOUBLE:
return Type::DoubleTy;
}
return NULL;
}
JITSymbolInfo* JITVM::add_argument_array(string name, Interpreter* m_eval) {
ArrayReference ptr(m_eval->getContext()->lookupVariable(name));
Class aclass = FM_FUNCPTR_ARRAY;
if (!ptr.valid())
return NULL;
if (!ptr->is2D())
throw Exception("Cannot JIT multi-dimensional array:" + name);
if (ptr->isString() || ptr->isReferenceType())
throw Exception("Cannot JIT strings or reference types:" + name);
if (ptr->isComplex())
throw Exception("Cannot JIT complex arrays:" + name);
aclass = ptr->dataClass();
// Map the array class to an llvm type
const Type* ctype(map_dataclass_type(aclass));
// Allocate local variables for the row, column and pointer to data
Value* r = new AllocaInst(IntegerType::get(32),name+"_rows",func_prolog);
Value* c = new AllocaInst(IntegerType::get(32),name+"_cols",func_prolog);
Value* p = new AllocaInst(PointerType::get(ctype),name+"_data",func_prolog);
Value* l = new AllocaInst(IntegerType::get(32),name+"_len",func_prolog);
// Get pointers to the argument array elements
Value* r_arg = cast(get_input_argument(3*argument_count+1,func_prolog),
PointerType::get(IntegerType::get(32)),false,func_prolog,name+"_rows_in");
Value* c_arg = cast(get_input_argument(3*argument_count+2,func_prolog),
PointerType::get(IntegerType::get(32)),false,func_prolog,name+"_cols_in");
Value* p_arg = cast(get_input_argument(3*argument_count,func_prolog),
PointerType::get(ctype),false,func_prolog,name+"_data_in");
// Initialize the local variables from the argument array
copy_value(r_arg,r,func_prolog);
copy_value(c_arg,c,func_prolog);
new StoreInst(p_arg,p,func_prolog);
new StoreInst(BinaryOperator::create(Instruction::Mul,
new LoadInst(c, "", false, func_prolog),
new LoadInst(r, "", false, func_prolog),
"",func_prolog),l,false,func_prolog);
symbols.insertSymbol(name,JITSymbolInfo(true,argument_count,false,true,
r,c,l,p,aclass,false));
argument_count++;
return symbols.findSymbol(name);
}
// FIXME - Simplify
JITSymbolInfo* JITVM::add_argument_scalar(string name, Interpreter* m_eval, JITScalar val, bool override) {
ArrayReference ptr(m_eval->getContext()->lookupVariable(name));
Class aclass = FM_FUNCPTR_ARRAY;
if (!val && !ptr.valid()) return NULL;
if (!ptr.valid() || override) {
if (isi(val))
aclass = FM_INT32;
else if (isf(val))
aclass = FM_FLOAT;
else if (isd(val))
aclass = FM_DOUBLE;
} else {
if (!ptr->isScalar())
throw Exception("Expect " + name + " to be a scalar");
if (ptr->isString() || ptr->isReferenceType())
throw Exception("Cannot JIT strings or reference types:" + name);
if (ptr->isComplex())
throw Exception("Cannot JIT complex arrays:" + name);
aclass = ptr->dataClass();
}
Value *s = new GetElementPtrInst(ptr_inputs,int32_const(3*argument_count),"",func_prolog);
s = new LoadInst(s, "", false, func_prolog);
const Type* ctype(map_dataclass_type(aclass));
Value *r = cast(s,PointerType::get(ctype),false,func_prolog,name+"_in");
Value *t = new AllocaInst(ctype,name,func_prolog);
new StoreInst(new LoadInst(r, "", false, func_prolog), t, false, func_prolog);
new StoreInst(new LoadInst(t, "", false, func_epilog), r, false, func_epilog);
symbols.insertSymbol(name,JITSymbolInfo(true,argument_count,true,true,NULL,
NULL,NULL,t,aclass,false));
argument_count++;
return symbols.findSymbol(name);
}
JITScalar JITVM::compile_scalar_function(string symname, Interpreter* m_eval) {
// Look up the function in the set of scalars
JITScalar *val;
val = JITScalars.findSymbol(symname);
if (!val) throw Exception("No JIT version of function " + symname);
return *val;
}
JITScalar JITVM::compile_function_call(tree t, Interpreter* m_eval) {
// First, make sure it is a function
string symname(t.first().text());
FuncPtr funcval;
if (!m_eval->lookupFunction(symname,funcval))
throw Exception("Couldn't find function " + symname);
if (funcval->type() != FM_BUILT_IN_FUNCTION)
throw Exception("Can only JIT built in functions - not " + symname);
if (t.numchildren() != 2)
return compile_scalar_function(symname,m_eval);
// Evaluate the argument
tree s(t.second());
if (!s.is(TOK_PARENS))
throw Exception("Expecting function arguments.");
if (s.numchildren() > 1)
throw Exception("Cannot JIT functions that take more than one argument");
if (s.numchildren() == 0)
return compile_scalar_function(symname,m_eval);
else {
JITScalar arg = compile_expression(s.first(),m_eval);
// First look up direct functions - also try double arg functions, as type
// promotion means sin(int32) --> sin(double)
JITFunction *fun;
if (isi(arg)) {
fun = JITIntFuncs.findSymbol(symname);
if (!fun) fun = JITDoubleFuncs.findSymbol(symname);
} else if (isf(arg)) {
fun = JITFloatFuncs.findSymbol(symname);
if (!fun) fun = JITDoubleFuncs.findSymbol(symname);
} else if (isd(arg)) {
fun = JITDoubleFuncs.findSymbol(symname);
}
if (!fun) throw Exception("No JIT version of function " + symname);
if (!fun->argType) throw Exception("JIT version of function " + symname + " takes no arguments");
//The function exists and is defined - call it
return new CallInst(fun->funcAddress,cast(arg,fun->argType,false,ip,""),"",ip);
}
}
JITScalar JITVM::compile_rhs(tree t, Interpreter* m_eval) {
string symname(t.first().text());
JITSymbolInfo *v = symbols.findSymbol(symname);
if (!v) {
if (t.numchildren() == 1)
v = add_argument_scalar(symname,m_eval);
else
v = add_argument_array(symname,m_eval);
if (!v)
return compile_function_call(t,m_eval);
}
if (t.numchildren() == 1) {
if (!v->is_scalar)
throw Exception("non-scalar reference returned in scalar context!");
return new LoadInst(v->data_value, "", false, ip);
}
if (t.numchildren() > 2)
throw Exception("multiple levels of dereference not handled yet...");
if (v->is_scalar)
throw Exception("array indexing of scalar values...");
tree s(t.second());
if (!s.is(TOK_PARENS))
throw Exception("non parenthetical dereferences not handled yet...");
if (s.numchildren() == 0)
throw Exception("Expecting at least 1 array reference for dereference...");
if (s.numchildren() > 2)
throw Exception("Expecting at most 2 array references for dereference...");
if (s.numchildren() == 1) {
JITScalar arg1 = compile_expression(s.first(),m_eval);
arg1 = cast(arg1,IntegerType::get(32),false,ip);
arg1 = BinaryOperator::create(Instruction::Sub,arg1,int32_const(1),"",ip);
Value* under_range = new ICmpInst(ICmpInst::ICMP_SLT,arg1,int32_const(0),"",ip);
Value* over_range = new ICmpInst(ICmpInst::ICMP_SGE,
arg1,
new LoadInst(v->num_length,"",false,ip),
"",ip);
Value* out_range = BinaryOperator::create(Instruction::Or,
under_range,over_range,"",ip);
BasicBlock *bb1 = new BasicBlock("",func,0);
BasicBlock *bb2 = new BasicBlock("",func,0);
new BranchInst(bb1, bb2, out_range, ip);
new StoreInst(int32_const(1),return_val,false,bb1);
new BranchInst(func_epilog,bb1);
ip = bb2;
Value *g = new LoadInst(v->data_value, "", false, ip);
JITScalar address = new GetElementPtrInst(g, arg1, "", ip);
return new LoadInst(address, "", false, ip);
} else if (s.numchildren() == 2) {
JITScalar arg1 = compile_expression(s.first(),m_eval);
JITScalar arg2 = compile_expression(s.second(),m_eval);
arg1 = cast(arg1,IntegerType::get(32),false,ip);
arg2 = cast(arg2,IntegerType::get(32),false,ip);
arg1 = BinaryOperator::create(Instruction::Sub,arg1,int32_const(1),"",ip);
arg2 = BinaryOperator::create(Instruction::Sub,arg2,int32_const(1),"",ip);
Value* under_range_1 = new ICmpInst(ICmpInst::ICMP_SLT,arg1,int32_const(0),"",ip);
Value* under_range_2 = new ICmpInst(ICmpInst::ICMP_SLT,arg2,int32_const(0),"",ip);
Value* over_range_1 = new ICmpInst(ICmpInst::ICMP_SGE,
arg1,
new LoadInst(v->num_rows,"",false,ip),
"",ip);
Value* over_range_2 = new ICmpInst(ICmpInst::ICMP_SGE,
arg2,
new LoadInst(v->num_cols,"",false,ip),
"",ip);
Value* out_range = BinaryOperator::create(Instruction::Or,
BinaryOperator::create(Instruction::Or,
under_range_1,
under_range_2,"",ip),
BinaryOperator::create(Instruction::Or,
over_range_1,
over_range_2,"",ip),
"",ip);
BasicBlock *bb1 = new BasicBlock("",func,0);
BasicBlock *bb2 = new BasicBlock("",func,0);
new BranchInst(bb1, bb2, out_range, ip);
new StoreInst(int32_const(1),return_val,false,bb1);
new BranchInst(func_epilog,bb1);
ip = bb2;
JITScalar lin = BinaryOperator::create(Instruction::Mul,arg2,
new LoadInst(v->num_rows,"",false,ip),
"",ip);
lin = BinaryOperator::create(Instruction::Add,lin,arg1,"",ip);
Value *g = new LoadInst(v->data_value, "", false, ip);
JITScalar address = new GetElementPtrInst(g, lin, "", ip);
return new LoadInst(address, "", false, ip);
}
throw Exception("dereference not handled yet...");
}
JITScalar JITVM::compile_expression(tree t, Interpreter* m_eval) {
switch(t.token()) {
case TOK_VARIABLE: return compile_rhs(t,m_eval);
case TOK_INTEGER: return int32_const(ArrayToInt32(t.array()));
case TOK_FLOAT: return ConstantFP::get(Type::FloatTy,ArrayToDouble(t.array()));
case TOK_DOUBLE: return ConstantFP::get(Type::DoubleTy,ArrayToDouble(t.array()));
case TOK_COMPLEX:
case TOK_DCOMPLEX:
case TOK_STRING:
case TOK_END:
case ':':
case TOK_MATDEF:
case TOK_CELLDEF: throw Exception("JIT compiler does not support complex, string, END, matrix or cell defs");
case '+':
return compile_binary_op(Instruction::Add,
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"add");
case '-':
return compile_binary_op(Instruction::Sub,
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"sub");
case '*':
case TOK_DOTTIMES:
return compile_binary_op(Instruction::Mul,
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"mul");
case '/':
case TOK_DOTRDIV:
{
JITScalar arg1 = compile_expression(t.first(),m_eval);
JITScalar arg2 = compile_expression(t.second(),m_eval);
if (!(isf(arg1) && isf(arg2))) {
arg1 = cast(arg1,Type::DoubleTy,true,ip);
arg2 = cast(arg2,Type::DoubleTy,true,ip);
}
return compile_binary_op(Instruction::FDiv,arg1,arg2,"div");
}
case '\\':
case TOK_DOTLDIV:
{
JITScalar arg1 = compile_expression(t.first(),m_eval);
JITScalar arg2 = compile_expression(t.second(),m_eval);
if (!(isf(arg1) && isf(arg2))) {
arg1 = cast(arg1,Type::DoubleTy,true,ip);
arg2 = cast(arg2,Type::DoubleTy,true,ip);
}
return compile_binary_op(Instruction::FDiv,arg2,arg1,"div");
}
// FIXME: Are shortcuts handled correctly here?
case TOK_SOR:
case '|':
return compile_boolean_op(Instruction::Or,
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"or");
case TOK_SAND:
case '&':
return compile_boolean_op(Instruction::And,
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"and");
case '<':
return compile_comparison_op(t.token(),
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"lt");
case TOK_LE:
return compile_comparison_op(t.token(),
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"le");
case '>':
return compile_comparison_op(t.token(),
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"gt");
case TOK_GE:
return compile_comparison_op(t.token(),
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"ge");
case TOK_EQ:
return compile_comparison_op(t.token(),
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"eq");
case TOK_NE:
return compile_comparison_op(t.token(),
compile_expression(t.first(),m_eval),
compile_expression(t.second(),m_eval),"ne");
case TOK_UNARY_MINUS:
{
JITScalar val(compile_expression(t.first(),m_eval));
return BinaryOperator::create(Instruction::Sub,
Constant::getNullValue(val->getType()),
val,"",ip);
}
case TOK_UNARY_PLUS:
return compile_expression(t.first(),m_eval);
case '~':
{
JITScalar val(compile_expression(t.first(),m_eval));
val = cast(val,IntegerType::get(1),false,ip);
return BinaryOperator::create(Instruction::Xor,val,bool_const(1),"",ip);
}
case '^': throw Exception("^ is not currently handled by the JIT compiler");
case TOK_DOTPOWER: throw Exception(".^ is not currently handled by the JIT compiler");
case '\'': throw Exception("' is not currently handled by the JIT compiler");
case TOK_DOTTRANSPOSE: throw Exception(".' is not currently handled by the JIT compiler");
case '@': throw Exception("@ is not currently handled by the JIT compiler");
default: throw Exception("Unrecognized expression!");
}
}
void JITVM::compile_statement_type(tree t, Interpreter *m_eval) {
switch(t.token()) {
case '=':
compile_assignment(t,m_eval);
break;
case TOK_MULTI:
throw Exception("multi function calls do not JIT compile");
case TOK_SPECIAL:
throw Exception("special function calls do not JIT compile");
case TOK_FOR:
compile_for_block(t,m_eval);
break;
case TOK_WHILE:
throw Exception("nested while loops do not JIT compile");
case TOK_IF:
compile_if_statement(t,m_eval);
break;
case TOK_BREAK: throw Exception("break is not currently handled by the JIT compiler");
case TOK_CONTINUE: throw Exception("continue is not currently handled by the JIT compiler");
case TOK_DBSTEP: throw Exception("dbstep is not currently handled by the JIT compiler");
case TOK_DBTRACE: throw Exception("dbtrace is not currently handled by the JIT compiler");
case TOK_RETURN: throw Exception("return is not currently handled by the JIT compiler");
case TOK_SWITCH: throw Exception("switch is not currently handled by the JIT compiler");
case TOK_TRY: throw Exception("try is not currently handled by the JIT compiler");
case TOK_QUIT: throw Exception("quit is not currently handled by the JIT compiler");
case TOK_RETALL: throw Exception("retall is not currently handled by the JIT compiler");
case TOK_KEYBOARD: throw Exception("keyboard is not currently handled by the JIT compiler");
case TOK_GLOBAL: throw Exception("global is not currently handled by the JIT compiler");
case TOK_PERSISTENT: throw Exception("persistent is not currently handled by the JIT compiler");
case TOK_EXPR:
compile_expression(t.first(),m_eval);
break;
case TOK_NEST_FUNC:
break;
default:
throw Exception("Unrecognized statement type");
}
}
void JITVM::compile_statement(tree t, Interpreter *m_eval) {
if (t.is(TOK_STATEMENT) &&
(t.first().is(TOK_EXPR) || t.first().is(TOK_SPECIAL) ||
t.first().is(TOK_MULTI) || t.first().is('=')))
throw Exception("JIT compiler doesn't work with verbose statements");
compile_statement_type(t.first(),m_eval);
}
void JITVM::compile_block(tree t, Interpreter *m_eval) {
const treeVector &statements(t.children());
for (treeVector::const_iterator i=statements.begin();
i!=statements.end();i++)
compile_statement(*i,m_eval);
}
void JITVM::compile_for_block(tree t, Interpreter *m_eval) {
if (!(t.first().is('=') && t.first().second().is(':') &&
t.first().second().first().is(TOK_INTEGER) &&
t.first().second().second().is(TOK_INTEGER)))
throw Exception("For loop cannot be compiled - need integer bounds");
string loop_start(t.first().second().first().text());
string loop_stop(t.first().second().second().text());
string loop_index(t.first().first().text());
// Allocate a slot for the loop index register
JITSymbolInfo* v = add_argument_scalar(loop_index,m_eval,ConstantInt::get(APInt(32, loop_start, 10)),true);
JITScalar loop_index_address = v->data_value;
new StoreInst(ConstantInt::get(APInt(32, loop_start, 10)), loop_index_address,
false, ip);
BasicBlock *loopbody = new BasicBlock("for_body",func,0);
BasicBlock *looptest = new BasicBlock("for_test",func,0);
BasicBlock *loopexit = new BasicBlock("for_exit",func,0);
new BranchInst(looptest, ip);
// Create 3 blocks
ip = loopbody;
compile_block(t.second(),m_eval);
JITScalar loop_index_value = new LoadInst(loop_index_address, "", false, ip);
JITScalar next_loop_value = BinaryOperator::create(Instruction::Add,loop_index_value,int32_const(1),"",ip);
new StoreInst(next_loop_value, loop_index_address, false, ip);
new BranchInst(looptest, ip);
loop_index_value = new LoadInst(loop_index_address, "", false, looptest);
JITScalar loop_comparison = new ICmpInst(ICmpInst::ICMP_SLE, loop_index_value,
ConstantInt::get(APInt(32, loop_stop, 10)),
"", looptest);
new BranchInst(loopbody, loopexit, loop_comparison, looptest);
ip = loopexit;
}
void JITVM::v_resize(void* base, int argnum, int r_new) {
JITVM *this_ptr = static_cast<JITVM*>(base);
if (!this_ptr) throw Exception("vector resize failed");
this_ptr->array_inputs[argnum]->vectorResize(r_new);
this_ptr->args[3*argnum] = (void*) this_ptr->array_inputs[argnum]->getReadWriteDataPointer();
*((int*)(this_ptr->args[3*argnum+1])) = this_ptr->array_inputs[argnum]->rows();
*((int*)(this_ptr->args[3*argnum+2])) = this_ptr->array_inputs[argnum]->columns();
}
void JITVM::m_resize(void* base, int argnum, int r_new, int c_new) {
JITVM *this_ptr = static_cast<JITVM*>(base);
if (!this_ptr) throw Exception("matrix resize failed");
Dimensions newDim(r_new,c_new);
this_ptr->array_inputs[argnum]->resize(newDim);
this_ptr->args[3*argnum] = (void*) this_ptr->array_inputs[argnum]->getReadWriteDataPointer();
*((int*)(this_ptr->args[3*argnum+1])) = this_ptr->array_inputs[argnum]->rows();
*((int*)(this_ptr->args[3*argnum+2])) = this_ptr->array_inputs[argnum]->columns();
}
void JITVM::compile(tree t, Interpreter *m_eval) {
// The signature for the compiled function should be:
// int func(void** inputs);
M = new Module("test");
initialize_JIT_functions();
// InitializeJITFunctions(M);
std::vector<const Type*> DispatchFuncArgs;
PointerType* void_pointer = PointerType::get(IntegerType::get(8));
PointerType* void_pointer_pointer = PointerType::get(void_pointer);
PointerType* int32_pointer = PointerType::get(IntegerType::get(32));
std::vector<const Type*> vResizeFuncArgs;
vResizeFuncArgs.push_back(void_pointer); //this pointer
vResizeFuncArgs.push_back(IntegerType::get(32)); //array index
vResizeFuncArgs.push_back(IntegerType::get(32)); //new row count
vResizeFuncTy = llvm::FunctionType::get(Type::VoidTy,
vResizeFuncArgs,false,
(ParamAttrsList *) 0);
std::vector<const Type*> mResizeFuncArgs;
mResizeFuncArgs.push_back(void_pointer); //this pointer
mResizeFuncArgs.push_back(IntegerType::get(32)); //array index
mResizeFuncArgs.push_back(IntegerType::get(32)); //new row count
mResizeFuncArgs.push_back(IntegerType::get(32)); //new col count
mResizeFuncTy = llvm::FunctionType::get(Type::VoidTy,
mResizeFuncArgs,false,
(ParamAttrsList *) 0);
DispatchFuncArgs.push_back(void_pointer_pointer); //argument array
DispatchFuncArgs.push_back(PointerType::get(vResizeFuncTy)); //vector resize func
DispatchFuncArgs.push_back(PointerType::get(mResizeFuncTy)); //matrix resize func
DispatchFuncArgs.push_back(void_pointer); //this pointer
llvm::FunctionType* DispatchFuncType = llvm::FunctionType::get(IntegerType::get(32),
DispatchFuncArgs,
false,
(ParamAttrsList *)0);
func = new Function(DispatchFuncType,
GlobalValue::ExternalLinkage,
"initArray", M);
func->setCallingConv(CallingConv::C);
Function::arg_iterator args = func->arg_begin();
ptr_inputs = args++;
ptr_inputs->setName("inputs");
v_resize_func_ptr = args++;
v_resize_func_ptr->setName("v_resize_func");
m_resize_func_ptr = args++;
m_resize_func_ptr->setName("m_resize_func");
this_ptr = args++;
this_ptr->setName("this_ptr");
ip = 0;
argument_count = 0;
func_prolog = new BasicBlock("func_prolog",func,0);
func_body = new BasicBlock("func_body",func,0);
func_epilog = new BasicBlock("func_epilog",func,0);
return_val = new AllocaInst(IntegerType::get(32),"return_code",func_prolog);
new StoreInst(int32_const(0),return_val,false,func_prolog);
ip = func_body;
compile_for_block(t,m_eval);
new BranchInst(func_body,func_prolog);
new BranchInst(func_epilog,ip);
new ReturnInst(new LoadInst(return_val, "", false, func_epilog),func_epilog);
std::cout << (*M);
#if 0
if (0) {
PassManager PM;
PM.add(new TargetData(M));
PM.add(createVerifierPass()); // Verify that input is correct
PM.add((Pass*)createLowerSetJmpPass()); // Lower llvm.setjmp/.longjmp
// If the -strip-debug command line option was specified, do it.
PM.add((Pass*)createRaiseAllocationsPass()); // call %malloc -> malloc inst
PM.add((Pass*)createCFGSimplificationPass()); // Clean up disgusting code
PM.add((Pass*)createPromoteMemoryToRegisterPass());// Kill useless allocas
PM.add((Pass*)createGlobalOptimizerPass()); // Optimize out global vars
PM.add((Pass*)createGlobalDCEPass()); // Remove unused fns and globs
PM.add((Pass*)createIPConstantPropagationPass());// IP Constant Propagation
PM.add((Pass*)createDeadArgEliminationPass()); // Dead argument elimination
PM.add((Pass*)createInstructionCombiningPass()); // Clean up after IPCP & DAE
PM.add((Pass*)createCFGSimplificationPass()); // Clean up after IPCP & DAE
PM.add((Pass*)createPruneEHPass()); // Remove dead EH info
PM.add((Pass*)createFunctionInliningPass()); // Inline small functions
PM.add((Pass*)createArgumentPromotionPass()); // Scalarize uninlined fn args
PM.add((Pass*)createTailDuplicationPass()); // Simplify cfg by copying code
PM.add((Pass*)createInstructionCombiningPass()); // Cleanup for scalarrepl.
PM.add((Pass*)createCFGSimplificationPass()); // Merge & remove BBs
PM.add((Pass*)createScalarReplAggregatesPass()); // Break up aggregate allocas
PM.add((Pass*)createInstructionCombiningPass()); // Combine silly seq's
PM.add((Pass*)createCondPropagationPass()); // Propagate conditionals
PM.add((Pass*)createTailCallEliminationPass()); // Eliminate tail calls
PM.add((Pass*)createCFGSimplificationPass()); // Merge & remove BBs
PM.add((Pass*)createReassociatePass()); // Reassociate expressions
PM.add((Pass*)createLoopRotatePass());
PM.add((Pass*)createLICMPass()); // Hoist loop invariants
PM.add((Pass*)createLoopUnswitchPass()); // Unswitch loops.
PM.add((Pass*)createInstructionCombiningPass()); // Clean up after LICM/reassoc
PM.add((Pass*)createIndVarSimplifyPass()); // Canonicalize indvars
PM.add((Pass*)createLoopUnrollPass()); // Unroll small loops
PM.add((Pass*)createInstructionCombiningPass()); // Clean up after the unroller
PM.add((Pass*)createLoadValueNumberingPass()); // GVN for load instructions
PM.add((Pass*)createGCSEPass()); // Remove common subexprs
PM.add((Pass*)createSCCPPass()); // Constant prop with SCCP
// Run instcombine after redundancy elimination to exploit opportunities
// opened up by them.
PM.add((Pass*)createInstructionCombiningPass());
PM.add((Pass*)createCondPropagationPass()); // Propagate conditionals
PM.add((Pass*)createDeadStoreEliminationPass()); // Delete dead stores
PM.add((Pass*)createAggressiveDCEPass()); // SSA based 'Aggressive DCE'
PM.add((Pass*)createCFGSimplificationPass()); // Merge & remove BBs
PM.add((Pass*)createSimplifyLibCallsPass()); // Library Call Optimizations
PM.add((Pass*)createDeadTypeEliminationPass()); // Eliminate dead types
PM.add((Pass*)createConstantMergePass()); // Merge dup global constants
PM.run(*M);
}
std::cout << *M;
#endif
}
void JITVM::run(Interpreter *m_eval) {
// Collect the list of arguments
stringVector argumentList(symbols.getCompletions(""));
// Allocate the argument array
args = (void**) malloc(sizeof(void*)*argumentList.size()*3);
// For each argument in the array, retrieve it from the interpreter
array_inputs.reserve(argumentList.size());
for (int i=0;i<argumentList.size();i++) {
JITSymbolInfo* v = symbols.findSymbol(argumentList[i]);
if (v) {
ArrayReference ptr(m_eval->getContext()->lookupVariable(argumentList[i]));
if (!ptr.valid()) {
if (!v->is_scalar) throw Exception("cannot create array types in the loop");
m_eval->getContext()->insertVariable(argumentList[i],
Array(v->inferred_type,
Dimensions(1,1),
Array::allocateArray(v->inferred_type,1)));
ptr = m_eval->getContext()->lookupVariable(argumentList[i]);
if (!ptr.valid()) throw Exception("unable to create variable " + argumentList[i]);
}
args[3*v->argument_index] = (void*) ptr->getReadWriteDataPointer();
args[3*v->argument_index+1] = (int*) malloc(sizeof(int));
*((int*)(args[3*v->argument_index+1])) = ptr->rows();
args[3*v->argument_index+2] = (int*) malloc(sizeof(int));;
*((int*)(args[3*v->argument_index+2])) = ptr->columns();
array_inputs[v->argument_index] = &(*ptr);
}
}
// std::ofstream p("jit.bc");
// WriteBitcodeToFile(M,p);
// p.close();
// return;
ExistingModuleProvider* MP = new ExistingModuleProvider(M);
ExecutionEngine* EE = ExecutionEngine::create(MP, false);
std::vector<GenericValue> GVargs;
GVargs.push_back(GenericValue(args));
GVargs.push_back(GenericValue((void*) &v_resize));
GVargs.push_back(GenericValue((void*) &m_resize));
GVargs.push_back(GenericValue((void*) this));
GenericValue gv = EE->runFunction(func,GVargs);
delete EE;
if (gv.IntVal == 1)
throw Exception("Index exceeds variable dimensions");
}
#endif
syntax highlighted by Code2HTML, v. 0.9.1