Mercurial > pub > bltoolkit
diff Source/Data/Linq/Builder/AggregationBuilder.cs @ 0:f990fcb411a9
Копия текущей версии из github
author | cin |
---|---|
date | Thu, 27 Mar 2014 21:46:09 +0400 |
parents | |
children |
line wrap: on
line diff
--- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/Source/Data/Linq/Builder/AggregationBuilder.cs Thu Mar 27 21:46:09 2014 +0400 @@ -0,0 +1,168 @@ +using System; +using System.Linq; +using System.Linq.Expressions; + +namespace BLToolkit.Data.Linq.Builder +{ + using BLToolkit.Linq; + using Data.Sql; + using Reflection; + + class AggregationBuilder : MethodCallBuilder + { + public static string[] MethodNames = new[] { "Average", "Min", "Max", "Sum" }; + + protected override bool CanBuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + return methodCall.IsQueryable(MethodNames); + } + + protected override IBuildContext BuildMethodCall(ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo) + { + var sequence = builder.BuildSequence(new BuildInfo(buildInfo, methodCall.Arguments[0])); + + if (sequence.SqlQuery.Select.IsDistinct || + sequence.SqlQuery.Select.TakeValue != null || + sequence.SqlQuery.Select.SkipValue != null || + !sequence.SqlQuery.GroupBy.IsEmpty) + { + sequence = new SubQueryContext(sequence); + } + + if (!sequence.SqlQuery.OrderBy.IsEmpty) + { + if (sequence.SqlQuery.Select.TakeValue == null && sequence.SqlQuery.Select.SkipValue == null) + sequence.SqlQuery.OrderBy.Items.Clear(); + else + sequence = new SubQueryContext(sequence); + } + + var context = new AggregationContext(buildInfo.Parent, sequence, methodCall); + var sql = sequence.ConvertToSql(null, 0, ConvertFlags.Field).Select(_ => _.Sql).ToArray(); + + if (sql.Length == 1 && sql[0] is SqlQuery) + { + var query = (SqlQuery)sql[0]; + + if (query.Select.Columns.Count == 1) + { + var join = SqlQuery.OuterApply(query); + context.SqlQuery.From.Tables[0].Joins.Add(join.JoinedTable); + sql[0] = query.Select.Columns[0]; + } + } + + context.Sql = context.SqlQuery; + context.FieldIndex = context.SqlQuery.Select.Add( + new SqlFunction(methodCall.Type, methodCall.Method.Name, sql)); + + return context; + } + + protected override SequenceConvertInfo Convert( + ExpressionBuilder builder, MethodCallExpression methodCall, BuildInfo buildInfo, ParameterExpression param) + { + return null; + } + + class AggregationContext : SequenceContextBase + { + public AggregationContext(IBuildContext parent, IBuildContext sequence, MethodCallExpression methodCall) + : base(parent, sequence, null) + { + _returnType = methodCall.Method.ReturnType; + _methodName = methodCall.Method.Name; + } + + readonly string _methodName; + readonly Type _returnType; + private SqlInfo[] _index; + + public int FieldIndex; + public ISqlExpression Sql; + + static object CheckNullValue(object value, object context) + { + if (value == null || value is DBNull) + throw new InvalidOperationException(string.Format("Function {0} returns non-nullable value, but result is NULL. Use nullable version of the function instead.", context)); + + return value; + } + + public override void BuildQuery<T>(Query<T> query, ParameterExpression queryParameter) + { + var expr = BuildExpression(FieldIndex); + var mapper = Builder.BuildMapper<object>(expr); + + query.SetElementQuery(mapper.Compile()); + } + + public override Expression BuildExpression(Expression expression, int level) + { + return BuildExpression(ConvertToIndex(expression, level, ConvertFlags.Field)[0].Index); + } + + Expression BuildExpression(int fieldIndex) + { + Expression expr; + + if (_returnType.IsClass || _methodName == "Sum" || TypeHelper.IsNullableType(_returnType)) + { + expr = Builder.BuildSql(_returnType, fieldIndex); + } + else + { + expr = Builder.BuildSql( + _returnType, + fieldIndex, + ReflectionHelper.Expressor<object>.MethodExpressor(o => CheckNullValue(o, o)), + Expression.Constant(_methodName)); + } + + return expr; + } + + public override SqlInfo[] ConvertToSql(Expression expression, int level, ConvertFlags flags) + { + switch (flags) + { + case ConvertFlags.All : + case ConvertFlags.Key : + case ConvertFlags.Field : return Sequence.ConvertToSql(expression, level + 1, flags); + } + + throw new InvalidOperationException(); + } + + public override SqlInfo[] ConvertToIndex(Expression expression, int level, ConvertFlags flags) + { + switch (flags) + { + case ConvertFlags.Field : + return _index ?? (_index = new[] + { + new SqlInfo { Query = Parent.SqlQuery, Index = Parent.SqlQuery.Select.Add(Sql), Sql = Sql, } + }); + } + + throw new InvalidOperationException(); + } + + public override IsExpressionResult IsExpression(Expression expression, int level, RequestFor requestFlag) + { + switch (requestFlag) + { + case RequestFor.Root : return new IsExpressionResult(Lambda != null && expression == Lambda.Parameters[0]); + case RequestFor.Expression : return IsExpressionResult.True; + } + + return IsExpressionResult.False; + } + + public override IBuildContext GetContext(Expression expression, int level, BuildInfo buildInfo) + { + throw new InvalidOperationException(); + } + } + } +}