Mercurial > pub > bltoolkit
diff Source/Data/Linq/Builder/UpdateBuilder.cs @ 0:f990fcb411a9
Копия текущей версии из github
author | cin |
---|---|
date | Thu, 27 Mar 2014 21:46:09 +0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Source/Data/Linq/Builder/UpdateBuilder.cs Thu Mar 27 21:46:09 2014 +0400 @@ -0,0 +1,443 @@ +using System; +using System.Collections.Generic; +using System.Linq; +using System.Linq.Expressions; +using System.Reflection; + +namespace BLToolkit.Data.Linq.Builder +{ + using BLToolkit.Linq; + using Data.Sql; + using Reflection; + + class UpdateBuilder : MethodCallBuilder + { + #region Update + + protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + return methodCall.IsQueryable("Update"); + } + + protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + var sequence = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0])); + + switch (methodCall.Arguments.Count) + { + case 1 : // int Update<T>(this IUpdateable<T> source) + CheckAssociation(sequence); + break; + + case 2 : // int Update<T>(this IQueryable<T> source, Expression<Func<T,T>> setter) + { + CheckAssociation(sequence); + + BuildSetter( + builder, + buildInfo, + (LambdaExpression)methodCall.Arguments[1].Unwrap(), + sequence, + sequence.SqlQuery.Update.Items, + sequence); + break; + } + + case 3 : + { + var expr = methodCall.Arguments[1].Unwrap(); + + if (expr is LambdaExpression) + { + // int Update<T>(this IQueryable<T> source, Expression<Func<T,bool>> predicate, Expression<Func<T,T>> setter) + // + sequence = builder.BuildWhere(buildInfo.Parent, sequence, (LambdaExpression)methodCall.Arguments[1].Unwrap(), false); + + CheckAssociation(sequence); + + BuildSetter( + builder, + buildInfo, + (LambdaExpression)methodCall.Arguments[2].Unwrap(), + sequence, + sequence.SqlQuery.Update.Items, + sequence); + } + else + { + // static int Update<TSource,TTarget>(this IQueryable<TSource> source, Table<TTarget> target, Expression<Func<TSource,TTarget>> setter) + // + var into = builder.BuildSequence(new BuildInfo(buildInfo, expr, new SqlQuery())); + + sequence.ConvertToIndex(null, 0, ConvertFlags.All); + sequence.SqlQuery.ResolveWeakJoins(new List<ISqlTableSource>()); + sequence.SqlQuery.Select.Columns.Clear(); + + BuildSetter( + builder, + buildInfo, + (LambdaExpression)methodCall.Arguments[2].Unwrap(), + into, + sequence.SqlQuery.Update.Items, + sequence); + + var sql = sequence.SqlQuery; + + sql.Select.Columns.Clear(); + + foreach (var item in sql.Update.Items) + sql.Select.Columns.Add(new SqlQuery.Column(sql, item.Expression)); + + sql.Update.Table = ((TableBuilder.TableContext)into).SqlTable; + } + + break; + } + } + + sequence.SqlQuery.QueryType = QueryType.Update; + + return new UpdateContext(buildInfo.Parent, sequence); + } + + static void CheckAssociation(IBuildContext sequence) + { + var ctx = sequence as SelectContext; + + if (ctx != null && ctx.IsScalar) + { + var res = ctx.IsExpression(null, 0, RequestFor.Association); + + if (res.Result && res.Context is TableBuilder.AssociatedTableContext) + { + var atc = (TableBuilder.AssociatedTableContext)res.Context; + sequence.SqlQuery.Update.Table = atc.SqlTable; + } + else + { + res = ctx.IsExpression(null, 0, RequestFor.Table); + + if (res.Result && res.Context is TableBuilder.TableContext) + { + var tc = (TableBuilder.TableContext)res.Context; + + if (sequence.SqlQuery.From.Tables.Count == 0 || sequence.SqlQuery.From.Tables[0].Source != tc.SqlQuery) + sequence.SqlQuery.Update.Table = tc.SqlTable; + } + } + } + } + + protected override SequenceConvertInfo Convert( + ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param) + { + return null; + } + + #endregion + + #region Helpers + + internal static void BuildSetter( + ExpressionBuilder builder, + BuildInfo buildInfo, + LambdaExpression setter, + IBuildContext into, + List<SqlQuery.SetExpression> items, + IBuildContext sequence) + { + var path = Expression.Parameter(setter.Body.Type, "p"); + var ctx = new ExpressionContext(buildInfo.Parent, sequence, setter); + + if (setter.Body.NodeType == ExpressionType.MemberInit) + { + var ex = (MemberInitExpression)setter.Body; + var p = sequence.Parent; + + BuildSetter(builder, into, items, ctx, ex, path); + + builder.ReplaceParent(ctx, p); + } + else + { + var sqlInfo = ctx.ConvertToSql(setter.Body, 0, ConvertFlags.All); + + foreach (var info in sqlInfo) + { + if (info.Members.Count == 0) + throw new LinqException("Object initializer expected for insert statement."); + + if (info.Members.Count != 1) + throw new InvalidOperationException(); + + var member = info.Members[0]; + var pe = Expression.MakeMemberAccess(path, member); + var column = into.ConvertToSql(pe, 1, ConvertFlags.Field); + var expr = info.Sql; + + if (expr is SqlParameter) + { + var type = member.MemberType == MemberTypes.Field ? + ((FieldInfo) member).FieldType : + ((PropertyInfo)member).PropertyType; + + if (TypeHelper.IsEnumOrNullableEnum(type)) + { + var memberAccessor = TypeAccessor.GetAccessor(member.DeclaringType)[member.Name]; + ((SqlParameter)expr).SetEnumConverter(memberAccessor, builder.MappingSchema); + } + } + + items.Add(new SqlQuery.SetExpression(column[0].Sql, expr)); + } + } + } + + static void BuildSetter( + ExpressionBuilder builder, + IBuildContext into, + List<SqlQuery.SetExpression> items, + IBuildContext ctx, + MemberInitExpression expression, + Expression path) + { + foreach (var binding in expression.Bindings) + { + var member = binding.Member; + + if (member is MethodInfo) + member = TypeHelper.GetPropertyByMethod((MethodInfo)member); + + if (binding is MemberAssignment) + { + var ma = binding as MemberAssignment; + var pe = Expression.MakeMemberAccess(path, member); + + if (ma.Expression is MemberInitExpression && !into.IsExpression(pe, 1, RequestFor.Field).Result) + { + BuildSetter( + builder, + into, + items, + ctx, + (MemberInitExpression)ma.Expression, Expression.MakeMemberAccess(path, member)); + } + else + { + var column = into.ConvertToSql(pe, 1, ConvertFlags.Field); + var expr = builder.ConvertToSqlExpression(ctx, ma.Expression, false); + + if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(ma.Expression.Type)) + { + var memberAccessor = TypeAccessor.GetAccessor(ma.Member.DeclaringType)[ma.Member.Name]; + ((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema); + } + + items.Add(new SqlQuery.SetExpression(column[0].Sql, expr)); + } + } + else + throw new InvalidOperationException(); + } + } + + internal static void ParseSet( + ExpressionBuilder builder, + BuildInfo buildInfo, + LambdaExpression extract, + LambdaExpression update, + IBuildContext select, + SqlTable table, + List<SqlQuery.SetExpression> items) + { + var ext = extract.Body; + + while (ext.NodeType == ExpressionType.Convert || ext.NodeType == ExpressionType.ConvertChecked) + ext = ((UnaryExpression)ext).Operand; + + if (ext.NodeType != ExpressionType.MemberAccess || ext.GetRootObject() != extract.Parameters[0]) + throw new LinqException("Member expression expected for the 'Set' statement."); + + var body = (MemberExpression)ext; + var member = body.Member; + + if (member is MethodInfo) + member = TypeHelper.GetPropertyByMethod((MethodInfo)member); + + var members = body.GetMembers(); + var name = members + .Skip(1) + .Select(ex => + { + var me = ex as MemberExpression; + + if (me == null) + return null; + + var m = me.Member; + + if (m is MethodInfo) + m = TypeHelper.GetPropertyByMethod((MethodInfo)m); + + return m; + }) + .Where(m => m != null && !TypeHelper.IsNullableValueMember(m)) + .Select(m => m.Name) + .Aggregate((s1,s2) => s1 + "." + s2); + + if (table != null && !table.Fields.ContainsKey(name)) + throw new LinqException("Member '{0}.{1}' is not a table column.", member.DeclaringType.Name, name); + + var column = table != null ? + table.Fields[name] : + select.ConvertToSql( + body, 1, ConvertFlags.Field)[0].Sql; + //Expression.MakeMemberAccess(Expression.Parameter(member.DeclaringType, "p"), member), 1, ConvertFlags.Field)[0].Sql; + var sp = select.Parent; + var ctx = new ExpressionContext(buildInfo.Parent, select, update); + var expr = builder.ConvertToSqlExpression(ctx, update.Body, false); + + builder.ReplaceParent(ctx, sp); + + if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(update.Body.Type)) + { + var memberAccessor = TypeAccessor.GetAccessor(body.Member.DeclaringType)[body.Member.Name]; + ((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema); + } + + items.Add(new SqlQuery.SetExpression(column, expr)); + } + + internal static void ParseSet( + ExpressionBuilder builder, + BuildInfo buildInfo, + LambdaExpression extract, + Expression update, + IBuildContext select, + List<SqlQuery.SetExpression> items) + { + var ext = extract.Body; + + if (!ExpressionHelper.IsConstant(update.Type) && !builder.AsParameters.Contains(update)) + builder.AsParameters.Add(update); + + while (ext.NodeType == ExpressionType.Convert || ext.NodeType == ExpressionType.ConvertChecked) + ext = ((UnaryExpression)ext).Operand; + + if (ext.NodeType != ExpressionType.MemberAccess || ext.GetRootObject() != extract.Parameters[0]) + throw new LinqException("Member expression expected for the 'Set' statement."); + + var body = (MemberExpression)ext; + var member = body.Member; + + if (member is MethodInfo) + member = TypeHelper.GetPropertyByMethod((MethodInfo)member); + + var column = select.ConvertToSql( + body, 1, ConvertFlags.Field); + //Expression.MakeMemberAccess(Expression.Parameter(member.DeclaringType, "p"), member), 1, ConvertFlags.Field); + + if (column.Length == 0) + throw new LinqException("Member '{0}.{1}' is not a table column.", member.DeclaringType.Name, member.Name); + + var expr = builder.ConvertToSql(select, update, false, false); + + if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(update.Type)) + { + var memberAccessor = TypeAccessor.GetAccessor(body.Member.DeclaringType)[body.Member.Name]; + ((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema); + } + + items.Add(new SqlQuery.SetExpression(column[0].Sql, expr)); + } + + #endregion + + #region UpdateContext + + class UpdateContext : SequenceContextBase + { + public UpdateContext(IBuildContext parent, IBuildContext sequence) + : base(parent, sequence, null) + { + } + + public override void BuildQuery<T>(Query<T> query, ParameterExpression queryParameter) + { + query.SetNonQueryQuery(); + } + + public override Expression BuildExpression(Expression expression, int level) + { + throw new InvalidOperationException(); + } + + public override SqlInfo[] ConvertToSql(Expression expression, int level, ConvertFlags flags) + { + throw new InvalidOperationException(); + } + + public override SqlInfo[] ConvertToIndex(Expression expression, int level, ConvertFlags flags) + { + throw new InvalidOperationException(); + } + + public override IsExpressionResult IsExpression(Expression expression, int level, RequestFor requestFlag) + { + throw new InvalidOperationException(); + } + + public override IBuildContext GetContext(Expression expression, int level, BuildInfo buildInfo) + { + throw new InvalidOperationException(); + } + } + + #endregion + + #region Set + + internal class Set : MethodCallBuilder + { + protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + return methodCall.IsQueryable("Set"); + } + + protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + var sequence = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0])); + var extract = (LambdaExpression)methodCall.Arguments[1].Unwrap(); + var update = methodCall.Arguments[2].Unwrap(); + + if (update.NodeType == ExpressionType.Lambda) + ParseSet( + builder, + buildInfo, + extract, + (LambdaExpression)update, + sequence, + sequence.SqlQuery.Update.Table, + sequence.SqlQuery.Update.Items); + else + ParseSet( + builder, + buildInfo, + extract, + update, + sequence, + sequence.SqlQuery.Update.Items); + + return sequence; + } + + protected override SequenceConvertInfo Convert( + ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param) + { + return null; + } + } + + #endregion + } +}