Mercurial > pub > bltoolkit
view Source/Data/Linq/Builder/UpdateBuilder.cs @ 9:1e85f66cf767 default tip
update bltoolkit
author | nickolay |
---|---|
date | Thu, 05 Apr 2018 20:53:26 +0300 |
parents | f990fcb411a9 |
children |
line wrap: on
line source
using System; using System.Collections.Generic; using System.Linq; using System.Linq.Expressions; using System.Reflection; namespace BLToolkit.Data.Linq.Builder { using BLToolkit.Linq; using Data.Sql; using Reflection; class UpdateBuilder : MethodCallBuilder { #region Update protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) { return methodCall.IsQueryable("Update"); } protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) { var sequence = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0])); switch (methodCall.Arguments.Count) { case 1 : // int Update<T>(this IUpdateable<T> source) CheckAssociation(sequence); break; case 2 : // int Update<T>(this IQueryable<T> source, Expression<Func<T,T>> setter) { CheckAssociation(sequence); BuildSetter( builder, buildInfo, (LambdaExpression)methodCall.Arguments[1].Unwrap(), sequence, sequence.SqlQuery.Update.Items, sequence); break; } case 3 : { var expr = methodCall.Arguments[1].Unwrap(); if (expr is LambdaExpression) { // int Update<T>(this IQueryable<T> source, Expression<Func<T,bool>> predicate, Expression<Func<T,T>> setter) // sequence = builder.BuildWhere(buildInfo.Parent, sequence, (LambdaExpression)methodCall.Arguments[1].Unwrap(), false); CheckAssociation(sequence); BuildSetter( builder, buildInfo, (LambdaExpression)methodCall.Arguments[2].Unwrap(), sequence, sequence.SqlQuery.Update.Items, sequence); } else { // static int Update<TSource,TTarget>(this IQueryable<TSource> source, Table<TTarget> target, Expression<Func<TSource,TTarget>> setter) // var into = builder.BuildSequence(new BuildInfo(buildInfo, expr, new SqlQuery())); sequence.ConvertToIndex(null, 0, ConvertFlags.All); sequence.SqlQuery.ResolveWeakJoins(new List<ISqlTableSource>()); sequence.SqlQuery.Select.Columns.Clear(); BuildSetter( builder, buildInfo, (LambdaExpression)methodCall.Arguments[2].Unwrap(), into, sequence.SqlQuery.Update.Items, sequence); var sql = sequence.SqlQuery; sql.Select.Columns.Clear(); foreach (var item in sql.Update.Items) sql.Select.Columns.Add(new SqlQuery.Column(sql, item.Expression)); sql.Update.Table = ((TableBuilder.TableContext)into).SqlTable; } break; } } sequence.SqlQuery.QueryType = QueryType.Update; return new UpdateContext(buildInfo.Parent, sequence); } static void CheckAssociation(IBuildContext sequence) { var ctx = sequence as SelectContext; if (ctx != null && ctx.IsScalar) { var res = ctx.IsExpression(null, 0, RequestFor.Association); if (res.Result && res.Context is TableBuilder.AssociatedTableContext) { var atc = (TableBuilder.AssociatedTableContext)res.Context; sequence.SqlQuery.Update.Table = atc.SqlTable; } else { res = ctx.IsExpression(null, 0, RequestFor.Table); if (res.Result && res.Context is TableBuilder.TableContext) { var tc = (TableBuilder.TableContext)res.Context; if (sequence.SqlQuery.From.Tables.Count == 0 || sequence.SqlQuery.From.Tables[0].Source != tc.SqlQuery) sequence.SqlQuery.Update.Table = tc.SqlTable; } } } } protected override SequenceConvertInfo Convert( ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param) { return null; } #endregion #region Helpers internal static void BuildSetter( ExpressionBuilder builder, BuildInfo buildInfo, LambdaExpression setter, IBuildContext into, List<SqlQuery.SetExpression> items, IBuildContext sequence) { var path = Expression.Parameter(setter.Body.Type, "p"); var ctx = new ExpressionContext(buildInfo.Parent, sequence, setter); if (setter.Body.NodeType == ExpressionType.MemberInit) { var ex = (MemberInitExpression)setter.Body; var p = sequence.Parent; BuildSetter(builder, into, items, ctx, ex, path); builder.ReplaceParent(ctx, p); } else { var sqlInfo = ctx.ConvertToSql(setter.Body, 0, ConvertFlags.All); foreach (var info in sqlInfo) { if (info.Members.Count == 0) throw new LinqException("Object initializer expected for insert statement."); if (info.Members.Count != 1) throw new InvalidOperationException(); var member = info.Members[0]; var pe = Expression.MakeMemberAccess(path, member); var column = into.ConvertToSql(pe, 1, ConvertFlags.Field); var expr = info.Sql; if (expr is SqlParameter) { var type = member.MemberType == MemberTypes.Field ? ((FieldInfo) member).FieldType : ((PropertyInfo)member).PropertyType; if (TypeHelper.IsEnumOrNullableEnum(type)) { var memberAccessor = TypeAccessor.GetAccessor(member.DeclaringType)[member.Name]; ((SqlParameter)expr).SetEnumConverter(memberAccessor, builder.MappingSchema); } } items.Add(new SqlQuery.SetExpression(column[0].Sql, expr)); } } } static void BuildSetter( ExpressionBuilder builder, IBuildContext into, List<SqlQuery.SetExpression> items, IBuildContext ctx, MemberInitExpression expression, Expression path) { foreach (var binding in expression.Bindings) { var member = binding.Member; if (member is MethodInfo) member = TypeHelper.GetPropertyByMethod((MethodInfo)member); if (binding is MemberAssignment) { var ma = binding as MemberAssignment; var pe = Expression.MakeMemberAccess(path, member); if (ma.Expression is MemberInitExpression && !into.IsExpression(pe, 1, RequestFor.Field).Result) { BuildSetter( builder, into, items, ctx, (MemberInitExpression)ma.Expression, Expression.MakeMemberAccess(path, member)); } else { var column = into.ConvertToSql(pe, 1, ConvertFlags.Field); var expr = builder.ConvertToSqlExpression(ctx, ma.Expression, false); if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(ma.Expression.Type)) { var memberAccessor = TypeAccessor.GetAccessor(ma.Member.DeclaringType)[ma.Member.Name]; ((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema); } items.Add(new SqlQuery.SetExpression(column[0].Sql, expr)); } } else throw new InvalidOperationException(); } } internal static void ParseSet( ExpressionBuilder builder, BuildInfo buildInfo, LambdaExpression extract, LambdaExpression update, IBuildContext select, SqlTable table, List<SqlQuery.SetExpression> items) { var ext = extract.Body; while (ext.NodeType == ExpressionType.Convert || ext.NodeType == ExpressionType.ConvertChecked) ext = ((UnaryExpression)ext).Operand; if (ext.NodeType != ExpressionType.MemberAccess || ext.GetRootObject() != extract.Parameters[0]) throw new LinqException("Member expression expected for the 'Set' statement."); var body = (MemberExpression)ext; var member = body.Member; if (member is MethodInfo) member = TypeHelper.GetPropertyByMethod((MethodInfo)member); var members = body.GetMembers(); var name = members .Skip(1) .Select(ex => { var me = ex as MemberExpression; if (me == null) return null; var m = me.Member; if (m is MethodInfo) m = TypeHelper.GetPropertyByMethod((MethodInfo)m); return m; }) .Where(m => m != null && !TypeHelper.IsNullableValueMember(m)) .Select(m => m.Name) .Aggregate((s1,s2) => s1 + "." + s2); if (table != null && !table.Fields.ContainsKey(name)) throw new LinqException("Member '{0}.{1}' is not a table column.", member.DeclaringType.Name, name); var column = table != null ? table.Fields[name] : select.ConvertToSql( body, 1, ConvertFlags.Field)[0].Sql; //Expression.MakeMemberAccess(Expression.Parameter(member.DeclaringType, "p"), member), 1, ConvertFlags.Field)[0].Sql; var sp = select.Parent; var ctx = new ExpressionContext(buildInfo.Parent, select, update); var expr = builder.ConvertToSqlExpression(ctx, update.Body, false); builder.ReplaceParent(ctx, sp); if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(update.Body.Type)) { var memberAccessor = TypeAccessor.GetAccessor(body.Member.DeclaringType)[body.Member.Name]; ((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema); } items.Add(new SqlQuery.SetExpression(column, expr)); } internal static void ParseSet( ExpressionBuilder builder, BuildInfo buildInfo, LambdaExpression extract, Expression update, IBuildContext select, List<SqlQuery.SetExpression> items) { var ext = extract.Body; if (!ExpressionHelper.IsConstant(update.Type) && !builder.AsParameters.Contains(update)) builder.AsParameters.Add(update); while (ext.NodeType == ExpressionType.Convert || ext.NodeType == ExpressionType.ConvertChecked) ext = ((UnaryExpression)ext).Operand; if (ext.NodeType != ExpressionType.MemberAccess || ext.GetRootObject() != extract.Parameters[0]) throw new LinqException("Member expression expected for the 'Set' statement."); var body = (MemberExpression)ext; var member = body.Member; if (member is MethodInfo) member = TypeHelper.GetPropertyByMethod((MethodInfo)member); var column = select.ConvertToSql( body, 1, ConvertFlags.Field); //Expression.MakeMemberAccess(Expression.Parameter(member.DeclaringType, "p"), member), 1, ConvertFlags.Field); if (column.Length == 0) throw new LinqException("Member '{0}.{1}' is not a table column.", member.DeclaringType.Name, member.Name); var expr = builder.ConvertToSql(select, update, false, false); if (expr is SqlValueBase && TypeHelper.IsEnumOrNullableEnum(update.Type)) { var memberAccessor = TypeAccessor.GetAccessor(body.Member.DeclaringType)[body.Member.Name]; ((SqlValueBase)expr).SetEnumConverter(memberAccessor, builder.MappingSchema); } items.Add(new SqlQuery.SetExpression(column[0].Sql, expr)); } #endregion #region UpdateContext class UpdateContext : SequenceContextBase { public UpdateContext(IBuildContext parent, IBuildContext sequence) : base(parent, sequence, null) { } public override void BuildQuery<T>(Query<T> query, ParameterExpression queryParameter) { query.SetNonQueryQuery(); } public override Expression BuildExpression(Expression expression, int level) { throw new InvalidOperationException(); } public override SqlInfo[] ConvertToSql(Expression expression, int level, ConvertFlags flags) { throw new InvalidOperationException(); } public override SqlInfo[] ConvertToIndex(Expression expression, int level, ConvertFlags flags) { throw new InvalidOperationException(); } public override IsExpressionResult IsExpression(Expression expression, int level, RequestFor requestFlag) { throw new InvalidOperationException(); } public override IBuildContext GetContext(Expression expression, int level, BuildInfo buildInfo) { throw new InvalidOperationException(); } } #endregion #region Set internal class Set : MethodCallBuilder { protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) { return methodCall.IsQueryable("Set"); } protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) { var sequence = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0])); var extract = (LambdaExpression)methodCall.Arguments[1].Unwrap(); var update = methodCall.Arguments[2].Unwrap(); if (update.NodeType == ExpressionType.Lambda) ParseSet( builder, buildInfo, extract, (LambdaExpression)update, sequence, sequence.SqlQuery.Update.Table, sequence.SqlQuery.Update.Items); else ParseSet( builder, buildInfo, extract, update, sequence, sequence.SqlQuery.Update.Items); return sequence; } protected override SequenceConvertInfo Convert( ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param) { return null; } } #endregion } }