comparison Source/Data/Linq/Builder/ExpressionBuilder.cs @ 0:f990fcb411a9

Копия текущей версии из github
author cin
date Thu, 27 Mar 2014 21:46:09 +0400
parents
children f7d63a092920
comparison
equal deleted inserted replaced
-1:000000000000 0:f990fcb411a9
1 using System;
2 using System.Collections;
3 using System.Collections.Generic;
4 using System.Data;
5 using System.Linq;
6 using System.Linq.Expressions;
7 using System.Reflection;
8
9 namespace BLToolkit.Data.Linq.Builder
10 {
11 using BLToolkit.Linq;
12 using Common;
13 using Data.Sql;
14 using Data.Sql.SqlProvider;
15 using Mapping;
16 using Reflection;
17
18 public partial class ExpressionBuilder
19 {
20 #region Sequence
21
22 static readonly object _sync = new object();
23
24 static List<ISequenceBuilder> _sequenceBuilders = new List<ISequenceBuilder>
25 {
26 new TableBuilder (),
27 new SelectBuilder (),
28 new SelectManyBuilder (),
29 new WhereBuilder (),
30 new OrderByBuilder (),
31 new GroupByBuilder (),
32 new JoinBuilder (),
33 new TakeSkipBuilder (),
34 new DefaultIfEmptyBuilder(),
35 new DistinctBuilder (),
36 new FirstSingleBuilder (),
37 new AggregationBuilder (),
38 new ScalarSelectBuilder (),
39 new CountBuilder (),
40 new PassThroughBuilder (),
41 new TableAttributeBuilder(),
42 new InsertBuilder (),
43 new InsertBuilder.Into (),
44 new InsertBuilder.Value (),
45 new InsertOrUpdateBuilder(),
46 new UpdateBuilder (),
47 new UpdateBuilder.Set (),
48 new DeleteBuilder (),
49 new ContainsBuilder (),
50 new AllAnyBuilder (),
51 new ConcatUnionBuilder (),
52 new IntersectBuilder (),
53 new CastBuilder (),
54 new OfTypeBuilder (),
55 new AsUpdatableBuilder (),
56 };
57
58 public static void AddBuilder(ISequenceBuilder builder)
59 {
60 _sequenceBuilders.Add(builder);
61 }
62
63 #endregion
64
65 #region Init
66
67 readonly Query _query;
68 readonly List<ISequenceBuilder> _builders = _sequenceBuilders;
69 private bool _reorder;
70 readonly Dictionary<Expression,Expression> _expressionAccessors;
71 private HashSet<Expression> _subQueryExpressions;
72
73 readonly public List<ParameterAccessor> CurrentSqlParameters = new List<ParameterAccessor>();
74
75 #if FW4 || SILVERLIGHT
76
77 readonly public List<ParameterExpression> BlockVariables = new List<ParameterExpression>();
78 readonly public List<Expression> BlockExpressions = new List<Expression>();
79 public bool IsBlockDisable;
80
81 #else
82 public bool IsBlockDisable = true;
83 #endif
84
85 readonly HashSet<Expression> _visitedExpressions;
86
87 public ExpressionBuilder(
88 Query query,
89 IDataContextInfo dataContext,
90 Expression expression,
91 ParameterExpression[] compiledParameters)
92 {
93 _query = query;
94 _expressionAccessors = expression.GetExpressionAccessors(ExpressionParam);
95
96 CompiledParameters = compiledParameters;
97 DataContextInfo = dataContext;
98 OriginalExpression = expression;
99
100 _visitedExpressions = new HashSet<Expression>();
101 Expression = ConvertExpressionTree(expression);
102 _visitedExpressions = null;
103 }
104
105 #endregion
106
107 #region Public Members
108
109 public readonly IDataContextInfo DataContextInfo;
110 public readonly Expression OriginalExpression;
111 public readonly Expression Expression;
112 public readonly ParameterExpression[] CompiledParameters;
113 public readonly List<IBuildContext> Contexts = new List<IBuildContext>();
114
115 private ISqlProvider _sqlProvider;
116 public ISqlProvider SqlProvider
117 {
118 get { return _sqlProvider ?? (_sqlProvider = DataContextInfo.CreateSqlProvider()); }
119 }
120
121 public static readonly ParameterExpression ContextParam = Expression.Parameter(typeof(QueryContext), "context");
122 public static readonly ParameterExpression DataContextParam = Expression.Parameter(typeof(IDataContext), "dctx");
123 public static readonly ParameterExpression DataReaderParam = Expression.Parameter(typeof(IDataReader), "rd");
124 public static readonly ParameterExpression ParametersParam = Expression.Parameter(typeof(object[]), "ps");
125 public static readonly ParameterExpression ExpressionParam = Expression.Parameter(typeof(Expression), "expr");
126
127 public MappingSchema MappingSchema
128 {
129 get { return DataContextInfo.MappingSchema; }
130 }
131
132 #endregion
133
134 #region Builder SQL
135
136 internal Query<T> Build<T>()
137 {
138 var sequence = BuildSequence(new BuildInfo((IBuildContext)null, Expression, new SqlQuery()));
139
140 if (_reorder)
141 lock (_sync)
142 {
143 _reorder = false;
144 _sequenceBuilders = _sequenceBuilders.OrderByDescending(_ => _.BuildCounter).ToList();
145 }
146
147 _query.Init(sequence, CurrentSqlParameters);
148
149 var param = Expression.Parameter(typeof(Query<T>), "info");
150
151 sequence.BuildQuery((Query<T>)_query, param);
152
153 return (Query<T>)_query;
154 }
155
156 [JetBrains.Annotations.NotNull]
157 public IBuildContext BuildSequence(BuildInfo buildInfo)
158 {
159 buildInfo.Expression = buildInfo.Expression.Unwrap();
160
161 var n = _builders[0].BuildCounter;
162
163 foreach (var builder in _builders)
164 {
165 if (builder.CanBuild(this, buildInfo))
166 {
167 var sequence = builder.BuildSequence(this, buildInfo);
168
169 lock (builder)
170 builder.BuildCounter++;
171
172 _reorder = _reorder || n < builder.BuildCounter;
173
174 return sequence;
175 }
176
177 n = builder.BuildCounter;
178 }
179
180 throw new LinqException("Sequence '{0}' cannot be converted to SQL.", buildInfo.Expression);
181 }
182
183 public SequenceConvertInfo ConvertSequence(BuildInfo buildInfo, ParameterExpression param)
184 {
185 buildInfo.Expression = buildInfo.Expression.Unwrap();
186
187 foreach (var builder in _builders)
188 if (builder.CanBuild(this, buildInfo))
189 return builder.Convert(this, buildInfo, param);
190
191 throw new LinqException("Sequence '{0}' cannot be converted to SQL.", buildInfo.Expression);
192 }
193
194 public bool IsSequence(BuildInfo buildInfo)
195 {
196 buildInfo.Expression = buildInfo.Expression.Unwrap();
197
198 foreach (var builder in _builders)
199 if (builder.CanBuild(this, buildInfo))
200 return builder.IsSequence(this, buildInfo);
201
202 return false;
203 }
204
205 #endregion
206
207 #region ConvertExpression
208
209 public ParameterExpression SequenceParameter;
210
211 Expression ConvertExpressionTree(Expression expression)
212 {
213 var expr = ConvertParameters(expression);
214
215 expr = ExposeExpression (expr);
216 expr = OptimizeExpression(expr);
217
218 var paramType = expr.Type;
219 var isQueryable = false;
220
221 if (expression.NodeType == ExpressionType.Call)
222 {
223 var call = (MethodCallExpression)expression;
224
225 if (call.IsQueryable() && call.Object == null && call.Arguments.Count > 0 && call.Type.IsGenericType)
226 {
227 var type = call.Type.GetGenericTypeDefinition();
228
229 if (type == typeof(IQueryable<>) || type == typeof(IEnumerable<>))
230 {
231 var arg = call.Type.GetGenericArguments();
232
233 if (arg.Length == 1)
234 {
235 paramType = arg[0];
236 isQueryable = true;
237 }
238 }
239 }
240 }
241
242 SequenceParameter = Expression.Parameter(paramType, "cp");
243
244 var sequence = ConvertSequence(new BuildInfo((IBuildContext)null, expr, new SqlQuery()), SequenceParameter);
245
246 if (sequence != null)
247 {
248 if (sequence.Expression.Type != expr.Type)
249 {
250 if (isQueryable)
251 {
252 var p = sequence.ExpressionsToReplace.SingleOrDefault(s => s.Path.NodeType == ExpressionType.Parameter);
253
254 return Expression.Call(
255 ((MethodCallExpression)expr).Method.DeclaringType,
256 "Select",
257 new[] { p.Path.Type, paramType },
258 sequence.Expression,
259 Expression.Lambda(p.Expr, (ParameterExpression)p.Path));
260 }
261
262 throw new InvalidOperationException();
263 }
264
265 return sequence.Expression;
266 }
267
268 return expr;
269 }
270
271 #region ConvertParameters
272
273 Expression ConvertParameters(Expression expression)
274 {
275 return expression.Convert(expr =>
276 {
277 switch (expr.NodeType)
278 {
279 case ExpressionType.Parameter:
280 if (CompiledParameters != null)
281 {
282 var idx = Array.IndexOf(CompiledParameters, (ParameterExpression)expr);
283
284 if (idx > 0)
285 return
286 Expression.Convert(
287 Expression.ArrayIndex(
288 ParametersParam,
289 Expression.Constant(Array.IndexOf(CompiledParameters, (ParameterExpression)expr))),
290 expr.Type);
291 }
292
293 break;
294 }
295
296 return expr;
297 });
298 }
299
300 #endregion
301
302 #region ExposeExpression
303
304 Expression ExposeExpression(Expression expression)
305 {
306 return expression.Convert(expr =>
307 {
308 switch (expr.NodeType)
309 {
310 case ExpressionType.MemberAccess:
311 {
312 var me = (MemberExpression)expr;
313 var l = ConvertMethodExpression(me.Member);
314
315 if (l != null)
316 {
317 var body = l.Body.Unwrap();
318 var ex = body.Convert2(wpi => new ExpressionHelper.ConvertInfo(wpi.NodeType == ExpressionType.Parameter ? me.Expression : wpi));
319
320 if (ex.Type != expr.Type)
321 ex = new ChangeTypeExpression(ex, expr.Type);
322
323 return ExposeExpression(ex);
324 }
325
326 break;
327 }
328
329 case ExpressionType.Constant :
330 {
331 var c = (ConstantExpression)expr;
332
333 // Fix Mono behaviour.
334 //
335 //if (c.Value is IExpressionQuery)
336 // return ((IQueryable)c.Value).Expression;
337
338 if (c.Value is IQueryable && !(c.Value is ITable))
339 {
340 var e = ((IQueryable)c.Value).Expression;
341
342 if (!_visitedExpressions.Contains(e))
343 {
344 _visitedExpressions.Add(e);
345 return ExposeExpression(e);
346 }
347 }
348
349 break;
350 }
351 }
352
353 return expr;
354 });
355 }
356
357 #endregion
358
359 #region OptimizeExpression
360
361 private MethodInfo[] _enumerableMethods;
362 public MethodInfo[] EnumerableMethods
363 {
364 get { return _enumerableMethods ?? (_enumerableMethods = typeof(Enumerable).GetMethods()); }
365 }
366
367 private MethodInfo[] _queryableMethods;
368 public MethodInfo[] QueryableMethods
369 {
370 get { return _queryableMethods ?? (_queryableMethods = typeof(Queryable).GetMethods()); }
371 }
372
373 readonly Dictionary<Expression,Expression> _optimizedExpressions = new Dictionary<Expression,Expression>();
374
375 Expression OptimizeExpression(Expression expression)
376 {
377 Expression expr;
378
379 if (_optimizedExpressions.TryGetValue(expression, out expr))
380 return expr;
381
382 _optimizedExpressions[expression] = expr = expression.Convert(OptimizeExpressionImpl);
383
384 return expr;
385 }
386
387 Expression OptimizeExpressionImpl(Expression expr)
388 {
389 switch (expr.NodeType)
390 {
391 case ExpressionType.MemberAccess:
392 {
393 var me = (MemberExpression)expr;
394
395 // Replace Count with Count()
396 //
397 if (me.Member.Name == "Count")
398 {
399 var isList = typeof(ICollection).IsAssignableFrom(me.Member.DeclaringType);
400
401 if (!isList)
402 isList = me.Member.DeclaringType.GetInterfaces()
403 .Any(t => t.IsGenericType && t.GetGenericTypeDefinition() == typeof(IList<>));
404
405 if (isList)
406 {
407 var mi = EnumerableMethods
408 .First(m => m.Name == "Count" && m.GetParameters().Length == 1)
409 .MakeGenericMethod(TypeHelper.GetElementType(me.Expression.Type));
410
411 return Expression.Call(null, mi, me.Expression);
412 }
413 }
414
415 if (CompiledParameters == null && TypeHelper.IsSameOrParent(typeof(IQueryable), expr.Type))
416 {
417 var ex = ConvertIQueriable(expr);
418
419 if (ex != expr)
420 return ConvertExpressionTree(ex);
421 }
422
423 return ConvertSubquery(expr);
424 }
425
426 case ExpressionType.Call :
427 {
428 var call = (MethodCallExpression)expr;
429
430 if (call.IsQueryable())
431 {
432 switch (call.Method.Name)
433 {
434 case "Where" : return ConvertWhere (call);
435 case "GroupBy" : return ConvertGroupBy (call);
436 case "SelectMany" : return ConvertSelectMany(call);
437 case "Select" : return ConvertSelect (call);
438 case "LongCount" :
439 case "Count" :
440 case "Single" :
441 case "SingleOrDefault" :
442 case "First" :
443 case "FirstOrDefault" : return ConvertPredicate (call);
444 case "Min" :
445 case "Max" : return ConvertSelector (call, true);
446 case "Sum" :
447 case "Average" : return ConvertSelector (call, false);
448 case "ElementAt" :
449 case "ElementAtOrDefault" : return ConvertElementAt (call);
450 }
451 }
452 else
453 {
454 var l = ConvertMethodExpression(call.Method);
455
456 if (l != null)
457 return OptimizeExpression(ConvertMethod(call, l));
458
459 if (CompiledParameters == null && TypeHelper.IsSameOrParent(typeof(IQueryable), expr.Type))
460 {
461 var attr = GetTableFunctionAttribute(call.Method);
462
463 if (attr == null)
464 {
465 var ex = ConvertIQueriable(expr);
466
467 if (ex != expr)
468 return ConvertExpressionTree(ex);
469 }
470 }
471 }
472
473 return ConvertSubquery(expr);
474 }
475 }
476
477 return expr;
478 }
479
480 LambdaExpression ConvertMethodExpression(MemberInfo mi)
481 {
482 var attrs = mi.GetCustomAttributes(typeof(MethodExpressionAttribute), true);
483
484 if (attrs.Length == 0)
485 return null;
486
487 MethodExpressionAttribute attr = null;
488
489 foreach (MethodExpressionAttribute a in attrs)
490 {
491 if (a.SqlProvider == SqlProvider.Name)
492 {
493 attr = a;
494 break;
495 }
496
497 if (a.SqlProvider == null)
498 attr = a;
499 }
500
501 if (attr != null)
502 {
503 Expression expr;
504
505 if (mi is MethodInfo && ((MethodInfo)mi).IsGenericMethod)
506 {
507 var method = (MethodInfo)mi;
508 var args = method.GetGenericArguments();
509 var names = args.Select(t => t.Name).ToArray();
510 var name = string.Format(attr.MethodName, names);
511
512 if (name != attr.MethodName)
513 expr = Expression.Call(mi.DeclaringType, name, Array<Type>.Empty);
514 else
515 expr = Expression.Call(mi.DeclaringType, name, args);
516 }
517 else
518 {
519 expr = Expression.Call(mi.DeclaringType, attr.MethodName, Array<Type>.Empty);
520 }
521
522 var call = Expression.Lambda<Func<LambdaExpression>>(Expression.Convert(expr, typeof(LambdaExpression)));
523
524 return call.Compile()();
525 }
526
527 return null;
528 }
529
530 Expression ConvertSubquery(Expression expr)
531 {
532 var ex = expr;
533
534 while (ex != null)
535 {
536 switch (ex.NodeType)
537 {
538 default : return expr;
539 case ExpressionType.MemberAccess : ex = ((MemberExpression)ex).Expression; break;
540 case ExpressionType.Call :
541 {
542 var call = (MethodCallExpression)ex;
543
544 if (call.Object == null)
545 {
546 if (call.IsQueryable()) switch (call.Method.Name)
547 {
548 case "Single" :
549 case "SingleOrDefault" :
550 case "First" :
551 case "FirstOrDefault" :
552 return ConvertSingleOrFirst(expr, call);
553 }
554
555 return expr;
556 }
557
558 ex = call.Object;
559
560 break;
561 }
562 }
563 }
564
565 return expr;
566 }
567
568 Expression ConvertSingleOrFirst(Expression expr, MethodCallExpression call)
569 {
570 var param = Expression.Parameter(call.Type, "p");
571 var selector = expr.Convert(e => e == call ? param : e);
572 var method = GetQueriableMethodInfo(call, (m, _) => m.Name == call.Method.Name && m.GetParameters().Length == 1);
573 var select = call.Method.DeclaringType == typeof(Enumerable) ?
574 EnumerableMethods
575 .Where(m => m.Name == "Select" && m.GetParameters().Length == 2)
576 .First(m => m.GetParameters()[1].ParameterType.GetGenericArguments().Length == 2) :
577 QueryableMethods
578 .Where(m => m.Name == "Select" && m.GetParameters().Length == 2)
579 .First(m => m.GetParameters()[1].ParameterType.GetGenericArguments()[0].GetGenericArguments().Length == 2);
580
581 call = (MethodCallExpression)OptimizeExpression(call);
582 select = select.MakeGenericMethod(call.Type, expr.Type);
583 method = method.MakeGenericMethod(expr.Type);
584
585 return Expression.Call(null, method,
586 Expression.Call(null, select,
587 call.Arguments[0],
588 Expression.Lambda(selector, param)));
589 }
590
591 #endregion
592
593 #region ConvertWhere
594
595 Expression ConvertWhere(MethodCallExpression method)
596 {
597 var sequence = OptimizeExpression(method.Arguments[0]);
598 var predicate = OptimizeExpression(method.Arguments[1]);
599 var lambda = (LambdaExpression)predicate.Unwrap();
600 var lparam = lambda.Parameters[0];
601 var lbody = lambda.Body;
602
603 if (lambda.Parameters.Count > 1)
604 return method;
605
606 var exprs = new List<Expression>();
607
608 lbody.Visit(ex =>
609 {
610 if (ex.NodeType == ExpressionType.Call)
611 {
612 var call = (MethodCallExpression)ex;
613
614 if (call.Arguments.Count > 0)
615 {
616 var arg = call.Arguments[0];
617
618 if (call.IsQueryable(AggregationBuilder.MethodNames))
619 {
620 while (arg.NodeType == ExpressionType.Call && ((MethodCallExpression) arg).Method.Name == "Select")
621 arg = ((MethodCallExpression) arg).Arguments[0];
622
623 if (arg.NodeType == ExpressionType.Call)
624 exprs.Add(ex);
625 }
626 else if (call.IsQueryable(CountBuilder.MethodNames))
627 {
628 //while (arg.NodeType == ExpressionType.Call && ((MethodCallExpression) arg).Method.Name == "Select")
629 // arg = ((MethodCallExpression) arg).Arguments[0];
630
631 if (arg.NodeType == ExpressionType.Call)
632 exprs.Add(ex);
633 }
634 }
635 }
636 });
637
638 Expression expr = null;
639
640 if (exprs.Count > 0)
641 {
642 expr = lparam;
643
644 foreach (var ex in exprs)
645 {
646 var type = typeof(ExpressionHoder<,>).MakeGenericType(expr.Type, ex.Type);
647 var fields = type.GetFields();
648
649 expr = Expression.MemberInit(
650 Expression.New(type),
651 Expression.Bind(fields[0], expr),
652 Expression.Bind(fields[1], ex));
653 }
654
655 var dic = new Dictionary<Expression, Expression>();
656 var parm = Expression.Parameter(expr.Type, lparam.Name);
657
658 for (var i = 0; i < exprs.Count; i++)
659 {
660 Expression ex = parm;
661
662 for (var j = i; j < exprs.Count - 1; j++)
663 ex = Expression.PropertyOrField(ex, "p");
664
665 ex = Expression.PropertyOrField(ex, "ex");
666
667 dic.Add(exprs[i], ex);
668
669 if (_subQueryExpressions == null)
670 _subQueryExpressions = new HashSet<Expression>();
671 _subQueryExpressions.Add(ex);
672 }
673
674 var newBody = lbody.Convert(ex =>
675 {
676 Expression e;
677 return dic.TryGetValue(ex, out e) ? e : ex;
678 });
679
680 var nparm = exprs.Aggregate<Expression,Expression>(parm, (c,t) => Expression.PropertyOrField(c, "p"));
681
682 newBody = newBody.Convert(ex => ex == lparam ? nparm : ex);
683
684 predicate = Expression.Lambda(newBody, parm);
685
686 var methodInfo = GetMethodInfo(method, "Select");
687
688 methodInfo = methodInfo.MakeGenericMethod(lparam.Type, expr.Type);
689 sequence = Expression.Call(methodInfo, sequence, Expression.Lambda(expr, lparam));
690 }
691
692 if (sequence != method.Arguments[0] || predicate != method.Arguments[1])
693 {
694 var methodInfo = method.Method.GetGenericMethodDefinition();
695 var genericType = sequence.Type.GetGenericArguments()[0];
696 var newMethod = methodInfo.MakeGenericMethod(genericType);
697
698 method = Expression.Call(newMethod, sequence, predicate);
699
700 if (exprs.Count > 0)
701 {
702 var parameter = Expression.Parameter(expr.Type, lparam.Name);
703
704 methodInfo = GetMethodInfo(method, "Select");
705 methodInfo = methodInfo.MakeGenericMethod(expr.Type, lparam.Type);
706 method = Expression.Call(methodInfo, method,
707 Expression.Lambda(
708 exprs.Aggregate((Expression)parameter, (current,_) => Expression.PropertyOrField(current, "p")),
709 parameter));
710 }
711 }
712
713 return method;
714 }
715
716 #endregion
717
718 #region ConvertGroupBy
719
720 public class GroupSubQuery<TKey,TElement>
721 {
722 public TKey Key;
723 public TElement Element;
724 }
725
726 interface IGroupByHelper
727 {
728 void Set(bool wrapInSubQuery, Expression sourceExpression, LambdaExpression keySelector, LambdaExpression elementSelector, LambdaExpression resultSelector);
729
730 Expression AddElementSelectorQ ();
731 Expression AddElementSelectorE ();
732 Expression AddResultQ ();
733 Expression AddResultE ();
734 Expression WrapInSubQueryQ ();
735 Expression WrapInSubQueryE ();
736 Expression WrapInSubQueryResultQ();
737 Expression WrapInSubQueryResultE();
738 }
739
740 class GroupByHelper<TSource,TKey,TElement,TResult> : IGroupByHelper
741 {
742 bool _wrapInSubQuery;
743 Expression _sourceExpression;
744 LambdaExpression _keySelector;
745 LambdaExpression _elementSelector;
746 LambdaExpression _resultSelector;
747
748 public void Set(
749 bool wrapInSubQuery,
750 Expression sourceExpression,
751 LambdaExpression keySelector,
752 LambdaExpression elementSelector,
753 LambdaExpression resultSelector)
754 {
755 _wrapInSubQuery = wrapInSubQuery;
756 _sourceExpression = sourceExpression;
757 _keySelector = keySelector;
758 _elementSelector = elementSelector;
759 _resultSelector = resultSelector;
760 }
761
762 public Expression AddElementSelectorQ()
763 {
764 Expression<Func<IQueryable<TSource>,TKey,TElement,TResult,IQueryable<IGrouping<TKey,TSource>>>> func = (source,key,e,r) => source
765 .GroupBy(keyParam => key, _ => _)
766 ;
767
768 var body = func.Body.Unwrap();
769 var keyArg = GetLambda(body, 1).Parameters[0]; // .GroupBy(keyParam
770
771 return Convert(func, keyArg, null, null);
772 }
773
774 public Expression AddElementSelectorE()
775 {
776 Expression<Func<IEnumerable<TSource>,TKey,TElement,TResult,IEnumerable<IGrouping<TKey,TSource>>>> func = (source,key,e,r) => source
777 .GroupBy(keyParam => key, _ => _)
778 ;
779
780 var body = func.Body.Unwrap();
781 var keyArg = GetLambda(body, 1).Parameters[0]; // .GroupBy(keyParam
782
783 return Convert(func, keyArg, null, null);
784 }
785
786 public Expression AddResultQ()
787 {
788 Expression<Func<IQueryable<TSource>,TKey,TElement,TResult,IQueryable<TResult>>> func = (source,key,e,r) => source
789 .GroupBy(keyParam => key, elemParam => e)
790 .Select (resParam => r)
791 ;
792
793 var body = func.Body.Unwrap();
794 var keyArg = GetLambda(body, 0, 1).Parameters[0]; // .GroupBy(keyParam
795 var elemArg = GetLambda(body, 0, 2).Parameters[0]; // .GroupBy(..., elemParam
796 var resArg = GetLambda(body, 1). Parameters[0]; // .Select (resParam
797
798 return Convert(func, keyArg, elemArg, resArg);
799 }
800
801 public Expression AddResultE()
802 {
803 Expression<Func<IEnumerable<TSource>,TKey,TElement,TResult,IEnumerable<TResult>>> func = (source,key,e,r) => source
804 .GroupBy(keyParam => key, elemParam => e)
805 .Select (resParam => r)
806 ;
807
808 var body = func.Body.Unwrap();
809 var keyArg = GetLambda(body, 0, 1).Parameters[0]; // .GroupBy(keyParam
810 var elemArg = GetLambda(body, 0, 2).Parameters[0]; // .GroupBy(..., elemParam
811 var resArg = GetLambda(body, 1). Parameters[0]; // .Select (resParam
812
813 return Convert(func, keyArg, elemArg, resArg);
814 }
815
816 public Expression WrapInSubQueryQ()
817 {
818 Expression<Func<IQueryable<TSource>,TKey,TElement,TResult,IQueryable<IGrouping<TKey,TElement>>>> func = (source,key,e,r) => source
819 .Select(selectParam => new GroupSubQuery<TKey,TSource>
820 {
821 Key = key,
822 Element = selectParam
823 })
824 .GroupBy(_ => _.Key, elemParam => e)
825 ;
826
827 var body = func.Body.Unwrap();
828 var keyArg = GetLambda(body, 0, 1).Parameters[0]; // .Select (selectParam
829 var elemArg = GetLambda(body, 2). Parameters[0]; // .GroupBy(..., elemParam
830
831 return Convert(func, keyArg, elemArg, null);
832 }
833
834 public Expression WrapInSubQueryE()
835 {
836 Expression<Func<IEnumerable<TSource>,TKey,TElement,TResult,IEnumerable<IGrouping<TKey,TElement>>>> func = (source,key,e,r) => source
837 .Select(selectParam => new GroupSubQuery<TKey,TSource>
838 {
839 Key = key,
840 Element = selectParam
841 })
842 .GroupBy(_ => _.Key, elemParam => e)
843 ;
844
845 var body = func.Body.Unwrap();
846 var keyArg = GetLambda(body, 0, 1).Parameters[0]; // .Select (selectParam
847 var elemArg = GetLambda(body, 2). Parameters[0]; // .GroupBy(..., elemParam
848
849 return Convert(func, keyArg, elemArg, null);
850 }
851
852 public Expression WrapInSubQueryResultQ()
853 {
854 Expression<Func<IQueryable<TSource>,TKey,TElement,TResult,IQueryable<TResult>>> func = (source,key,e,r) => source
855 .Select(selectParam => new GroupSubQuery<TKey,TSource>
856 {
857 Key = key,
858 Element = selectParam
859 })
860 .GroupBy(_ => _.Key, elemParam => e)
861 .Select (resParam => r)
862 ;
863
864 var body = func.Body.Unwrap();
865 var keyArg = GetLambda(body, 0, 0, 1).Parameters[0]; // .Select (selectParam
866 var elemArg = GetLambda(body, 0, 2). Parameters[0]; // .GroupBy(..., elemParam
867 var resArg = GetLambda(body, 1). Parameters[0]; // .Select (resParam
868
869 return Convert(func, keyArg, elemArg, resArg);
870 }
871
872 public Expression WrapInSubQueryResultE()
873 {
874 Expression<Func<IEnumerable<TSource>,TKey,TElement,TResult,IEnumerable<TResult>>> func = (source,key,e,r) => source
875 .Select(selectParam => new GroupSubQuery<TKey,TSource>
876 {
877 Key = key,
878 Element = selectParam
879 })
880 .GroupBy(_ => _.Key, elemParam => e)
881 .Select (resParam => r)
882 ;
883
884 var body = func.Body.Unwrap();
885 var keyArg = GetLambda(body, 0, 0, 1).Parameters[0]; // .Select (selectParam
886 var elemArg = GetLambda(body, 0, 2). Parameters[0]; // .GroupBy(..., elemParam
887 var resArg = GetLambda(body, 1). Parameters[0]; // .Select (resParam
888
889 return Convert(func, keyArg, elemArg, resArg);
890 }
891
892 Expression Convert(
893 LambdaExpression func,
894 ParameterExpression keyArg,
895 ParameterExpression elemArg,
896 ParameterExpression resArg)
897 {
898 var body = func.Body.Unwrap();
899 var expr = body.Convert(ex =>
900 {
901 if (ex == func.Parameters[0])
902 return _sourceExpression;
903
904 if (ex == func.Parameters[1])
905 return _keySelector.Body.Convert(e => e == _keySelector.Parameters[0] ? keyArg : e);
906
907 if (ex == func.Parameters[2])
908 {
909 Expression obj = elemArg;
910
911 if (_wrapInSubQuery)
912 obj = Expression.PropertyOrField(elemArg, "Element");
913
914 if (_elementSelector == null)
915 return obj;
916
917 return _elementSelector.Body.Convert(e => e == _elementSelector.Parameters[0] ? obj : e);
918 }
919
920 if (ex == func.Parameters[3])
921 return _resultSelector.Body.Convert(e =>
922 {
923 if (e == _resultSelector.Parameters[0])
924 return Expression.PropertyOrField(resArg, "Key");
925
926 if (e == _resultSelector.Parameters[1])
927 return resArg;
928
929 return e;
930 });
931
932 return ex;
933 });
934
935 return expr;
936 }
937 }
938
939 static LambdaExpression GetLambda(Expression expression, params int[] n)
940 {
941 foreach (var i in n)
942 expression = ((MethodCallExpression)expression).Arguments[i].Unwrap();
943 return (LambdaExpression)expression;
944 }
945
946 Expression ConvertGroupBy(MethodCallExpression method)
947 {
948 if (method.Arguments[method.Arguments.Count - 1].Unwrap().NodeType != ExpressionType.Lambda)
949 return method;
950
951 var types = method.Method.GetGenericMethodDefinition().GetGenericArguments()
952 .Zip(method.Method.GetGenericArguments(), (n, t) => new { n = n.Name, t })
953 .ToDictionary(_ => _.n, _ => _.t);
954
955 var sourceExpression = OptimizeExpression(method.Arguments[0].Unwrap());
956 var keySelector = (LambdaExpression)OptimizeExpression(method.Arguments[1].Unwrap());
957 var elementSelector = types.ContainsKey("TElement") ? (LambdaExpression)OptimizeExpression(method.Arguments[2].Unwrap()) : null;
958 var resultSelector = types.ContainsKey("TResult") ?
959 (LambdaExpression)OptimizeExpression(method.Arguments[types.ContainsKey("TElement") ? 3 : 2].Unwrap()) : null;
960
961 var needSubQuery = null != ConvertExpression(keySelector.Body.Unwrap()).Find(IsExpression);
962
963 if (!needSubQuery && resultSelector == null && elementSelector != null)
964 return method;
965
966 var gtype = typeof(GroupByHelper<,,,>).MakeGenericType(
967 types["TSource"],
968 types["TKey"],
969 types.ContainsKey("TElement") ? types["TElement"] : types["TSource"],
970 types.ContainsKey("TResult") ? types["TResult"] : types["TSource"]);
971
972 var helper =
973 //Expression.Lambda<Func<IGroupByHelper>>(
974 // Expression.Convert(Expression.New(gtype), typeof(IGroupByHelper)))
975 //.Compile()();
976 (IGroupByHelper)Activator.CreateInstance(gtype);
977
978 helper.Set(needSubQuery, sourceExpression, keySelector, elementSelector, resultSelector);
979
980 if (method.Method.DeclaringType == typeof(Queryable))
981 {
982 if (!needSubQuery)
983 return resultSelector == null ? helper.AddElementSelectorQ() : helper.AddResultQ();
984
985 return resultSelector == null ? helper.WrapInSubQueryQ() : helper.WrapInSubQueryResultQ();
986 }
987 else
988 {
989 if (!needSubQuery)
990 return resultSelector == null ? helper.AddElementSelectorE() : helper.AddResultE();
991
992 return resultSelector == null ? helper.WrapInSubQueryE() : helper.WrapInSubQueryResultE();
993 }
994 }
995
996 bool IsExpression(Expression ex)
997 {
998 switch (ex.NodeType)
999 {
1000 case ExpressionType.Convert :
1001 case ExpressionType.ConvertChecked :
1002 case ExpressionType.MemberInit :
1003 case ExpressionType.New :
1004 case ExpressionType.NewArrayBounds :
1005 case ExpressionType.NewArrayInit :
1006 case ExpressionType.Parameter : return false;
1007 case ExpressionType.MemberAccess :
1008 {
1009 var ma = (MemberExpression)ex;
1010 var attr = GetFunctionAttribute(ma.Member);
1011
1012 if (attr != null)
1013 return true;
1014
1015 return false;
1016 }
1017 }
1018
1019 return true;
1020 }
1021
1022 #endregion
1023
1024 #region ConvertSelectMany
1025
1026 interface ISelectManyHelper
1027 {
1028 void Set(Expression sourceExpression, LambdaExpression colSelector);
1029
1030 Expression AddElementSelectorQ();
1031 Expression AddElementSelectorE();
1032 }
1033
1034 class SelectManyHelper<TSource,TCollection> : ISelectManyHelper
1035 {
1036 Expression _sourceExpression;
1037 LambdaExpression _colSelector;
1038
1039 public void Set(Expression sourceExpression, LambdaExpression colSelector)
1040 {
1041 _sourceExpression = sourceExpression;
1042 _colSelector = colSelector;
1043 }
1044
1045 public Expression AddElementSelectorQ()
1046 {
1047 Expression<Func<IQueryable<TSource>,IEnumerable<TCollection>,IQueryable<TCollection>>> func = (source,col) => source
1048 .SelectMany(colParam => col, (s,c) => c)
1049 ;
1050
1051 var body = func.Body.Unwrap();
1052 var colArg = GetLambda(body, 1).Parameters[0]; // .SelectMany(colParam
1053
1054 return Convert(func, colArg);
1055 }
1056
1057 public Expression AddElementSelectorE()
1058 {
1059 Expression<Func<IEnumerable<TSource>,IEnumerable<TCollection>,IEnumerable<TCollection>>> func = (source,col) => source
1060 .SelectMany(colParam => col, (s,c) => c)
1061 ;
1062
1063 var body = func.Body.Unwrap();
1064 var colArg = GetLambda(body, 1).Parameters[0]; // .SelectMany(colParam
1065
1066 return Convert(func, colArg);
1067 }
1068
1069 Expression Convert(LambdaExpression func, ParameterExpression colArg)
1070 {
1071 var body = func.Body.Unwrap();
1072 var expr = body.Convert(ex =>
1073 {
1074 if (ex == func.Parameters[0])
1075 return _sourceExpression;
1076
1077 if (ex == func.Parameters[1])
1078 return _colSelector.Body.Convert(e => e == _colSelector.Parameters[0] ? colArg : e);
1079
1080 return ex;
1081 });
1082
1083 return expr;
1084 }
1085 }
1086
1087 Expression ConvertSelectMany(MethodCallExpression method)
1088 {
1089 if (method.Arguments.Count != 2 || ((LambdaExpression)method.Arguments[1].Unwrap()).Parameters.Count != 1)
1090 return method;
1091
1092 var types = method.Method.GetGenericMethodDefinition().GetGenericArguments()
1093 .Zip(method.Method.GetGenericArguments(), (n, t) => new { n = n.Name, t })
1094 .ToDictionary(_ => _.n, _ => _.t);
1095
1096 var sourceExpression = OptimizeExpression(method.Arguments[0].Unwrap());
1097 var colSelector = (LambdaExpression)OptimizeExpression(method.Arguments[1].Unwrap());
1098
1099 var gtype = typeof(SelectManyHelper<,>).MakeGenericType(types["TSource"], types["TResult"]);
1100 var helper =
1101 //Expression.Lambda<Func<ISelectManyHelper>>(
1102 // Expression.Convert(Expression.New(gtype), typeof(ISelectManyHelper)))
1103 //.Compile()();
1104 (ISelectManyHelper)Activator.CreateInstance(gtype);
1105
1106 helper.Set(sourceExpression, colSelector);
1107
1108 return method.Method.DeclaringType == typeof(Queryable) ?
1109 helper.AddElementSelectorQ() :
1110 helper.AddElementSelectorE();
1111 }
1112
1113 #endregion
1114
1115 #region ConvertPredicate
1116
1117 Expression ConvertPredicate(MethodCallExpression method)
1118 {
1119 if (method.Arguments.Count != 2)
1120 return method;
1121
1122 var cm = GetQueriableMethodInfo(method, (m,_) => m.Name == method.Method.Name && m.GetParameters().Length == 1);
1123 var wm = GetMethodInfo(method, "Where");
1124
1125 var argType = method.Method.GetGenericArguments()[0];
1126
1127 wm = wm.MakeGenericMethod(argType);
1128 cm = cm.MakeGenericMethod(argType);
1129
1130 return Expression.Call(null, cm,
1131 Expression.Call(null, wm,
1132 OptimizeExpression(method.Arguments[0]),
1133 OptimizeExpression(method.Arguments[1])));
1134 }
1135
1136 #endregion
1137
1138 #region ConvertSelector
1139
1140 Expression ConvertSelector(MethodCallExpression method, bool isGeneric)
1141 {
1142 if (method.Arguments.Count != 2)
1143 return method;
1144
1145 isGeneric = isGeneric && method.Method.DeclaringType == typeof(Queryable);
1146
1147 var types = GetMethodGenericTypes(method);
1148 var sm = GetMethodInfo(method, "Select");
1149 var cm = GetQueriableMethodInfo(method, (m,isDefault) =>
1150 {
1151 if (m.Name == method.Method.Name)
1152 {
1153 var ps = m.GetParameters();
1154
1155 if (ps.Length == 1)
1156 {
1157 if (isGeneric)
1158 return true;
1159
1160 var ts = ps[0].ParameterType.GetGenericArguments();
1161 return ts[0] == types[1] || isDefault && ts[0].IsGenericParameter;
1162 }
1163 }
1164
1165 return false;
1166 });
1167
1168 var argType = types[0];
1169
1170 sm = sm.MakeGenericMethod(argType, types[1]);
1171
1172 if (cm.IsGenericMethodDefinition)
1173 cm = cm.MakeGenericMethod(types[1]);
1174
1175 return Expression.Call(null, cm,
1176 OptimizeExpression(Expression.Call(null, sm,
1177 method.Arguments[0],
1178 method.Arguments[1])));
1179 }
1180
1181 #endregion
1182
1183 #region ConvertSelect
1184
1185 Expression ConvertSelect(MethodCallExpression method)
1186 {
1187 var sequence = OptimizeExpression(method.Arguments[0]);
1188 var lambda1 = (LambdaExpression)method.Arguments[1].Unwrap();
1189 var lambda = (LambdaExpression)OptimizeExpression(lambda1);
1190
1191 if (lambda1.Parameters.Count > 1 ||
1192 sequence.NodeType != ExpressionType.Call ||
1193 ((MethodCallExpression)sequence).Method.Name != method.Method.Name)
1194 {
1195 return method;
1196 }
1197
1198 var slambda = (LambdaExpression)((MethodCallExpression)sequence).Arguments[1].Unwrap();
1199 var sbody = slambda.Body.Unwrap();
1200
1201 if (slambda.Parameters.Count > 1 || sbody.NodeType != ExpressionType.MemberAccess)
1202 return method;
1203
1204 var types1 = GetMethodGenericTypes((MethodCallExpression)sequence);
1205 var types2 = GetMethodGenericTypes(method);
1206
1207 var expr = Expression.Call(null,
1208 GetMethodInfo(method, "Select").MakeGenericMethod(types1[0], types2[1]),
1209 ((MethodCallExpression)sequence).Arguments[0],
1210 Expression.Lambda(
1211 lambda.Body.Convert(ex => ex == lambda.Parameters[0] ? sbody : ex),
1212 slambda.Parameters[0]));
1213
1214 return expr;
1215 }
1216
1217 #endregion
1218
1219 #region ConvertIQueriable
1220
1221 Expression ConvertIQueriable(Expression expression)
1222 {
1223 if (expression.NodeType == ExpressionType.MemberAccess || expression.NodeType == ExpressionType.Call)
1224 {
1225 var p = Expression.Parameter(typeof(Expression), "exp");
1226 var exas = expression.GetExpressionAccessors(p);
1227 var expr = ReplaceParameter(exas, expression, _ => {});
1228
1229 if (expr.Find(e => e.NodeType == ExpressionType.Parameter && e != p) != null)
1230 return expression;
1231
1232 var l = Expression.Lambda<Func<Expression,IQueryable>>(Expression.Convert(expr, typeof(IQueryable)), new [] { p });
1233 var n = _query.AddQueryableAccessors(expression, l);
1234
1235 Expression accessor;
1236
1237 _expressionAccessors.TryGetValue(expression, out accessor);
1238
1239 var path =
1240 Expression.Call(
1241 Expression.Constant(_query),
1242 ReflectionHelper.Expressor<Query>.MethodExpressor(a => a.GetIQueryable(0, null)),
1243 new[] { Expression.Constant(n), accessor ?? Expression.Constant(null, typeof(Expression)) });
1244
1245 var qex = _query.GetIQueryable(n, expression);
1246
1247 if (expression.NodeType == ExpressionType.Call && qex.NodeType == ExpressionType.Call)
1248 {
1249 var m1 = (MethodCallExpression)expression;
1250 var m2 = (MethodCallExpression)qex;
1251
1252 if (m1.Method == m2.Method)
1253 return expression;
1254 }
1255
1256 foreach (var a in qex.GetExpressionAccessors(path))
1257 if (!_expressionAccessors.ContainsKey(a.Key))
1258 _expressionAccessors.Add(a.Key, a.Value);
1259
1260 return qex;
1261 }
1262
1263 throw new InvalidOperationException();
1264 }
1265
1266 #endregion
1267
1268 #region ConvertElementAt
1269
1270 Expression ConvertElementAt(MethodCallExpression method)
1271 {
1272 var sequence = OptimizeExpression(method.Arguments[0]);
1273 var index = OptimizeExpression(method.Arguments[1]).Unwrap();
1274 var sourceType = method.Method.GetGenericArguments()[0];
1275
1276 MethodInfo skipMethod;
1277
1278 if (index.NodeType == ExpressionType.Lambda)
1279 {
1280 skipMethod = ReflectionHelper.Expressor<object>.MethodExpressor(o => LinqExtensions.Skip<object>(null, null));
1281 skipMethod = skipMethod.GetGenericMethodDefinition();
1282 }
1283 else
1284 {
1285 skipMethod = GetQueriableMethodInfo(method, (mi,_) => mi.Name == "Skip");
1286 }
1287
1288 skipMethod = skipMethod.MakeGenericMethod(sourceType);
1289
1290 var methodName = method.Method.Name == "ElementAt" ? "First" : "FirstOrDefault";
1291 var firstMethod = GetQueriableMethodInfo(method, (mi,_) => mi.Name == methodName && mi.GetParameters().Length == 1);
1292
1293 firstMethod = firstMethod.MakeGenericMethod(sourceType);
1294
1295 return Expression.Call(null, firstMethod, Expression.Call(null, skipMethod, sequence, index));
1296 }
1297
1298 #endregion
1299
1300 #region Helpers
1301
1302 MethodInfo GetQueriableMethodInfo(MethodCallExpression method, Func<MethodInfo,bool,bool> predicate)
1303 {
1304 return method.Method.DeclaringType == typeof(Enumerable) ?
1305 EnumerableMethods.FirstOrDefault(m => predicate(m, false)) ?? EnumerableMethods.First(m => predicate(m, true)):
1306 QueryableMethods. FirstOrDefault(m => predicate(m, false)) ?? QueryableMethods. First(m => predicate(m, true));
1307 }
1308
1309 MethodInfo GetMethodInfo(MethodCallExpression method, string name)
1310 {
1311 return method.Method.DeclaringType == typeof(Enumerable) ?
1312 EnumerableMethods
1313 .Where(m => m.Name == name && m.GetParameters().Length == 2)
1314 .First(m => m.GetParameters()[1].ParameterType.GetGenericArguments().Length == 2) :
1315 QueryableMethods
1316 .Where(m => m.Name == name && m.GetParameters().Length == 2)
1317 .First(m => m.GetParameters()[1].ParameterType.GetGenericArguments()[0].GetGenericArguments().Length == 2);
1318 }
1319
1320 static Type[] GetMethodGenericTypes(MethodCallExpression method)
1321 {
1322 return method.Method.DeclaringType == typeof(Enumerable) ?
1323 method.Method.GetParameters()[1].ParameterType.GetGenericArguments() :
1324 method.Method.GetParameters()[1].ParameterType.GetGenericArguments()[0].GetGenericArguments();
1325 }
1326
1327 #endregion
1328
1329 #endregion
1330 }
1331 }