view Source/Data/Linq/Builder/ConcatUnionBuilder.cs @ 6:11b6da379593

Исправлена странная ошибка при использовании OfType<...>().Where(...)
author cin
date Mon, 05 Dec 2016 05:50:52 +0300
parents f990fcb411a9
children
line wrap: on
line source

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

namespace BLToolkit.Data.Linq.Builder
{
	using BLToolkit.Linq;
	using Data.Sql;
	using Reflection;

	class ConcatUnionBuilder : MethodCallBuilder
	{
		#region Builder

		protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo)
		{
			return methodCall.Arguments.Count == 2 && methodCall.IsQueryable("Concat", "Union");
		}

		protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo)
		{
			var sequence1 = new SubQueryContext(builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0])));
			var sequence2 = new SubQueryContext(builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[1], new SqlQuery())));
			var union     = new SqlQuery.Union(sequence2.SqlQuery, methodCall.Method.Name == "Concat");

			sequence1.SqlQuery.Unions.Add(union);

			return new UnionContext(sequence1, sequence2, methodCall);
		}

		protected override SequenceConvertInfo Convert(
			ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param)
		{
			return null;
		}

		#endregion

		#region Context

		sealed class UnionContext : SubQueryContext
		{
			public UnionContext(SubQueryContext sequence1, SubQueryContext sequence2, MethodCallExpression methodCall)
				: base(sequence1)
			{
				_methodCall = methodCall;

				_isObject =
					sequence1.IsExpression(null, 0, RequestFor.Object).Result ||
					sequence2.IsExpression(null, 0, RequestFor.Object).Result;

				if (_isObject)
				{
					var type = _methodCall.Method.GetGenericArguments()[0];
					_unionParameter = Expression.Parameter(type, "t");
				}

				Init(sequence1, sequence2);
			}

			readonly bool                          _isObject;
			readonly MethodCallExpression          _methodCall;
			readonly ParameterExpression           _unionParameter;
			readonly Dictionary<MemberInfo,Member> _members = new Dictionary<MemberInfo,Member>(new MemberInfoComparer());

			class Member
			{
				public SqlInfo          SequenceInfo;
				public SqlInfo          SqlQueryInfo;
				public MemberExpression MemberExpression;
			}

			class UnionMember
			{
				public Member  Member;
				public SqlInfo Info1;
				public SqlInfo Info2;
			}

			void Init(SubQueryContext sequence1, SubQueryContext sequence2)
			{
				var info1 = sequence1.ConvertToIndex(null, 0, ConvertFlags.All).ToList();
				var info2 = sequence2.ConvertToIndex(null, 0, ConvertFlags.All).ToList();

				if (!_isObject)
					return;

				var members = new List<UnionMember>();

				foreach (var info in info1)
				{
					if (info.Members.Count == 0)
						throw new InvalidOperationException();

					var member = new Member
					{
						SequenceInfo     = info,
						MemberExpression = Expression.MakeMemberAccess(_unionParameter, info.Members[0])
					};

					members.Add(new UnionMember { Member = member, Info1 = info });
				}

				foreach (var info in info2)
				{
					if (info.Members.Count == 0)
						throw new InvalidOperationException();

					var em = members.FirstOrDefault(m =>
						m.Member.SequenceInfo != null &&
						m.Member.SequenceInfo.CompareLastMember(info));

					if (em == null)
					{
						var member = new Member { MemberExpression = Expression.MakeMemberAccess(_unionParameter, info.Members[0]) };

						if (sequence2.IsExpression(member.MemberExpression, 1, RequestFor.Object).Result)
							throw new LinqException("Types in {0} are constructed incompatibly.", _methodCall.Method.Name);

						members.Add(new UnionMember { Member = member, Info2 = info });
					}
					else
					{
						em.Info2 = info;
					}
				}

				sequence1.SqlQuery.Select.Columns.Clear();
				sequence2.SqlQuery.Select.Columns.Clear();

				for (var i = 0; i < members.Count; i++)
				{
					var member = members[i];

					if (member.Info1 == null)
					{
						member.Info1 = new SqlInfo(member.Info2.Members)
						{
							Sql   = new SqlValue(null),
							Query = sequence1.SqlQuery,
						};

						member.Member.SequenceInfo = member.Info1;
					}

					if (member.Info2 == null)
					{
						member.Info2 = new SqlInfo(member.Info1.Members)
						{
							Sql   = new SqlValue(null),
							Query = sequence2.SqlQuery,
						};
					}

					sequence1.SqlQuery.Select.Columns.Add(new SqlQuery.Column(sequence1.SqlQuery, member.Info1.Sql));
					sequence2.SqlQuery.Select.Columns.Add(new SqlQuery.Column(sequence2.SqlQuery, member.Info2.Sql));

					member.Member.SequenceInfo.Index = i;

					_members[member.Member.MemberExpression.Member] = member.Member;
				}

				foreach (var key in sequence1.ColumnIndexes.Keys.ToList())
					sequence1.ColumnIndexes[key] = sequence1.SqlQuery.Select.Add(key);

				foreach (var key in sequence2.ColumnIndexes.Keys.ToList())
					sequence2.ColumnIndexes[key] = sequence2.SqlQuery.Select.Add(key);
			}

			public override void BuildQuery<T>(Query<T> query, ParameterExpression queryParameter)
			{
				var expr   = BuildExpression(null, 0);
				var mapper = Builder.BuildMapper<T>(expr);

				query.SetQuery(mapper.Compile());
			}

			public override Expression BuildExpression(Expression expression, int level)
			{
				if (_isObject)
				{
					if (expression == null)
					{
						var type  = _methodCall.Method.GetGenericArguments()[0];
						var nctor = (NewExpression)Expression.Find(e =>
						{
							if (e.NodeType == ExpressionType.New && e.Type == type)
							{
								var ne = (NewExpression)e;
								return ne.Arguments != null && ne.Arguments.Count > 0;
							}

							return false;
						});

						Expression expr;

						if (nctor != null)
						{
							var members = nctor.Members
								.Select(m => m is MethodInfo ? TypeHelper.GetPropertyByMethod((MethodInfo)m) : m)
								.ToList();

							expr = Expression.New(
								nctor.Constructor,
								members
									.Select(m => Expression.PropertyOrField(_unionParameter, m.Name))
									.Cast<Expression>(),
								members);
						}
						else
						{
							var ta = TypeAccessor.GetAccessor(type);

							expr = Expression.MemberInit(
								Expression.New(ta.Type),
								_members
									.Select(m => Expression.Bind(m.Value.MemberExpression.Member, m.Value.MemberExpression))
									.Cast<MemberBinding>());
						}

						var ex = Builder.BuildExpression(this, expr);

						return ex;
					}

					if (level == 0 || level == 1)
					{
						var levelExpression = expression.GetLevelExpression(1);

						if (expression == levelExpression && !IsExpression(expression, 1, RequestFor.Object).Result)
						{
							var idx = ConvertToIndex(expression, level, ConvertFlags.Field);
							var n   = idx[0].Index;

							if (Parent != null)
								n = Parent.ConvertToParentIndex(n, this);

							return Builder.BuildSql(expression.Type, n);
						}
					}
				}

				return base.BuildExpression(expression, level);
			}

			public override IsExpressionResult IsExpression(Expression expression, int level, RequestFor testFlag)
			{
				if (testFlag == RequestFor.Root && expression == _unionParameter)
					return IsExpressionResult.True;

				return base.IsExpression(expression, level, testFlag);
			}

			public override SqlInfo[] ConvertToIndex(Expression expression, int level, ConvertFlags flags)
			{
				if (_isObject)
				{
					return ConvertToSql(expression, level, flags)
						.Select(idx =>
						{
							if (idx.Index < 0)
							{
								if (idx.Index == -2)
								{
									SqlQuery.Select.Columns.Add(new SqlQuery.Column(SqlQuery, idx.Sql));
									idx.Index = SqlQuery.Select.Columns.Count - 1;
								}
								else
								{
									idx.Index = SqlQuery.Select.Add(idx.Sql);
								}
							}

							return idx;
						})
						.ToArray();
				}

				return base.ConvertToIndex(expression, level, flags);
			}

			public override SqlInfo[] ConvertToSql(Expression expression, int level, ConvertFlags flags)
			{
				if (_isObject)
				{
					switch (flags)
					{
						case ConvertFlags.All   :
						case ConvertFlags.Key   :

							if (expression == null)
							{
								return _members.Values
									.Select(m => ConvertToSql(m.MemberExpression, 0, ConvertFlags.Field)[0])
									.ToArray();
							}

							break;

						case ConvertFlags.Field :

							if (expression != null && (level == 0 || level == 1) && expression.NodeType == ExpressionType.MemberAccess)
							{
								var levelExpression = expression.GetLevelExpression(1);

								if (expression == levelExpression)
								{
									var ma     = (MemberExpression)expression;
									var member = _members[ma.Member];

									if (member.SqlQueryInfo == null)
									{
										member.SqlQueryInfo = new SqlInfo(member.MemberExpression.Member)
										{
											Index = -2,
											Sql   = SubQuery.SqlQuery.Select.Columns[member.SequenceInfo.Index],
											Query = SqlQuery,
										};
									}

									return new[] { member.SqlQueryInfo };
								}
							}

							break;
					}

					throw new InvalidOperationException();
				}

				return base.ConvertToSql(expression, level, flags);
			}
		}

		#endregion
	}
}