Mercurial > pub > bltoolkit
diff Source/Data/Linq/Builder/ConcatUnionBuilder.cs @ 0:f990fcb411a9
Копия текущей версии из github
author | cin |
---|---|
date | Thu, 27 Mar 2014 21:46:09 +0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Source/Data/Linq/Builder/ConcatUnionBuilder.cs Thu Mar 27 21:46:09 2014 +0400 @@ -0,0 +1,340 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; + +namespace BLToolkit.Data.Linq.Builder +{ + using BLToolkit.Linq; + using Data.Sql; + using Reflection; + + class ConcatUnionBuilder : MethodCallBuilder + { + #region Builder + + protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + return methodCall.Arguments.Count == 2 && methodCall.IsQueryable("Concat", "Union"); + } + + protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + var sequence1 = new SubQueryContext(builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0]))); + var sequence2 = new SubQueryContext(builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[1], new SqlQuery()))); + var union = new SqlQuery.Union(sequence2.SqlQuery, methodCall.Method.Name == "Concat"); + + sequence1.SqlQuery.Unions.Add(union); + + return new UnionContext(sequence1, sequence2, methodCall); + } + + protected override SequenceConvertInfo Convert( + ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param) + { + return null; + } + + #endregion + + #region Context + + sealed class UnionContext : SubQueryContext + { + public UnionContext(SubQueryContext sequence1, SubQueryContext sequence2, MethodCallExpression methodCall) + : base(sequence1) + { + _methodCall = methodCall; + + _isObject = + sequence1.IsExpression(null, 0, RequestFor.Object).Result || + sequence2.IsExpression(null, 0, RequestFor.Object).Result; + + if (_isObject) + { + var type = _methodCall.Method.GetGenericArguments()[0]; + _unionParameter = Expression.Parameter(type, "t"); + } + + Init(sequence1, sequence2); + } + + readonly bool _isObject; + readonly MethodCallExpression _methodCall; + readonly ParameterExpression _unionParameter; + readonly Dictionary<MemberInfo,Member> _members = new Dictionary<MemberInfo,Member>(new MemberInfoComparer()); + + class Member + { + public SqlInfo SequenceInfo; + public SqlInfo SqlQueryInfo; + public MemberExpression MemberExpression; + } + + class UnionMember + { + public Member Member; + public SqlInfo Info1; + public SqlInfo Info2; + } + + void Init(SubQueryContext sequence1, SubQueryContext sequence2) + { + var info1 = sequence1.ConvertToIndex(null, 0, ConvertFlags.All).ToList(); + var info2 = sequence2.ConvertToIndex(null, 0, ConvertFlags.All).ToList(); + + if (!_isObject) + return; + + var members = new List<UnionMember>(); + + foreach (var info in info1) + { + if (info.Members.Count == 0) + throw new InvalidOperationException(); + + var member = new Member + { + SequenceInfo = info, + MemberExpression = Expression.MakeMemberAccess(_unionParameter, info.Members[0]) + }; + + members.Add(new UnionMember { Member = member, Info1 = info }); + } + + foreach (var info in info2) + { + if (info.Members.Count == 0) + throw new InvalidOperationException(); + + var em = members.FirstOrDefault(m => + m.Member.SequenceInfo != null && + m.Member.SequenceInfo.CompareLastMember(info)); + + if (em == null) + { + var member = new Member { MemberExpression = Expression.MakeMemberAccess(_unionParameter, info.Members[0]) }; + + if (sequence2.IsExpression(member.MemberExpression, 1, RequestFor.Object).Result) + throw new LinqException("Types in {0} are constructed incompatibly.", _methodCall.Method.Name); + + members.Add(new UnionMember { Member = member, Info2 = info }); + } + else + { + em.Info2 = info; + } + } + + sequence1.SqlQuery.Select.Columns.Clear(); + sequence2.SqlQuery.Select.Columns.Clear(); + + for (var i = 0; i < members.Count; i++) + { + var member = members[i]; + + if (member.Info1 == null) + { + member.Info1 = new SqlInfo(member.Info2.Members) + { + Sql = new SqlValue(null), + Query = sequence1.SqlQuery, + }; + + member.Member.SequenceInfo = member.Info1; + } + + if (member.Info2 == null) + { + member.Info2 = new SqlInfo(member.Info1.Members) + { + Sql = new SqlValue(null), + Query = sequence2.SqlQuery, + }; + } + + sequence1.SqlQuery.Select.Columns.Add(new SqlQuery.Column(sequence1.SqlQuery, member.Info1.Sql)); + sequence2.SqlQuery.Select.Columns.Add(new SqlQuery.Column(sequence2.SqlQuery, member.Info2.Sql)); + + member.Member.SequenceInfo.Index = i; + + _members[member.Member.MemberExpression.Member] = member.Member; + } + + foreach (var key in sequence1.ColumnIndexes.Keys.ToList()) + sequence1.ColumnIndexes[key] = sequence1.SqlQuery.Select.Add(key); + + foreach (var key in sequence2.ColumnIndexes.Keys.ToList()) + sequence2.ColumnIndexes[key] = sequence2.SqlQuery.Select.Add(key); + } + + public override void BuildQuery<T>(Query<T> query, ParameterExpression queryParameter) + { + var expr = BuildExpression(null, 0); + var mapper = Builder.BuildMapper<T>(expr); + + query.SetQuery(mapper.Compile()); + } + + public override Expression BuildExpression(Expression expression, int level) + { + if (_isObject) + { + if (expression == null) + { + var type = _methodCall.Method.GetGenericArguments()[0]; + var nctor = (NewExpression)Expression.Find(e => + { + if (e.NodeType == ExpressionType.New && e.Type == type) + { + var ne = (NewExpression)e; + return ne.Arguments != null && ne.Arguments.Count > 0; + } + + return false; + }); + + Expression expr; + + if (nctor != null) + { + var members = nctor.Members + .Select(m => m is MethodInfo ? TypeHelper.GetPropertyByMethod((MethodInfo)m) : m) + .ToList(); + + expr = Expression.New( + nctor.Constructor, + members + .Select(m => Expression.PropertyOrField(_unionParameter, m.Name)) + .Cast<Expression>(), + members); + } + else + { + var ta = TypeAccessor.GetAccessor(type); + + expr = Expression.MemberInit( + Expression.New(ta.Type), + _members + .Select(m => Expression.Bind(m.Value.MemberExpression.Member, m.Value.MemberExpression)) + .Cast<MemberBinding>()); + } + + var ex = Builder.BuildExpression(this, expr); + + return ex; + } + + if (level == 0 || level == 1) + { + var levelExpression = expression.GetLevelExpression(1); + + if (expression == levelExpression && !IsExpression(expression, 1, RequestFor.Object).Result) + { + var idx = ConvertToIndex(expression, level, ConvertFlags.Field); + var n = idx[0].Index; + + if (Parent != null) + n = Parent.ConvertToParentIndex(n, this); + + return Builder.BuildSql(expression.Type, n); + } + } + } + + return base.BuildExpression(expression, level); + } + + public override IsExpressionResult IsExpression(Expression expression, int level, RequestFor testFlag) + { + if (testFlag == RequestFor.Root && expression == _unionParameter) + return IsExpressionResult.True; + + return base.IsExpression(expression, level, testFlag); + } + + public override SqlInfo[] ConvertToIndex(Expression expression, int level, ConvertFlags flags) + { + if (_isObject) + { + return ConvertToSql(expression, level, flags) + .Select(idx => + { + if (idx.Index < 0) + { + if (idx.Index == -2) + { + SqlQuery.Select.Columns.Add(new SqlQuery.Column(SqlQuery, idx.Sql)); + idx.Index = SqlQuery.Select.Columns.Count - 1; + } + else + { + idx.Index = SqlQuery.Select.Add(idx.Sql); + } + } + + return idx; + }) + .ToArray(); + } + + return base.ConvertToIndex(expression, level, flags); + } + + public override SqlInfo[] ConvertToSql(Expression expression, int level, ConvertFlags flags) + { + if (_isObject) + { + switch (flags) + { + case ConvertFlags.All : + case ConvertFlags.Key : + + if (expression == null) + { + return _members.Values + .Select(m => ConvertToSql(m.MemberExpression, 0, ConvertFlags.Field)[0]) + .ToArray(); + } + + break; + + case ConvertFlags.Field : + + if (expression != null && (level == 0 || level == 1) && expression.NodeType == ExpressionType.MemberAccess) + { + var levelExpression = expression.GetLevelExpression(1); + + if (expression == levelExpression) + { + var ma = (MemberExpression)expression; + var member = _members[ma.Member]; + + if (member.SqlQueryInfo == null) + { + member.SqlQueryInfo = new SqlInfo(member.MemberExpression.Member) + { + Index = -2, + Sql = SubQuery.SqlQuery.Select.Columns[member.SequenceInfo.Index], + Query = SqlQuery, + }; + } + + return new[] { member.SqlQueryInfo }; + } + } + + break; + } + + throw new InvalidOperationException(); + } + + return base.ConvertToSql(expression, level, flags); + } + } + + #endregion + } +}