view Source/Data/Linq/Builder/UpdateBuilder.cs @ 9:1e85f66cf767 default tip

update bltoolkit
author nickolay
date Thu, 05 Apr 2018 20:53:26 +0300
parents f990fcb411a9
children
line wrap: on
line source

using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;

namespace BLToolkit.Data.Linq.Builder
{
	using BLToolkit.Linq;
	using Data.Sql;
	using Reflection;

	class UpdateBuilder : MethodCallBuilder
	{
		#region Update

		protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo)
		{
			return methodCall.IsQueryable("Update");
		}

		protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo)
		{
			var sequence = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0]));

			switch (methodCall.Arguments.Count)
			{
				case 1 : // int Update<T>(this IUpdateable<T> source)
					CheckAssociation(sequence);
					break;

				case 2 : // int Update<T>(this IQueryable<T> source, Expression<Func<T,T>> setter)
					{
						CheckAssociation(sequence);

						BuildSetter(
							builder,
							buildInfo,
							(LambdaExpression)methodCall.Arguments[1].Unwrap(),
							sequence,
							sequence.SqlQuery.Update.Items,
							sequence);
						break;
					}

				case 3 :
					{
						var expr = methodCall.Arguments[1].Unwrap();

						if (expr is LambdaExpression)
						{
							// int Update<T>(this IQueryable<T> source, Expression<Func<T,bool>> predicate, Expression<Func<T,T>> setter)
							//
							sequence = builder.BuildWhere(buildInfo.Parent, sequence, (LambdaExpression)methodCall.Arguments[1].Unwrap(), false);

							CheckAssociation(sequence);

							BuildSetter(
								builder,
								buildInfo,
								(LambdaExpression)methodCall.Arguments[2].Unwrap(),
								sequence,
								sequence.SqlQuery.Update.Items,
								sequence);
						}
						else
						{
							// static int Update<TSource,TTarget>(this IQueryable<TSource> source, Table<TTarget> target, Expression<Func<TSource,TTarget>> setter)
							//
							var into = builder.BuildSequence(new BuildInfo(buildInfo, expr, new SqlQuery()));

							sequence.ConvertToIndex(null, 0, ConvertFlags.All);
							sequence.SqlQuery.ResolveWeakJoins(new List<ISqlTableSource>());
							sequence.SqlQuery.Select.Columns.Clear();

							BuildSetter(
								builder,
								buildInfo,
								(LambdaExpression)methodCall.Arguments[2].Unwrap(),
								into,
								sequence.SqlQuery.Update.Items,
								sequence);

							var sql = sequence.SqlQuery;

							sql.Select.Columns.Clear();

							foreach (var item in sql.Update.Items)
								sql.Select.Columns.Add(new SqlQuery.Column(sql, item.Expression));

							sql.Update.Table = ((TableBuilder.TableContext)into).SqlTable;
						}

						break;
					}
			}

			sequence.SqlQuery.QueryType = QueryType.Update;

			return new UpdateContext(buildInfo.Parent, sequence);
		}

		static void CheckAssociation(IBuildContext sequence)
		{
			var ctx = sequence as SelectContext;

			if (ctx != null && ctx.IsScalar)
			{
				var res = ctx.IsExpression(null, 0, RequestFor.Association);

				if (res.Result && res.Context is TableBuilder.AssociatedTableContext)
				{
					var atc = (TableBuilder.AssociatedTableContext)res.Context;
					sequence.SqlQuery.Update.Table = atc.SqlTable;
				}
				else
				{
					res = ctx.IsExpression(null, 0, RequestFor.Table);

					if (res.Result && res.Context is TableBuilder.TableContext)
					{
						var tc = (TableBuilder.TableContext)res.Context;

						if (sequence.SqlQuery.From.Tables.Count == 0 || sequence.SqlQuery.From.Tables[0].Source != tc.SqlQuery)
							sequence.SqlQuery.Update.Table = tc.SqlTable;
					}
				}
			}
		}

		protected override SequenceConvertInfo Convert(
			ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param)
		{
			return null;
		}

		#endregion

		#region Helpers

		internal static void BuildSetter(
			ExpressionBuilder            builder,
			BuildInfo                    buildInfo,
			LambdaExpression             setter,
			IBuildContext                into,
			List<SqlQuery.SetExpression> items,
			IBuildContext                sequence)
		{
			var path = Expression.Parameter(setter.Body.Type, "p");
			var ctx  = new ExpressionContext(buildInfo.Parent, sequence, setter);

			if (setter.Body.NodeType == ExpressionType.MemberInit)
			{
				var ex  = (MemberInitExpression)setter.Body;
				var p   = sequence.Parent;

				BuildSetter(builder, into, items, ctx, ex, path);

				builder.ReplaceParent(ctx, p);
			}
			else
			{
				var sqlInfo = ctx.ConvertToSql(setter.Body, 0, ConvertFlags.All);

				foreach (var info in sqlInfo)
				{
					if (info.Members.Count == 0)
						throw new LinqException("Object initializer expected for insert statement.");

					if (info.Members.Count != 1)
						throw new InvalidOperationException();

					var member = info.Members[0];
					var pe     = Expression.MakeMemberAccess(path, member);
					var column = into.ConvertToSql(pe, 1, ConvertFlags.Field);
					var expr   = info.Sql;

					if (expr is SqlParameter)
					{
						var type = member.MemberType == MemberTypes.Field ? 
							((FieldInfo)   member).FieldType :
							((PropertyInfo)member).PropertyType;

						if (TypeHelper.IsEnumOrNullableEnum(type))
						{
							var memberAccessor = TypeAccessor.GetAccessor(member.DeclaringType)[member.Name];
							((SqlParameter)expr).SetEnumConverter(memberAccessor, builder.MappingSchema);
						}
					}

					items.Add(new SqlQuery.SetExpression(column[0].Sql, expr));
				}
			}
		}

		static void BuildSetter(
			ExpressionBuilder            builder,
			IBuildContext                into,
			List<SqlQuery.SetExpression> items,
			IBuildContext                ctx,
			MemberInitExpression         expression,
			Expression                   path)
		{
			foreach (var binding in expression.Bindings)
			{
				var member  = binding.Member;

				if (member is MethodInfo)
					member = TypeHelper.GetPropertyByMethod((MethodInfo)member);

				if (binding is MemberAssignment)
				{
					var ma = binding as MemberAssignment;
					var pe = Expression.MakeMemberAccess(path, member);

					if (ma.Expression is MemberInitExpression && !into.IsExpression(pe, 1, RequestFor.Field).Result)
					{
						BuildSetter(
							builder,
							into,
							items,
							ctx,
							(MemberInitExpression)ma.Expression, Expression.MakeMemberAccess(path, member));
					}
					else
					{
						var column = into.ConvertToSql(pe, 1, ConvertFlags.Field);
						var expr   = builder.ConvertToSqlExpression(ctx, ma.Expression, false);

						if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(ma.Expression.Type))
						{
							var memberAccessor = TypeAccessor.GetAccessor(ma.Member.DeclaringType)[ma.Member.Name];
							((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema);
						}

						items.Add(new SqlQuery.SetExpression(column[0].Sql, expr));
					}
				}
				else
					throw new InvalidOperationException();
			}
		}

		internal static void ParseSet(
			ExpressionBuilder            builder,
			BuildInfo                    buildInfo,
			LambdaExpression             extract,
			LambdaExpression             update,
			IBuildContext                select,
			SqlTable                     table,
			List<SqlQuery.SetExpression> items)
		{
			var ext = extract.Body;

			while (ext.NodeType == ExpressionType.Convert || ext.NodeType == ExpressionType.ConvertChecked)
				ext = ((UnaryExpression)ext).Operand;

			if (ext.NodeType != ExpressionType.MemberAccess || ext.GetRootObject() != extract.Parameters[0])
				throw new LinqException("Member expression expected for the 'Set' statement.");

			var body   = (MemberExpression)ext;
			var member = body.Member;

			if (member is MethodInfo)
				member = TypeHelper.GetPropertyByMethod((MethodInfo)member);

			var members = body.GetMembers();
			var name    = members
				.Skip(1)
				.Select(ex =>
				{
					var me = ex as MemberExpression;

					if (me == null)
						return null;

					var m = me.Member;

					if (m is MethodInfo)
						m = TypeHelper.GetPropertyByMethod((MethodInfo)m);

					return m;
				})
				.Where(m => m != null && !TypeHelper.IsNullableValueMember(m))
				.Select(m => m.Name)
				.Aggregate((s1,s2) => s1 + "." + s2);

			if (table != null && !table.Fields.ContainsKey(name))
				throw new LinqException("Member '{0}.{1}' is not a table column.", member.DeclaringType.Name, name);

			var column = table != null ?
				table.Fields[name] :
				select.ConvertToSql(
					body, 1, ConvertFlags.Field)[0].Sql;
					//Expression.MakeMemberAccess(Expression.Parameter(member.DeclaringType, "p"), member), 1, ConvertFlags.Field)[0].Sql;
			var sp     = select.Parent;
			var ctx    = new ExpressionContext(buildInfo.Parent, select, update);
			var expr   = builder.ConvertToSqlExpression(ctx, update.Body, false);

			builder.ReplaceParent(ctx, sp);

			if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(update.Body.Type))
			{
				var memberAccessor = TypeAccessor.GetAccessor(body.Member.DeclaringType)[body.Member.Name];
				((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema);
			}

			items.Add(new SqlQuery.SetExpression(column, expr));
		}

		internal static void ParseSet(
			ExpressionBuilder            builder,
			BuildInfo                    buildInfo,
			LambdaExpression             extract,
			Expression                   update,
			IBuildContext                select,
			List<SqlQuery.SetExpression> items)
		{
			var ext = extract.Body;

			if (!ExpressionHelper.IsConstant(update.Type) && !builder.AsParameters.Contains(update))
				builder.AsParameters.Add(update);

			while (ext.NodeType == ExpressionType.Convert || ext.NodeType == ExpressionType.ConvertChecked)
				ext = ((UnaryExpression)ext).Operand;

			if (ext.NodeType != ExpressionType.MemberAccess || ext.GetRootObject() != extract.Parameters[0])
				throw new LinqException("Member expression expected for the 'Set' statement.");

			var body   = (MemberExpression)ext;
			var member = body.Member;

			if (member is MethodInfo)
				member = TypeHelper.GetPropertyByMethod((MethodInfo)member);

			var column = select.ConvertToSql(
				body, 1, ConvertFlags.Field);
				//Expression.MakeMemberAccess(Expression.Parameter(member.DeclaringType, "p"), member), 1, ConvertFlags.Field);

			if (column.Length == 0)
				throw new LinqException("Member '{0}.{1}' is not a table column.", member.DeclaringType.Name, member.Name);

			var expr   = builder.ConvertToSql(select, update, false, false);

			if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(update.Type))
			{
				var memberAccessor = TypeAccessor.GetAccessor(body.Member.DeclaringType)[body.Member.Name];
				((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema);
			}

			items.Add(new SqlQuery.SetExpression(column[0].Sql, expr));
		}

		#endregion

		#region UpdateContext

		class UpdateContext : SequenceContextBase
		{
			public UpdateContext(IBuildContext parent, IBuildContext sequence)
				: base(parent, sequence, null)
			{
			}

			public override void BuildQuery<T>(Query<T> query, ParameterExpression queryParameter)
			{
				query.SetNonQueryQuery();
			}

			public override Expression BuildExpression(Expression expression, int level)
			{
				throw new InvalidOperationException();
			}

			public override SqlInfo[] ConvertToSql(Expression expression, int level, ConvertFlags flags)
			{
				throw new InvalidOperationException();
			}

			public override SqlInfo[] ConvertToIndex(Expression expression, int level, ConvertFlags flags)
			{
				throw new InvalidOperationException();
			}

			public override IsExpressionResult IsExpression(Expression expression, int level, RequestFor requestFlag)
			{
				throw new InvalidOperationException();
			}

			public override IBuildContext GetContext(Expression expression, int level, BuildInfo buildInfo)
			{
				throw new InvalidOperationException();
			}
		}

		#endregion

		#region Set

		internal class Set : MethodCallBuilder
		{
			protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo)
			{
				return methodCall.IsQueryable("Set");
			}

			protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo)
			{
				var sequence = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0]));
				var extract  = (LambdaExpression)methodCall.Arguments[1].Unwrap();
				var update   =                   methodCall.Arguments[2].Unwrap();

				if (update.NodeType == ExpressionType.Lambda)
					ParseSet(
						builder,
						buildInfo,
						extract,
						(LambdaExpression)update,
						sequence,
						sequence.SqlQuery.Update.Table,
						sequence.SqlQuery.Update.Items);
				else
					ParseSet(
						builder,
						buildInfo,
						extract,
						update,
						sequence,
						sequence.SqlQuery.Update.Items);

				return sequence;
			}

			protected override SequenceConvertInfo Convert(
				ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param)
			{
				return null;
			}
		}

		#endregion
	}
}