Mercurial > pub > bltoolkit
diff Source/Data/Linq/Builder/JoinBuilder.cs @ 0:f990fcb411a9
Копия текущей версии из github
author | cin |
---|---|
date | Thu, 27 Mar 2014 21:46:09 +0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Source/Data/Linq/Builder/JoinBuilder.cs Thu Mar 27 21:46:09 2014 +0400 @@ -0,0 +1,397 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; + +namespace BLToolkit.Data.Linq.Builder +{ + using BLToolkit.Linq; + using Data.Sql; + + class JoinBuilder : MethodCallBuilder + { + protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + if (!methodCall.IsQueryable("Join", "GroupJoin") || methodCall.Arguments.Count != 5) + return false; + + var body = ((LambdaExpression)methodCall.Arguments[2].Unwrap()).Body.Unwrap(); + + if (body.NodeType == ExpressionType .MemberInit) + { + var mi = (MemberInitExpression)body; + bool throwExpr; + + if (mi.NewExpression.Arguments.Count > 0 || mi.Bindings.Count == 0) + throwExpr = true; + else + throwExpr = mi.Bindings.Any(b => b.BindingType != MemberBindingType.Assignment); + + if (throwExpr) + throw new NotSupportedException(string.Format("Explicit construction of entity type '{0}' in join is not allowed.", body.Type)); + } + + return true; + } + + protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + var isGroup = methodCall.Method.Name == "GroupJoin"; + var outerContext = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0], buildInfo.SqlQuery)); + var innerContext = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[1], new SqlQuery())); + var countContext = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[1], new SqlQuery())); + + var context = new SubQueryContext(outerContext); + innerContext = isGroup ? new GroupJoinSubQueryContext(innerContext, methodCall) : new SubQueryContext(innerContext); + countContext = new SubQueryContext(countContext); + + var join = isGroup ? innerContext.SqlQuery.WeakLeftJoin() : innerContext.SqlQuery.InnerJoin(); + var sql = context.SqlQuery; + + sql.From.Tables[0].Joins.Add(join.JoinedTable); + + var selector = (LambdaExpression)methodCall.Arguments[4].Unwrap(); + + context.SetAlias(selector.Parameters[0].Name); + innerContext.SetAlias(selector.Parameters[1].Name); + + var outerKeyLambda = ((LambdaExpression)methodCall.Arguments[2].Unwrap()); + var innerKeyLambda = ((LambdaExpression)methodCall.Arguments[3].Unwrap()); + + var outerKeySelector = outerKeyLambda.Body.Unwrap(); + var innerKeySelector = innerKeyLambda.Body.Unwrap(); + + var outerParent = context. Parent; + var innerParent = innerContext.Parent; + var countParent = countContext.Parent; + + var outerKeyContext = new ExpressionContext(buildInfo.Parent, context, outerKeyLambda); + var innerKeyContext = new InnerKeyContext (buildInfo.Parent, innerContext, innerKeyLambda); + var countKeyContext = new ExpressionContext(buildInfo.Parent, countContext, innerKeyLambda); + + // Process counter. + // + var counterSql = ((SubQueryContext)countContext).SqlQuery; + + // Make join and where for the counter. + // + if (outerKeySelector.NodeType == ExpressionType.New) + { + var new1 = (NewExpression)outerKeySelector; + var new2 = (NewExpression)innerKeySelector; + + for (var i = 0; i < new1.Arguments.Count; i++) + { + var arg1 = new1.Arguments[i]; + var arg2 = new2.Arguments[i]; + + BuildJoin(builder, join, outerKeyContext, arg1, innerKeyContext, arg2, countKeyContext, counterSql); + } + } + else if (outerKeySelector.NodeType == ExpressionType.MemberInit) + { + var mi1 = (MemberInitExpression)outerKeySelector; + var mi2 = (MemberInitExpression)innerKeySelector; + + for (var i = 0; i < mi1.Bindings.Count; i++) + { + if (mi1.Bindings[i].Member != mi2.Bindings[i].Member) + throw new LinqException(string.Format("List of member inits does not match for entity type '{0}'.", outerKeySelector.Type)); + + var arg1 = ((MemberAssignment)mi1.Bindings[i]).Expression; + var arg2 = ((MemberAssignment)mi2.Bindings[i]).Expression; + + BuildJoin(builder, join, outerKeyContext, arg1, innerKeyContext, arg2, countKeyContext, counterSql); + } + } + else + { + BuildJoin(builder, join, outerKeyContext, outerKeySelector, innerKeyContext, innerKeySelector, countKeyContext, counterSql); + } + + builder.ReplaceParent(outerKeyContext, outerParent); + builder.ReplaceParent(innerKeyContext, innerParent); + builder.ReplaceParent(countKeyContext, countParent); + + if (isGroup) + { + counterSql.ParentSql = sql; + counterSql.Select.Columns.Clear(); + + var inner = (GroupJoinSubQueryContext)innerContext; + + inner.Join = join.JoinedTable; + inner.CounterSql = counterSql; + return new GroupJoinContext( + buildInfo.Parent, selector, context, inner, methodCall.Arguments[1], outerKeyLambda, innerKeyLambda); + } + + return new JoinContext(buildInfo.Parent, selector, context, innerContext) +#if DEBUG + { + MethodCall = methodCall + } +#endif + ; + } + + protected override SequenceConvertInfo Convert( + ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param) + { + return null; + } + + static void BuildJoin( + ExpressionBuilder builder, + SqlQuery.FromClause.Join join, + IBuildContext outerKeyContext, Expression outerKeySelector, + IBuildContext innerKeyContext, Expression innerKeySelector, + IBuildContext countKeyContext, SqlQuery countSql) + { + var predicate = builder.ConvertObjectComparison( + ExpressionType.Equal, + outerKeyContext, outerKeySelector, + innerKeyContext, innerKeySelector); + + if (predicate != null) + join.JoinedTable.Condition.Conditions.Add(new SqlQuery.Condition(false, predicate)); + else + join + .Expr(builder.ConvertToSql(outerKeyContext, outerKeySelector, false)).Equal + .Expr(builder.ConvertToSql(innerKeyContext, innerKeySelector, false)); + + predicate = builder.ConvertObjectComparison( + ExpressionType.Equal, + outerKeyContext, outerKeySelector, + countKeyContext, innerKeySelector); + + if (predicate != null) + countSql.Where.SearchCondition.Conditions.Add(new SqlQuery.Condition(false, predicate)); + else + countSql.Where + .Expr(builder.ConvertToSql(outerKeyContext, outerKeySelector, false)).Equal + .Expr(builder.ConvertToSql(countKeyContext, innerKeySelector, false)); + } + + class InnerKeyContext : ExpressionContext + { + public InnerKeyContext(IBuildContext parent, IBuildContext sequence, LambdaExpression lambda) + : base(parent, sequence, lambda) + { + } + + public override SqlInfo[] ConvertToSql(Expression expression, int level, ConvertFlags flags) + { + return base + .ConvertToSql(expression, level, flags) + .Select(idx => + { + var n = SqlQuery.Select.Add(idx.Sql); + + return new SqlInfo(idx.Members) + { + Sql = SqlQuery.Select.Columns[n], + Index = n + }; + }) + .ToArray(); + } + } + + internal class JoinContext : SelectContext + { + public JoinContext(IBuildContext parent, LambdaExpression lambda, IBuildContext outerContext, IBuildContext innerContext) + : base(parent, lambda, outerContext, innerContext) + { + } + } + + internal class GroupJoinContext : JoinContext + { + public GroupJoinContext( + IBuildContext parent, + LambdaExpression lambda, + IBuildContext outerContext, + GroupJoinSubQueryContext innerContext, + Expression innerExpression, + LambdaExpression outerKeyLambda, + LambdaExpression innerKeyLambda) + : base(parent, lambda, outerContext, innerContext) + { + _innerExpression = innerExpression; + _outerKeyLambda = outerKeyLambda; + _innerKeyLambda = innerKeyLambda; + + innerContext.GroupJoin = this; + } + + readonly Expression _innerExpression; + readonly LambdaExpression _outerKeyLambda; + readonly LambdaExpression _innerKeyLambda; + private Expression _groupExpression; + + interface IGroupJoinHelper + { + Expression GetGroupJoin(GroupJoinContext context); + } + + class GroupJoinHelper<TKey,TElement> : IGroupJoinHelper + { + public Expression GetGroupJoin(GroupJoinContext context) + { + // Convert outer condition. + // + var outerParam = Expression.Parameter(context._outerKeyLambda.Body.Type, "o"); + var outerKey = context._outerKeyLambda.Body.Convert( + e => e == context._outerKeyLambda.Parameters[0] ? context.Lambda.Parameters[0] : e); + + outerKey = context.Builder.BuildExpression(context, outerKey); + + // Convert inner condition. + // + var parameters = context.Builder.CurrentSqlParameters + .Select((p,i) => new { p, i }) + .ToDictionary(_ => _.p.Expression, _ => _.i); + var paramArray = Expression.Parameter(typeof(object[]), "ps"); + + var innerKey = context._innerKeyLambda.Body.Convert(e => + { + int idx; + + if (parameters.TryGetValue(e, out idx)) + { + return + Expression.Convert( + Expression.ArrayIndex(paramArray, Expression.Constant(idx)), + e.Type); + } + + return e; + }); + + // Item reader. + // +// ReSharper disable AssignNullToNotNullAttribute + + var expr = Expression.Call( + null, + ReflectionHelper.Expressor<object>.MethodExpressor(_ => Queryable.Where(null, (Expression<Func<TElement,bool>>)null)), + context._innerExpression, + Expression.Lambda<Func<TElement,bool>>( + Expression.Equal(innerKey, outerParam), + new[] { context._innerKeyLambda.Parameters[0] })); + +// ReSharper restore AssignNullToNotNullAttribute + + var lambda = Expression.Lambda<Func<IDataContext,TKey,object[],IQueryable<TElement>>>( + Expression.Convert(expr, typeof(IQueryable<TElement>)), + Expression.Parameter(typeof(IDataContext), "ctx"), + outerParam, + paramArray); + + var itemReader = CompiledQuery.Compile(lambda); + + return Expression.Call( + null, + ReflectionHelper.Expressor<object>.MethodExpressor(_ => GetGrouping(null, null, default(TKey), null)), + new[] + { + ExpressionBuilder.ContextParam, + Expression.Constant(context.Builder.CurrentSqlParameters), + outerKey, + Expression.Constant(itemReader), + }); + } + + static IEnumerable<TElement> GetGrouping( + QueryContext context, + List<ParameterAccessor> parameterAccessor, + TKey key, + Func<IDataContext,TKey,object[],IQueryable<TElement>> itemReader) + { + return new GroupByBuilder.GroupByContext.Grouping<TKey,TElement>(key, context, parameterAccessor, itemReader); + } + } + + public override Expression BuildExpression(Expression expression, int level) + { + if (expression == Lambda.Parameters[1]) + { + if (_groupExpression == null) + { + var gtype = typeof(GroupJoinHelper<,>).MakeGenericType( + _innerKeyLambda.Body.Type, + _innerKeyLambda.Parameters[0].Type); + + var helper = (IGroupJoinHelper)Activator.CreateInstance(gtype); + + _groupExpression = helper.GetGroupJoin(this); + } + + return _groupExpression; + } + + return base.BuildExpression(expression, level); + } + } + + internal class GroupJoinSubQueryContext : SubQueryContext + { + //readonly MethodCallExpression _methodCall; + + public SqlQuery.JoinedTable Join; + public SqlQuery CounterSql; + public GroupJoinContext GroupJoin; + + public GroupJoinSubQueryContext(IBuildContext subQuery, MethodCallExpression methodCall) + : base(subQuery) + { + //_methodCall = methodCall; + } + + public override IBuildContext GetContext(Expression expression, int level, BuildInfo buildInfo) + { + if (expression == null) + return this; + + return base.GetContext(expression, level, buildInfo); + } + + Expression _counterExpression; + SqlInfo[] _counterInfo; + + public override SqlInfo[] ConvertToIndex(Expression expression, int level, ConvertFlags flags) + { + if (expression != null && expression == _counterExpression) + return _counterInfo ?? (_counterInfo = new[] + { + new SqlInfo + { + Query = CounterSql.ParentSql, + Index = CounterSql.ParentSql.Select.Add(CounterSql), + Sql = CounterSql + } + }); + + return base.ConvertToIndex(expression, level, flags); + } + + public override IsExpressionResult IsExpression(Expression expression, int level, RequestFor testFlag) + { + if (testFlag == RequestFor.GroupJoin && expression == null) + return IsExpressionResult.True; + + return base.IsExpression(expression, level, testFlag); + } + + public SqlQuery GetCounter(Expression expr) + { + Join.IsWeak = true; + + _counterExpression = expr; + + return CounterSql; + } + } + } +}