diff Source/Data/Linq/Builder/ConcatUnionBuilder.cs @ 0:f990fcb411a9

Копия текущей версии из github
author cin
date Thu, 27 Mar 2014 21:46:09 +0400
parents
children
line wrap: on
line diff
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/Source/Data/Linq/Builder/ConcatUnionBuilder.cs	Thu Mar 27 21:46:09 2014 +0400
@@ -0,0 +1,340 @@
+using System;
+using System.Collections.Generic;
+using System.Linq;
+using System.Linq.Expressions;
+using System.Reflection;
+
+namespace BLToolkit.Data.Linq.Builder
+{
+	using BLToolkit.Linq;
+	using Data.Sql;
+	using Reflection;
+
+	class ConcatUnionBuilder : MethodCallBuilder
+	{
+		#region Builder
+
+		protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo)
+		{
+			return methodCall.Arguments.Count == 2 && methodCall.IsQueryable("Concat", "Union");
+		}
+
+		protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo)
+		{
+			var sequence1 = new SubQueryContext(builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0])));
+			var sequence2 = new SubQueryContext(builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[1], new SqlQuery())));
+			var union     = new SqlQuery.Union(sequence2.SqlQuery, methodCall.Method.Name == "Concat");
+
+			sequence1.SqlQuery.Unions.Add(union);
+
+			return new UnionContext(sequence1, sequence2, methodCall);
+		}
+
+		protected override SequenceConvertInfo Convert(
+			ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param)
+		{
+			return null;
+		}
+
+		#endregion
+
+		#region Context
+
+		sealed class UnionContext : SubQueryContext
+		{
+			public UnionContext(SubQueryContext sequence1, SubQueryContext sequence2, MethodCallExpression methodCall)
+				: base(sequence1)
+			{
+				_methodCall = methodCall;
+
+				_isObject =
+					sequence1.IsExpression(null, 0, RequestFor.Object).Result ||
+					sequence2.IsExpression(null, 0, RequestFor.Object).Result;
+
+				if (_isObject)
+				{
+					var type = _methodCall.Method.GetGenericArguments()[0];
+					_unionParameter = Expression.Parameter(type, "t");
+				}
+
+				Init(sequence1, sequence2);
+			}
+
+			readonly bool                          _isObject;
+			readonly MethodCallExpression          _methodCall;
+			readonly ParameterExpression           _unionParameter;
+			readonly Dictionary<MemberInfo,Member> _members = new Dictionary<MemberInfo,Member>(new MemberInfoComparer());
+
+			class Member
+			{
+				public SqlInfo          SequenceInfo;
+				public SqlInfo          SqlQueryInfo;
+				public MemberExpression MemberExpression;
+			}
+
+			class UnionMember
+			{
+				public Member  Member;
+				public SqlInfo Info1;
+				public SqlInfo Info2;
+			}
+
+			void Init(SubQueryContext sequence1, SubQueryContext sequence2)
+			{
+				var info1 = sequence1.ConvertToIndex(null, 0, ConvertFlags.All).ToList();
+				var info2 = sequence2.ConvertToIndex(null, 0, ConvertFlags.All).ToList();
+
+				if (!_isObject)
+					return;
+
+				var members = new List<UnionMember>();
+
+				foreach (var info in info1)
+				{
+					if (info.Members.Count == 0)
+						throw new InvalidOperationException();
+
+					var member = new Member
+					{
+						SequenceInfo     = info,
+						MemberExpression = Expression.MakeMemberAccess(_unionParameter, info.Members[0])
+					};
+
+					members.Add(new UnionMember { Member = member, Info1 = info });
+				}
+
+				foreach (var info in info2)
+				{
+					if (info.Members.Count == 0)
+						throw new InvalidOperationException();
+
+					var em = members.FirstOrDefault(m =>
+						m.Member.SequenceInfo != null &&
+						m.Member.SequenceInfo.CompareLastMember(info));
+
+					if (em == null)
+					{
+						var member = new Member { MemberExpression = Expression.MakeMemberAccess(_unionParameter, info.Members[0]) };
+
+						if (sequence2.IsExpression(member.MemberExpression, 1, RequestFor.Object).Result)
+							throw new LinqException("Types in {0} are constructed incompatibly.", _methodCall.Method.Name);
+
+						members.Add(new UnionMember { Member = member, Info2 = info });
+					}
+					else
+					{
+						em.Info2 = info;
+					}
+				}
+
+				sequence1.SqlQuery.Select.Columns.Clear();
+				sequence2.SqlQuery.Select.Columns.Clear();
+
+				for (var i = 0; i < members.Count; i++)
+				{
+					var member = members[i];
+
+					if (member.Info1 == null)
+					{
+						member.Info1 = new SqlInfo(member.Info2.Members)
+						{
+							Sql   = new SqlValue(null),
+							Query = sequence1.SqlQuery,
+						};
+
+						member.Member.SequenceInfo = member.Info1;
+					}
+
+					if (member.Info2 == null)
+					{
+						member.Info2 = new SqlInfo(member.Info1.Members)
+						{
+							Sql   = new SqlValue(null),
+							Query = sequence2.SqlQuery,
+						};
+					}
+
+					sequence1.SqlQuery.Select.Columns.Add(new SqlQuery.Column(sequence1.SqlQuery, member.Info1.Sql));
+					sequence2.SqlQuery.Select.Columns.Add(new SqlQuery.Column(sequence2.SqlQuery, member.Info2.Sql));
+
+					member.Member.SequenceInfo.Index = i;
+
+					_members[member.Member.MemberExpression.Member] = member.Member;
+				}
+
+				foreach (var key in sequence1.ColumnIndexes.Keys.ToList())
+					sequence1.ColumnIndexes[key] = sequence1.SqlQuery.Select.Add(key);
+
+				foreach (var key in sequence2.ColumnIndexes.Keys.ToList())
+					sequence2.ColumnIndexes[key] = sequence2.SqlQuery.Select.Add(key);
+			}
+
+			public override void BuildQuery<T>(Query<T> query, ParameterExpression queryParameter)
+			{
+				var expr   = BuildExpression(null, 0);
+				var mapper = Builder.BuildMapper<T>(expr);
+
+				query.SetQuery(mapper.Compile());
+			}
+
+			public override Expression BuildExpression(Expression expression, int level)
+			{
+				if (_isObject)
+				{
+					if (expression == null)
+					{
+						var type  = _methodCall.Method.GetGenericArguments()[0];
+						var nctor = (NewExpression)Expression.Find(e =>
+						{
+							if (e.NodeType == ExpressionType.New && e.Type == type)
+							{
+								var ne = (NewExpression)e;
+								return ne.Arguments != null && ne.Arguments.Count > 0;
+							}
+
+							return false;
+						});
+
+						Expression expr;
+
+						if (nctor != null)
+						{
+							var members = nctor.Members
+								.Select(m => m is MethodInfo ? TypeHelper.GetPropertyByMethod((MethodInfo)m) : m)
+								.ToList();
+
+							expr = Expression.New(
+								nctor.Constructor,
+								members
+									.Select(m => Expression.PropertyOrField(_unionParameter, m.Name))
+									.Cast<Expression>(),
+								members);
+						}
+						else
+						{
+							var ta = TypeAccessor.GetAccessor(type);
+
+							expr = Expression.MemberInit(
+								Expression.New(ta.Type),
+								_members
+									.Select(m => Expression.Bind(m.Value.MemberExpression.Member, m.Value.MemberExpression))
+									.Cast<MemberBinding>());
+						}
+
+						var ex = Builder.BuildExpression(this, expr);
+
+						return ex;
+					}
+
+					if (level == 0 || level == 1)
+					{
+						var levelExpression = expression.GetLevelExpression(1);
+
+						if (expression == levelExpression && !IsExpression(expression, 1, RequestFor.Object).Result)
+						{
+							var idx = ConvertToIndex(expression, level, ConvertFlags.Field);
+							var n   = idx[0].Index;
+
+							if (Parent != null)
+								n = Parent.ConvertToParentIndex(n, this);
+
+							return Builder.BuildSql(expression.Type, n);
+						}
+					}
+				}
+
+				return base.BuildExpression(expression, level);
+			}
+
+			public override IsExpressionResult IsExpression(Expression expression, int level, RequestFor testFlag)
+			{
+				if (testFlag == RequestFor.Root && expression == _unionParameter)
+					return IsExpressionResult.True;
+
+				return base.IsExpression(expression, level, testFlag);
+			}
+
+			public override SqlInfo[] ConvertToIndex(Expression expression, int level, ConvertFlags flags)
+			{
+				if (_isObject)
+				{
+					return ConvertToSql(expression, level, flags)
+						.Select(idx =>
+						{
+							if (idx.Index < 0)
+							{
+								if (idx.Index == -2)
+								{
+									SqlQuery.Select.Columns.Add(new SqlQuery.Column(SqlQuery, idx.Sql));
+									idx.Index = SqlQuery.Select.Columns.Count - 1;
+								}
+								else
+								{
+									idx.Index = SqlQuery.Select.Add(idx.Sql);
+								}
+							}
+
+							return idx;
+						})
+						.ToArray();
+				}
+
+				return base.ConvertToIndex(expression, level, flags);
+			}
+
+			public override SqlInfo[] ConvertToSql(Expression expression, int level, ConvertFlags flags)
+			{
+				if (_isObject)
+				{
+					switch (flags)
+					{
+						case ConvertFlags.All   :
+						case ConvertFlags.Key   :
+
+							if (expression == null)
+							{
+								return _members.Values
+									.Select(m => ConvertToSql(m.MemberExpression, 0, ConvertFlags.Field)[0])
+									.ToArray();
+							}
+
+							break;
+
+						case ConvertFlags.Field :
+
+							if (expression != null && (level == 0 || level == 1) && expression.NodeType == ExpressionType.MemberAccess)
+							{
+								var levelExpression = expression.GetLevelExpression(1);
+
+								if (expression == levelExpression)
+								{
+									var ma     = (MemberExpression)expression;
+									var member = _members[ma.Member];
+
+									if (member.SqlQueryInfo == null)
+									{
+										member.SqlQueryInfo = new SqlInfo(member.MemberExpression.Member)
+										{
+											Index = -2,
+											Sql   = SubQuery.SqlQuery.Select.Columns[member.SequenceInfo.Index],
+											Query = SqlQuery,
+										};
+									}
+
+									return new[] { member.SqlQueryInfo };
+								}
+							}
+
+							break;
+					}
+
+					throw new InvalidOperationException();
+				}
+
+				return base.ConvertToSql(expression, level, flags);
+			}
+		}
+
+		#endregion
+	}
+}