/* Copyright (C) 2004 - 2006 db4objects Inc. http://www.db4o.com
This file is part of the db4o open source object database.
db4o is free software; you can redistribute it and/or modify it under
the terms of version 2 of the GNU General Public License as published
by the Free Software Foundation and as clarified by db4objects' GPL
interpretation policy, available at
http://www.db4o.com/about/company/legalpolicies/gplinterpretation/
Alternatively you can write to db4objects, Inc., 1900 S Norfolk Street,
Suite 350, San Mateo, CA 94403, USA.
db4o is distributed in the hope that it will be useful, but WITHOUT ANY
WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
for more details.
You should have received a copy of the GNU General Public License along
with this program; if not, write to the Free Software Foundation, Inc.,
59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. */
using Cecil.FlowAnalysis.Impl.CodeStructure;
namespace Db4oTools.NativeQueries
{
using System;
using System.Collections;
using System.Reflection;
using Mono.Cecil;
using Cecil.FlowAnalysis;
using Cecil.FlowAnalysis.ActionFlow;
using Cecil.FlowAnalysis.CodeStructure;
using com.db4o.nativequery.expr;
using com.db4o.nativequery.expr.cmp;
using com.db4o.nativequery.expr.cmp.field;
using com.db4o.inside.query;
///
/// Build a com.db4o.nativequery.expr tree out of
/// a predicate method definition.
///
public class QueryExpressionBuilder : ExpressionBuilder
{
public override Expression FromMethod(System.Reflection.MethodBase method)
{
if (method == null) throw new ArgumentNullException("method");
Expression e = (Expression)_expressionCachingStrategy.Get(method);
if (e != null) return e;
string location = GetAssemblyLocation(method);
AssemblyDefinition assembly = GetAssembly(location);
TypeDefinition type = FindTypeDefinition(assembly.MainModule, method.DeclaringType);
if (null == type) UnsupportedPredicate(string.Format("Unable to load type '{0}' from assembly '{1}'", method.DeclaringType.FullName, location));
MethodDefinition methodDef = type.Methods.GetMethod(method.Name, GetParameterTypes(method));
if (null == methodDef) UnsupportedPredicate(string.Format("Unable to load the definition of '{0}' from assembly '{1}'", method, location));
e = AdjustBoxedValueTypes(FromMethodDefinition(methodDef));
_expressionCachingStrategy.Add(method, e);
return e;
}
private Expression AdjustBoxedValueTypes(Expression expression)
{
expression.Accept(new BoxedValueTypeProcessor());
return expression;
}
private static AssemblyDefinition GetAssembly(string location)
{
AssemblyDefinition assembly = (AssemblyDefinition)_assemblyCachingStrategy.Get(location);
if (null == assembly)
{
assembly = AssemblyFactory.GetAssembly(location);
_assemblyCachingStrategy.Add(location, assembly);
}
return assembly;
}
private static Type[] GetParameterTypes(MethodBase method)
{
ParameterInfo[] parameters = method.GetParameters();
Type[] types = new Type[parameters.Length];
for (int i = 0; i < parameters.Length; ++i)
{
types[i] = parameters[i].ParameterType;
}
return types;
}
private static TypeDefinition FindTypeDefinition(ModuleDefinition module, Type type)
{
return IsNested(type)
? FindNestedTypeDefinition(module, type)
: FindTypeDefinition(module, type.FullName);
}
private static bool IsNested(Type type)
{
return type.IsNestedPublic || type.IsNestedPrivate || type.IsNestedAssembly;
}
private static TypeDefinition FindNestedTypeDefinition(ModuleDefinition module, Type type)
{
foreach (TypeDefinition td in FindTypeDefinition(module, type.DeclaringType).NestedTypes)
{
if (td.Name == type.Name)
return td;
}
return null;
}
private static TypeDefinition FindTypeDefinition(ModuleDefinition module, string fullName)
{
return module.Types[fullName];
}
private static string GetAssemblyLocation(MethodBase method)
{
return method.DeclaringType.Module.FullyQualifiedName;
}
public static Expression FromMethodDefinition(MethodDefinition method)
{
if (method == null)
throw new ArgumentNullException("method");
if (1 != method.Parameters.Count)
UnsupportedPredicate("A predicate must take a single argument.");
if (0 != method.Body.ExceptionHandlers.Count)
UnsupportedPredicate("A predicate can not contain exception handlers.");
if (method.DeclaringType.Module.Import(typeof(bool)) != method.ReturnType.ReturnType)
UnsupportedPredicate("A predicate must have a boolean return type.");
IExpression expression = GetQueryExpression(method);
if (null == expression)
UnsupportedPredicate("No expression found.");
Visitor visitor = new Visitor(method);
expression.Accept(visitor);
return visitor.Expression;
}
private static IExpression GetQueryExpression(IMethodDefinition method)
{
IActionFlowGraph afg = FlowGraphFactory.CreateActionFlowGraph(FlowGraphFactory.CreateControlFlowGraph(method));
return GetQueryExpression(afg);
}
private static void UnsupportedPredicate(string reason)
{
throw new UnsupportedPredicateException(reason);
}
private static void UnsupportedExpression(IExpression node)
{
UnsupportedPredicate("Unsupported expression: " + ExpressionPrinter.ToString(node));
}
private static IExpression GetQueryExpression(IActionFlowGraph afg)
{
Hashtable variables = new Hashtable();
IActionBlock block = afg.Blocks[0];
while (block != null)
{
switch (block.ActionType)
{
case ActionType.Invoke:
UnsupportedExpression(((IInvokeActionBlock)block).Expression);
break;
case ActionType.ConditionalBranch:
UnsupportedPredicate("Conditional blocks are not supported.");
break;
case ActionType.Branch:
block = ((IBranchActionBlock)block).Target;
break;
case ActionType.Assign:
{
IAssignActionBlock assignBlock = (IAssignActionBlock)block;
IAssignExpression assign = assignBlock.AssignExpression;
IVariableReferenceExpression variable = assign.Target as IVariableReferenceExpression;
if (null == variable)
{
UnsupportedExpression(assign);
}
else
{
if (variables.Contains(variable.Variable.Index))
UnsupportedExpression(assign.Expression);
variables.Add(variable.Variable.Index, assign.Expression);
block = assignBlock.Next;
}
break;
}
case ActionType.Return:
{
IExpression expression = ((IReturnActionBlock)block).Expression;
IVariableReferenceExpression variable = expression as IVariableReferenceExpression;
return null == variable
? expression
: (IExpression)variables[variable.Variable.Index];
}
}
}
return null;
}
class Visitor : AbstractCodeStructureVisitor
{
object _current;
private int _insideCandidate = 0;
Hashtable _assemblies = new Hashtable();
IList _methodDefinitionStack = new ArrayList();
public Visitor(IMethodDefinition topLevelMethod)
{
EnterMethodDefinition(topLevelMethod);
RegisterAssembly(topLevelMethod.DeclaringType.Module.Assembly);
}
private void EnterMethodDefinition(IMethodDefinition method)
{
_methodDefinitionStack.Add(method);
}
private void LeaveMethodDefinition(IMethodDefinition method)
{
int lastIndex = _methodDefinitionStack.Count - 1;
object popped = _methodDefinitionStack[lastIndex];
System.Diagnostics.Debug.Assert(method == popped);
_methodDefinitionStack.RemoveAt(lastIndex);
}
///
/// Registers an assembly so it can be looked up by its assembly name
/// string later.
///
///
private void RegisterAssembly(IAssemblyDefinition assembly)
{
_assemblies.Add(assembly.Name.FullName, assembly);
}
private AssemblyDefinition LookupAssembly(string fullName)
{
return (AssemblyDefinition)_assemblies[fullName];
}
public Expression Expression
{
get { return (Expression)_current; }
}
private bool InsideCandidate
{
get { return _insideCandidate > 0; }
}
public override void Visit(IAssignExpression node)
{
UnsupportedExpression(node);
}
public override void Visit(IVariableReferenceExpression node)
{
UnsupportedExpression(node);
}
public override void Visit(IArgumentReferenceExpression node)
{
UnsupportedExpression(node);
}
public override void Visit(IUnaryExpression node)
{
switch (node.Operator)
{
case UnaryOperator.Not:
Visit(node.Operand);
Negate();
break;
default:
UnsupportedExpression(node);
break;
}
}
public override void Visit(IBinaryExpression node)
{
switch (node.Operator)
{
case BinaryOperator.ValueEquality:
PushComparison(node.Left, node.Right, ComparisonOperator.EQUALS);
break;
case BinaryOperator.ValueInequality:
PushComparison(node.Left, node.Right, ComparisonOperator.EQUALS);
Negate();
break;
case BinaryOperator.LessThan:
PushComparison(node.Left, node.Right, ComparisonOperator.SMALLER);
break;
case BinaryOperator.GreaterThan:
PushComparison(node.Left, node.Right, ComparisonOperator.GREATER);
break;
case BinaryOperator.GreaterThanOrEqual:
PushComparison(node.Left, node.Right, ComparisonOperator.SMALLER);
Negate();
break;
case BinaryOperator.LessThanOrEqual:
PushComparison(node.Left, node.Right, ComparisonOperator.GREATER);
Negate();
break;
case BinaryOperator.LogicalOr:
Push(new OrExpression(Convert(node.Left), Convert(node.Right)));
break;
case BinaryOperator.LogicalAnd:
Push(new AndExpression(Convert(node.Left), Convert(node.Right)));
break;
default:
UnsupportedExpression(node);
break;
}
}
private void Negate()
{
Expression top = (Expression)Pop();
NotExpression topNot = top as NotExpression;
if (topNot != null)
{
Push(topNot.Expr());
return;
}
Push(new NotExpression(top));
}
private void PushComparison(IExpression lhs, IExpression rhs, ComparisonOperator op)
{
Visit(lhs);
object left = Pop();
Visit(rhs);
object right = Pop();
bool areOperandsSwapped = IsCandidateFieldValue(right);
if (areOperandsSwapped)
{
object temp = left;
left = right;
right = temp;
}
AssertType(left, typeof(FieldValue), lhs);
AssertType(right, typeof(ComparisonOperand), rhs);
Push(new ComparisonExpression((FieldValue)left, (ComparisonOperand)right, op));
if (areOperandsSwapped && !op.IsSymmetric())
{
Negate();
}
}
private bool IsCandidateFieldValue(object o)
{
FieldValue value = o as FieldValue;
if (value == null) return false;
return value.Root() is CandidateFieldRoot;
}
public override void Visit(IMethodInvocationExpression node)
{
IMethodReferenceExpression methodRef = node.Target as IMethodReferenceExpression;
if (null == methodRef)
UnsupportedExpression(node);
IMethodReference method = methodRef.Method;
if (IsOperator(method))
{
ProcessOperatorMethodInvocation(node, method);
return;
}
if (IsSystemString(method.DeclaringType))
{
ProcessStringMethod(node, methodRef);
return;
}
ProcessRegularMethodInvocation(node, methodRef);
}
private static bool IsSystemString(TypeReference type)
{
return type.FullName == "System.String";
}
private void ProcessStringMethod(IMethodInvocationExpression node, IMethodReferenceExpression methodRef)
{
IMethodReference method = methodRef.Method;
if (method.Parameters.Count != 1
|| !IsSystemString(method.Parameters[0].ParameterType))
{
UnsupportedExpression(methodRef);
}
switch (method.Name)
{
case "Contains":
PushComparison(methodRef.Target, node.Arguments[0], ComparisonOperator.CONTAINS);
break;
case "StartsWith":
PushComparison(methodRef.Target, node.Arguments[0], ComparisonOperator.STARTSWITH);
break;
case "EndsWith":
PushComparison(methodRef.Target, node.Arguments[0], ComparisonOperator.ENDSWITH);
break;
case "Equals":
PushComparison(methodRef.Target, node.Arguments[0], ComparisonOperator.EQUALS);
break;
default:
UnsupportedExpression(methodRef);
break;
}
}
private void ProcessRegularMethodInvocation(IMethodInvocationExpression node, IMethodReferenceExpression methodRef)
{
if (node.Arguments.Count != 0)
UnsupportedExpression(node);
IExpression target = methodRef.Target;
switch (target.CodeElementType)
{
case CodeElementType.ThisReferenceExpression:
if (!InsideCandidate)
UnsupportedExpression(node);
ProcessCandidateMethodInvocation(node, methodRef);
break;
case CodeElementType.ArgumentReferenceExpression:
ProcessCandidateMethodInvocation(node, methodRef);
break;
default:
Push(ToFieldValue(target));
ProcessCandidateMethodInvocation(node, methodRef);
break;
}
}
private void ProcessOperatorMethodInvocation(IMethodInvocationExpression node, IMethodReference method)
{
switch (method.Name)
{
case "op_Equality":
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.EQUALS);
break;
case "op_Inequality":
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.EQUALS);
Negate();
break;
// XXX: check if the operations below are really supported for the
// data types in question
case "op_GreaterThanOrEqual":
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.SMALLER);
Negate();
break;
case "op_LessThanOrEqual":
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.GREATER);
Negate();
break;
case "op_LessThan":
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.SMALLER);
break;
case "op_GreaterThan":
PushComparison(node.Arguments[0], node.Arguments[1], ComparisonOperator.GREATER);
break;
default:
UnsupportedExpression(node);
break;
}
}
private void ProcessCandidateMethodInvocation(IMethodInvocationExpression node, IMethodReferenceExpression methodRef)
{
IMethodDefinition method = GetMethodDefinition(methodRef);
if (null == method)
UnsupportedExpression(node);
AssertMethodCanBeVisited(node, method);
IExpression expression = GetQueryExpression(method);
if (null == expression)
UnsupportedExpression(node);
EnterCandidateMethod(method);
try
{
Visit(expression);
}
finally
{
LeaveCandidateMethod(method);
}
}
private void AssertMethodCanBeVisited(IMethodInvocationExpression node, IMethodDefinition method)
{
if (_methodDefinitionStack.Contains(method))
UnsupportedExpression(node);
}
private IMethodDefinition GetMethodDefinition(IMethodReferenceExpression methodRef)
{
IMethodDefinition definition = methodRef.Method as IMethodDefinition;
return definition != null
? definition
: LoadExternalMethodDefinition(methodRef);
}
private IMethodDefinition LoadExternalMethodDefinition(IMethodReferenceExpression methodRef)
{
IMethodReference method = methodRef.Method;
IAssemblyDefinition assemblyDef = GetContainingAssembly(method.DeclaringType);
ITypeDefinition type = assemblyDef.MainModule.Types[method.DeclaringType.FullName];
return type.Methods.GetMethod(method.Name, method.Parameters);
}
private IAssemblyDefinition GetContainingAssembly(ITypeReference type)
{
AssemblyNameReference scope = (AssemblyNameReference)type.Scope;
string assemblyName = scope.FullName;
AssemblyDefinition definition = LookupAssembly(assemblyName);
if (null == definition)
{
Assembly assembly = Assembly.Load(assemblyName);
string location = assembly.GetType(type.FullName).Module.FullyQualifiedName;
definition = QueryExpressionBuilder.GetAssembly(location);
RegisterAssembly(definition);
}
return definition;
}
private void EnterCandidateMethod(IMethodDefinition method)
{
EnterMethodDefinition(method);
++_insideCandidate;
}
private void LeaveCandidateMethod(IMethodDefinition method)
{
--_insideCandidate;
LeaveMethodDefinition(method);
}
private static bool IsOperator(IMethodReference method)
{
return !method.HasThis && method.Name.StartsWith("op_") && 2 == method.Parameters.Count;
}
public override void Visit(IFieldReferenceExpression node)
{
IExpression target = node.Target;
switch (target.CodeElementType)
{
case CodeElementType.ArgumentReferenceExpression:
//IArgumentReferenceExpression arg = (IArgumentReferenceExpression)target;
Push(new FieldValue(CandidateFieldRoot.INSTANCE, node.Field.Name, node));
break;
case CodeElementType.ThisReferenceExpression:
if (InsideCandidate)
{
if (_current != null)
{
FieldValue current = PopFieldValue(node);
Push(new FieldValue(current, node.Field.Name, node));
}
else
{
Push(new FieldValue(CandidateFieldRoot.INSTANCE, node.Field.Name, node));
}
}
else
{
Push(new FieldValue(PredicateFieldRoot.INSTANCE, node.Field.Name, node));
}
break;
case CodeElementType.MethodInvocationExpression:
case CodeElementType.FieldReferenceExpression:
FieldValue value = ToFieldValue(target);
Push(new FieldValue(value, node.Field.Name, node));
break;
default:
UnsupportedExpression(node);
break;
}
}
public override void Visit(ILiteralExpression node)
{
Push(new ConstValue(node.Value));
}
Expression Convert(IExpression node)
{
return ReconstructNullComparisonIfNecessary(node);
}
private Expression ReconstructNullComparisonIfNecessary(IExpression node)
{
Visit(node);
object top = Pop();
FieldValue fieldValue = top as FieldValue;
if (fieldValue == null)
{
AssertType(top, typeof(Expression), node);
return (Expression)top;
}
return
new NotExpression(
new ComparisonExpression(
fieldValue,
new ConstValue(null),
ComparisonOperator.EQUALS));
}
FieldValue ToFieldValue(IExpression node)
{
Visit(node);
return PopFieldValue(node);
}
private FieldValue PopFieldValue(IExpression node)
{
return (FieldValue)Pop(node, typeof(FieldValue));
}
void Push(object value)
{
Assert(_current == null, "expression stack must be empty before Push");
_current = value;
}
object Pop(IExpression node, System.Type expectedType)
{
object value = Pop();
AssertType(value, expectedType, node);
return value;
}
private static void AssertType(object value, Type expectedType, IExpression sourceExpression)
{
Type actualType = value.GetType();
if (!expectedType.IsAssignableFrom(actualType))
{
UnsupportedPredicate(
string.Format("Unsupported expression: {0}. Unexpected type on stack. Expected: {1}, Got: {2}.",
ExpressionPrinter.ToString(sourceExpression), expectedType, actualType));
}
}
object Pop()
{
Assert(_current != null, "expression stack is empty");
object value = _current;
_current = null;
return value;
}
private void Assert(bool condition, string message)
{
System.Diagnostics.Debug.Assert(condition, message);
}
}
}
internal class BoxedValueTypeProcessor : TraversingExpressionVisitor
{
override public void Visit(ComparisonExpression expression)
{
TypeReference fieldType = GetFieldType(expression.Left());
if (!fieldType.IsValueType) return;
ConstValue constValue = expression.Right() as ConstValue;
if (constValue == null) return;
AdjustConstValue(fieldType, constValue);
}
private static TypeReference GetFieldType(FieldValue field)
{
return ((FieldReferenceExpression) field.Tag()).Field.FieldType;
}
private void AdjustConstValue(TypeReference typeRef, ConstValue constValue)
{
object value = constValue.Value();
if (!value.GetType().IsValueType) return;
System.Type type = ResolveTypeReference(typeRef);
if (!type.IsEnum || value.GetType() == type) return;
constValue.Value(Enum.ToObject(type, value));
}
private static Type ResolveTypeReference(TypeReference typeRef)
{
Assembly assembly = LoadAssembly(typeRef.Scope);
return assembly.GetType(typeRef.FullName.Replace('/', '+'), true);
}
private static Assembly LoadAssembly(IMetadataScope scope)
{
IAssemblyNameReference nameRef = scope as IAssemblyNameReference;
if (null != nameRef) return Assembly.Load(nameRef.FullName);
IModuleDefinition moduleDef = scope as IModuleDefinition;
return LoadAssembly(moduleDef.Assembly.Name);
}
}
}