diff --git a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs index fdcd59916ed..54e1dbb54bc 100644 --- a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs +++ b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs @@ -23,11 +23,22 @@ namespace MongoDB.Bson.Serialization { + /// + /// An interface implemented by BsonClassMapSerializer. + /// + public interface IBsonClassMapSerializer + { + /// + /// Gets the class map for a BsonClassMapSerializer. + /// + public BsonClassMap ClassMap { get; } + } + /// /// Represents a serializer for a class map. /// /// The type of the class. - public sealed class BsonClassMapSerializer : SerializerBase, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention + public sealed class BsonClassMapSerializer : SerializerBase, IBsonClassMapSerializer, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention { // private fields private readonly BsonClassMap _classMap; @@ -57,6 +68,9 @@ public BsonClassMapSerializer(BsonClassMap classMap) } // public properties + /// + public BsonClassMap ClassMap => _classMap; + /// public IDiscriminatorConvention DiscriminatorConvention => _classMap.GetDiscriminatorConvention(); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs index d4c74a22b09..0f08c7cfe03 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs @@ -17,25 +17,27 @@ using MongoDB.Bson.IO; using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; +using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { internal static class ConstantExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(ConstantExpression constantExpression) + public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer targetSerializer) { - var constantType = constantExpression.Type; - var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); - return Translate(constantExpression, constantSerializer); - } + var resultSerializer = targetSerializer; + if (resultSerializer == null) + { + var constantType = constantExpression.Type; + resultSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType); + } - public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer constantSerializer) - { - var constantValue = constantExpression.Value; - var serializedValue = constantSerializer.ToBsonValue(constantValue); + var value = constantExpression.Value; + var serializedValue = SerializationHelper.SerializeValue(resultSerializer, value); var ast = AstExpression.Constant(serializedValue); - return new AggregationExpression(constantExpression, ast, constantSerializer); + + return new AggregationExpression(constantExpression, ast, resultSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs index abb3ba1766d..bebb6dc6f6a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs @@ -27,7 +27,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg internal static class ExpressionToAggregationExpressionTranslator { // public static methods - public static AggregationExpression Translate(TranslationContext context, Expression expression) + public static AggregationExpression Translate(TranslationContext context, Expression expression, IBsonSerializer targetSerializer = null) { switch (expression.NodeType) { @@ -63,11 +63,11 @@ public static AggregationExpression Translate(TranslationContext context, Expres case ExpressionType.ArrayLength: return ArrayLengthExpressionToAggregationExpressionTranslator.Translate(context, (UnaryExpression)expression); case ExpressionType.Call: - return MethodCallExpressionToAggregationExpressionTranslator.Translate(context, (MethodCallExpression)expression); + return MethodCallExpressionToAggregationExpressionTranslator.Translate(context, (MethodCallExpression)expression, targetSerializer); case ExpressionType.Conditional: return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression); case ExpressionType.Constant: - return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression); + return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression, targetSerializer); case ExpressionType.Index: return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression); case ExpressionType.ListInit: @@ -75,13 +75,13 @@ public static AggregationExpression Translate(TranslationContext context, Expres case ExpressionType.MemberAccess: return MemberExpressionToAggregationExpressionTranslator.Translate(context, (MemberExpression)expression); case ExpressionType.MemberInit: - return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression); + return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression, targetSerializer); case ExpressionType.Negate: return NegateExpressionToAggregationExpressionTranslator.Translate(context, (UnaryExpression)expression); case ExpressionType.New: - return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression); + return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression, targetSerializer); case ExpressionType.NewArrayInit: - return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression); + return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression, targetSerializer); case ExpressionType.Parameter: return ParameterExpressionToAggregationExpressionTranslator.Translate(context, (ParameterExpression)expression); case ExpressionType.TypeIs: @@ -91,19 +91,19 @@ public static AggregationExpression Translate(TranslationContext context, Expres throw new ExpressionNotSupportedException(expression); } - public static AggregationExpression TranslateEnumerable(TranslationContext context, Expression expression) + public static AggregationExpression TranslateEnumerable(TranslationContext context, Expression expression, IBsonSerializer targetSerializer = null) { - var aggregateExpression = Translate(context, expression); + var aggregateExpression = Translate(context, expression, targetSerializer); - var serializer = aggregateExpression.Serializer; - if (serializer is IWrappedEnumerableSerializer wrappedEnumerableSerializer) + var resultSerializer = aggregateExpression.Serializer; + if (resultSerializer is IWrappedEnumerableSerializer wrappedEnumerableSerializer) { var enumerableFieldName = wrappedEnumerableSerializer.EnumerableFieldName; - var enumerableElementSerializer = wrappedEnumerableSerializer.EnumerableElementSerializer; - var enumerableSerializer = IEnumerableSerializer.Create(enumerableElementSerializer); - var ast = AstExpression.GetField(aggregateExpression.Ast, enumerableFieldName); + var itemSerializer = wrappedEnumerableSerializer.EnumerableElementSerializer; - return new AggregationExpression(aggregateExpression.Expression, ast, enumerableSerializer); + var ast = AstExpression.GetField(aggregateExpression.Ast, enumerableFieldName); + resultSerializer = IEnumerableSerializer.Create(itemSerializer); + return new AggregationExpression(aggregateExpression.Expression, ast, resultSerializer); } return aggregateExpression; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs index e1a8cd5f399..f24fe68a1a8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs @@ -28,22 +28,28 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class MemberInitExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression) + public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression, IBsonSerializer targetSerializer) { if (expression.Type == typeof(BsonDocument)) { return NewBsonDocumentExpressionToAggregationExpressionTranslator.Translate(context, expression); } - return Translate(context, expression, expression.NewExpression, expression.Bindings); + return Translate(context, expression, expression.NewExpression, expression.Bindings, targetSerializer); } public static AggregationExpression Translate( TranslationContext context, Expression expression, NewExpression newExpression, - IReadOnlyList bindings) + IReadOnlyList bindings, + IBsonSerializer targetSerializer) { + if (targetSerializer != null) + { + return TranslateWithTargetSerializer(context, expression, newExpression, bindings, targetSerializer); + } + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct var constructorArguments = newExpression.Arguments; var computedFields = new List(); @@ -83,8 +89,7 @@ public static AggregationExpression Translate( foreach (var binding in bindings) { var memberAssignment = (MemberAssignment)binding; - var member = memberAssignment.Member; - var memberMap = FindMemberMap(expression, classMap, member.Name); + var memberMap = FindMemberMap(expression, classMap, memberInfo: memberAssignment.Member); var valueExpression = memberAssignment.Expression; var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer); @@ -100,6 +105,70 @@ public static AggregationExpression Translate( return new AggregationExpression(expression, ast, serializer); } + private static AggregationExpression TranslateWithTargetSerializer( + TranslationContext context, + Expression expression, + NewExpression newExpression, + IReadOnlyList bindings, + IBsonSerializer targetSerializer) + { + var resultSerializer = targetSerializer as IBsonDocumentSerializer; + if (resultSerializer == null) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer.GetType()} does not implement IBsonDocumentSerializer."); + } + + var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct + var constructorArguments = newExpression.Arguments; + var computedFields = new List(); + + if (constructorInfo != null && constructorArguments.Count > 0) + { + var constructorParameters = constructorInfo.GetParameters(); + + // if the documentSerializer is a BsonClassMappedSerializer we can use the classMap and creatorMap + var classMap = (resultSerializer as IBsonClassMapSerializer)?.ClassMap; + var creatorMap = classMap == null ? null : FindMatchingCreatorMap(classMap, constructorInfo); + if (creatorMap == null && classMap != null) + { + throw new ExpressionNotSupportedException(expression, because: "no matching creator map found"); + } + var creatorMapArguments = creatorMap?.Arguments?.ToArray(); + + for (var i = 0; i < constructorParameters.Length; i++) + { + var argumentExpression = constructorArguments[i]; + + // if we have a classMap (and therefore a creatorMap also) use them + // otherwise fall back to matching constructor parameter names to member names + var (elementName, memberSerializer) = classMap != null ? + FindMemberElementNameAndSerializer(argumentExpression, classMap, memberInfo: creatorMapArguments[i]) : + FindMemberElementNameAndSerializer(argumentExpression, resultSerializer, constructorParameterName: constructorParameters[i].Name); + + var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression, targetSerializer: memberSerializer); + computedFields.Add(AstExpression.ComputedField(elementName, argumentTranslation.Ast)); + } + } + + foreach (var binding in bindings) + { + var memberAssignment = (MemberAssignment)binding; + var member = memberAssignment.Member; + var valueExpression = memberAssignment.Expression; + if (!resultSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo)) + { + throw new ExpressionNotSupportedException(valueExpression, expression, because: $"couldn't find member {member.Name}"); + } + var memberSerializer = memberSerializationInfo.Serializer; + + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression, memberSerializer); + computedFields.Add(AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast)); + } + + var ast = AstExpression.ComputedDocument(computedFields); + return new AggregationExpression(expression, ast, resultSerializer); + } + private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap) { BsonClassMap baseClassMap = null; @@ -190,11 +259,44 @@ private static void EnsureDefaultValue(BsonMemberMap memberMap) memberMap.SetDefaultValue(defaultValue); } - private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName) + private static BsonCreatorMap FindMatchingCreatorMap(BsonClassMap classMap, ConstructorInfo constructorInfo) + => classMap?.CreatorMaps.FirstOrDefault(m => m.MemberInfo.Equals(constructorInfo)); + + private static (string, IBsonSerializer) FindMemberElementNameAndSerializer( + Expression expression, + BsonClassMap classMap, + MemberInfo memberInfo) + { + var memberMap = FindMemberMap(expression, classMap, memberInfo); + return (memberMap.ElementName, memberMap.GetSerializer()); + } + + private static (string, IBsonSerializer) FindMemberElementNameAndSerializer( + Expression expression, + IBsonDocumentSerializer documentSerializer, + string constructorParameterName) + { + // case insensitive GetMember could return some false hits but TryGetMemberSerializationInfo will filter them out + var bindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy | BindingFlags.IgnoreCase; + foreach (var memberInfo in documentSerializer.ValueType.GetMember(constructorParameterName, bindingFlags)) + { + if (documentSerializer.TryGetMemberSerializationInfo(memberInfo.Name, out var serializationInfo)) + { + return (serializationInfo.ElementName, serializationInfo.Serializer); + } + } + + throw new ExpressionNotSupportedException(expression, because: $"no matching member map found for constructor parameter: {constructorParameterName}"); + } + + private static BsonMemberMap FindMemberMap( + Expression expression, + BsonClassMap classMap, + MemberInfo memberInfo) { foreach (var memberMap in classMap.DeclaredMemberMaps) { - if (memberMap.MemberName == memberName) + if (memberMap.MemberInfo == memberInfo) { return memberMap; } @@ -202,10 +304,10 @@ private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap c if (classMap.BaseClassMap != null) { - return FindMemberMap(expression, classMap.BaseClassMap, memberName); + return FindMemberMap(expression, classMap.BaseClassMap, memberInfo); } - throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}"); + throw new ExpressionNotSupportedException(expression, because: $"no member map found for member: {memberInfo.Name}"); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs index 1fe6689fbdc..eb31beaf92a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs @@ -14,39 +14,40 @@ */ using System.Linq.Expressions; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { internal static class MethodCallExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { switch (expression.Method.Name) { case "Abs": return AbsMethodToAggregationExpressionTranslator.Translate(context, expression); case "Add": return AddMethodToAggregationExpressionTranslator.Translate(context, expression); case "AddToSet": return AddToSetMethodToAggregationExpressionTranslator.Translate(context, expression); - case "Aggregate": return AggregateMethodToAggregationExpressionTranslator.Translate(context, expression); + case "Aggregate": return AggregateMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); case "All": return AllMethodToAggregationExpressionTranslator.Translate(context, expression); case "Any": return AnyMethodToAggregationExpressionTranslator.Translate(context, expression); - case "AsQueryable": return AsQueryableMethodToAggregationExpressionTranslator.Translate(context, expression); + case "AsQueryable": return AsQueryableMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); case "Average": return AverageMethodToAggregationExpressionTranslator.Translate(context, expression); case "Ceiling": return CeilingMethodToAggregationExpressionTranslator.Translate(context, expression); case "CompareTo": return CompareToMethodToAggregationExpressionTranslator.Translate(context, expression); - case "Concat": return ConcatMethodToAggregationExpressionTranslator.Translate(context, expression); + case "Concat": return ConcatMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); case "Constant": return ConstantMethodToAggregationExpressionTranslator.Translate(context, expression); case "Contains": return ContainsMethodToAggregationExpressionTranslator.Translate(context, expression); case "ContainsKey": return ContainsKeyMethodToAggregationExpressionTranslator.Translate(context, expression); case "ContainsValue": return ContainsValueMethodToAggregationExpressionTranslator.Translate(context, expression); case "CovariancePopulation": return CovariancePopulationMethodToAggregationExpressionTranslator.Translate(context, expression); case "CovarianceSample": return CovarianceSampleMethodToAggregationExpressionTranslator.Translate(context, expression); - case "Create": return CreateMethodToAggregationExpressionTranslator.Translate(context, expression); + case "Create": return CreateMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); case "DateFromString": return DateFromStringMethodToAggregationExpressionTranslator.Translate(context, expression); - case "DefaultIfEmpty": return DefaultIfEmptyMethodToAggregationExpressionTranslator.Translate(context, expression); + case "DefaultIfEmpty": return DefaultIfEmptyMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); case "DenseRank": return DenseRankMethodToAggregationExpressionTranslator.Translate(context, expression); case "Derivative": return DerivativeMethodToAggregationExpressionTranslator.Translate(context, expression); - case "Distinct": return DistinctMethodToAggregationExpressionTranslator.Translate(context, expression); + case "Distinct": return DistinctMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); case "DocumentNumber": return DocumentNumberMethodToAggregationExpressionTranslator.Translate(context, expression); case "Equals": return EqualsMethodToAggregationExpressionTranslator.Translate(context, expression); case "Except": return ExceptMethodToAggregationExpressionTranslator.Translate(context, expression); @@ -122,7 +123,7 @@ public static AggregationExpression Translate(TranslationContext context, Method case "Append": case "Prepend": - return AppendOrPrependMethodToAggregationExpressionTranslator.Translate(context, expression); + return AppendOrPrependMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); case "Bottom": case "BottomN": @@ -140,7 +141,7 @@ public static AggregationExpression Translate(TranslationContext context, Method case "ElementAt": case "ElementAtOrDefault": - return ElementAtMethodToAggregationExpressionTranslator.Translate(context, expression); + return ElementAtMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); case "First": case "FirstOrDefault": diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs index e2e2f0f1b8c..8bcbef04746 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs @@ -19,6 +19,7 @@ using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { @@ -34,27 +35,33 @@ internal static class AggregateMethodToAggregationExpressionTranslator QueryableMethod.AggregateWithSeedFuncAndResultSelector }; - private static readonly MethodInfo[] __aggregateWithoutSeedMethods = + private static readonly MethodInfo[] __aggregateWithFuncMethods = { EnumerableMethod.AggregateWithFunc, QueryableMethod.AggregateWithFunc }; - private static readonly MethodInfo[] __aggregateWithSeedMethods = + private static readonly MethodInfo[] __aggregateWithSeedAndFuncMethods = { EnumerableMethod.AggregateWithSeedAndFunc, + QueryableMethod.AggregateWithSeedAndFunc + }; + + private static readonly MethodInfo[] __aggregateWithSeedAndFuncAndResultSelectorMethods = + { EnumerableMethod.AggregateWithSeedFuncAndResultSelector, - QueryableMethod.AggregateWithSeedAndFunc, QueryableMethod.AggregateWithSeedFuncAndResultSelector }; - private static readonly MethodInfo[] __aggregateWithSeedFuncAndResultSelectorMethods = - { + private static readonly MethodInfo[] __aggregateIncludingSeedMethods = + { + EnumerableMethod.AggregateWithSeedAndFunc, EnumerableMethod.AggregateWithSeedFuncAndResultSelector, + QueryableMethod.AggregateWithSeedAndFunc, QueryableMethod.AggregateWithSeedFuncAndResultSelector }; - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { var method = expression.Method; var arguments = expression.Arguments; @@ -62,11 +69,12 @@ public static AggregationExpression Translate(TranslationContext context, Method if (method.IsOneOf(__aggregateMethods)) { var sourceExpression = arguments[0]; - var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + var sourceTargetSerializer = GetSourceTargetSerializer(method, targetSerializer); + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, sourceTargetSerializer); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - if (method.IsOneOf(__aggregateWithoutSeedMethods)) + if (method.IsOneOf(__aggregateWithFuncMethods)) { var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]); var funcParameters = funcLambda.Parameters; @@ -75,7 +83,8 @@ public static AggregationExpression Translate(TranslationContext context, Method var itemParameter = funcParameters[1]; var itemSymbol = context.CreateSymbolWithVarName(itemParameter, varName: "this", itemSerializer); // note: MQL uses $$this for the item being processed var funcContext = context.WithSymbols(accumulatorSymbol, itemSymbol); - var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body); + var funcTargetSerializer = GetFuncTargetSerializer(method, targetSerializer); + var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body, funcTargetSerializer); var (sourceVarBinding, sourceAst) = AstExpression.UseVarIfNotSimple("source", sourceTranslation.Ast); var seedVar = AstExpression.Var("seed"); @@ -95,10 +104,11 @@ public static AggregationExpression Translate(TranslationContext context, Method return new AggregationExpression(expression, ast, itemSerializer); } - else if (method.IsOneOf(__aggregateWithSeedMethods)) + else if (method.IsOneOf(__aggregateIncludingSeedMethods)) { var seedExpression = arguments[1]; - var seedTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, seedExpression); + var seedTargetSerializer = GetSeedTargetSerializer(method, targetSerializer); + var seedTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, seedExpression, seedTargetSerializer); var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]); var funcParameters = funcLambda.Parameters; @@ -108,7 +118,8 @@ public static AggregationExpression Translate(TranslationContext context, Method var itemParameter = funcParameters[1]; var itemSymbol = context.CreateSymbolWithVarName(itemParameter, varName: "this", itemSerializer); // note: MQL uses $$this for the item being processed var funcContext = context.WithSymbols(accumulatorSymbol, itemSymbol); - var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body); + var funcTargetSerializer = GetFuncTargetSerializer(method, targetSerializer); + var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body, funcTargetSerializer); var ast = AstExpression.Reduce( input: sourceTranslation.Ast, @@ -116,13 +127,14 @@ public static AggregationExpression Translate(TranslationContext context, Method @in: funcTranslation.Ast); var serializer = accumulatorSerializer; - if (method.IsOneOf(__aggregateWithSeedFuncAndResultSelectorMethods)) + if (method.IsOneOf(__aggregateWithSeedAndFuncAndResultSelectorMethods)) { var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]); var resultSelectorAccumulatorParameter = resultSelectorLambda.Parameters[0]; var resultSelectorAccumulatorSymbol = context.CreateSymbol(resultSelectorAccumulatorParameter, accumulatorSerializer); var resultSelectorContext = context.WithSymbol(resultSelectorAccumulatorSymbol); - var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body); + var resultSelectorTargetSerializer = GetResultSelectorTargetSerializer(method, targetSerializer); + var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body, resultSelectorTargetSerializer); ast = AstExpression.Let( var: AstExpression.VarBinding(resultSelectorAccumulatorSymbol.Var, ast), @@ -136,5 +148,57 @@ public static AggregationExpression Translate(TranslationContext context, Method throw new ExpressionNotSupportedException(expression); } + + private static IBsonSerializer GetFuncTargetSerializer(MethodInfo method, IBsonSerializer targetSerializer) + { + if (method.IsOneOf(__aggregateWithFuncMethods, __aggregateWithSeedAndFuncMethods)) + { + return targetSerializer; + } + + return null; + } + + private static IBsonSerializer GetResultSelectorTargetSerializer(MethodInfo method, IBsonSerializer targetSerializer) + { + if (method.IsOneOf(__aggregateWithSeedAndFuncAndResultSelectorMethods)) + { + return targetSerializer; + } + + return null; + } + + private static IBsonSerializer GetSeedTargetSerializer(MethodInfo method, IBsonSerializer targetSerializer) + { + if (method.IsOneOf(__aggregateWithSeedAndFuncMethods)) + { + return targetSerializer; + } + + return null; + } + + private static IBsonSerializer GetSourceTargetSerializer(MethodInfo method, IBsonSerializer targetSerializer) + { + IBsonSerializer itemSerializer = null; + if (method.IsOneOf(__aggregateWithFuncMethods)) + { + itemSerializer = targetSerializer; + } + + if (method.IsOneOf(__aggregateWithSeedAndFuncMethods)) + { + var genericArguments = method.GetGenericArguments(); + var sourceType = genericArguments[0]; + var accumulateType = genericArguments[1]; + if (sourceType == accumulateType) + { + itemSerializer = targetSerializer; + } + } + + return itemSerializer == null ? null : IEnumerableSerializer.Create(itemSerializer); + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs index fa87251f83a..7f1b193fe3b 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs @@ -15,6 +15,7 @@ using System.Linq.Expressions; using System.Reflection; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; @@ -38,7 +39,7 @@ internal static class AppendOrPrependMethodToAggregationExpressionTranslator QueryableMethod.Append }; - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { var method = expression.Method; var arguments = expression.Arguments; @@ -48,32 +49,22 @@ public static AggregationExpression Translate(TranslationContext context, Method var sourceExpression = arguments[0]; var elementExpression = arguments[1]; - var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, targetSerializer); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - AggregationExpression elementTranslation; - if (elementExpression is ConstantExpression elementConstantExpression) + var elementTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, elementExpression, itemSerializer); + if (!elementTranslation.Serializer.Equals(itemSerializer)) { - var value = elementConstantExpression.Value; - var serializedValue = SerializationHelper.SerializeValue(itemSerializer, value); - elementTranslation = new AggregationExpression(elementExpression, AstExpression.Constant(serializedValue), itemSerializer); - } - else - { - elementTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, elementExpression); - if (!elementTranslation.Serializer.Equals(itemSerializer)) - { - throw new ExpressionNotSupportedException(expression, because: "argument serializers are not compatible"); - } + throw new ExpressionNotSupportedException(expression, because: "argument serializers are not compatible"); } var ast = method.IsOneOf(__appendMethods) ? AstExpression.ConcatArrays(sourceTranslation.Ast, AstExpression.ComputedArray(elementTranslation.Ast)) : AstExpression.ConcatArrays(AstExpression.ComputedArray(elementTranslation.Ast), sourceTranslation.Ast); - var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); + var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); - return new AggregationExpression(expression, ast, serializer); + return new AggregationExpression(expression, ast, resultSerializer); } throw new ExpressionNotSupportedException(expression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AsQueryableMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AsQueryableMethodToAggregationExpressionTranslator.cs index b97dd8615b8..72a8ed788a8 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AsQueryableMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AsQueryableMethodToAggregationExpressionTranslator.cs @@ -23,7 +23,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class AsQueryableMethodToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { var method = expression.Method; var arguments = expression.Arguments; @@ -31,12 +31,13 @@ public static AggregationExpression Translate(TranslationContext context, Method if (method.Is(QueryableMethod.AsQueryable)) { var sourceExpression = arguments[0]; - var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + var sourceTargetSerializer = GetSourceTargetSerializer(targetSerializer); + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, sourceTargetSerializer); IBsonSerializer serializer; if (sourceTranslation.Serializer is INestedAsQueryableSerializer) { - serializer = sourceTranslation.Serializer; + serializer = sourceTranslation.Serializer; } else { @@ -49,5 +50,16 @@ public static AggregationExpression Translate(TranslationContext context, Method throw new ExpressionNotSupportedException(expression); } + + private static IBsonSerializer GetSourceTargetSerializer(IBsonSerializer targetSerializer) + { + if (targetSerializer is IBsonArraySerializer arraySerializer) + { + var itemSerializer = ArraySerializerHelper.GetItemSerializer(arraySerializer); + return IEnumerableSerializer.Create(itemSerializer); + } + + return null; + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareToMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareToMethodToAggregationExpressionTranslator.cs index 014a337dcf4..98ae2595962 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareToMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareToMethodToAggregationExpressionTranslator.cs @@ -30,11 +30,23 @@ public static AggregationExpression Translate(TranslationContext context, Method if (IComparableMethod.IsCompareToMethod(method)) { var objectExpression = expression.Object; - var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression); var otherExpression = arguments[0]; - var otherTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, otherExpression); + + AggregationExpression objectTranslation; + AggregationExpression otherTranslation; + if (objectExpression is ConstantExpression && otherExpression is not ConstantExpression) + { + otherTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, otherExpression); + objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression, otherTranslation.Serializer); + } + else + { + objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression); + otherTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, otherExpression, objectTranslation.Serializer); + } var ast = AstExpression.Cmp(objectTranslation.Ast, otherTranslation.Ast); - return new AggregationExpression(expression, ast, new Int32Serializer()); + + return new AggregationExpression(expression, ast, Int32Serializer.Instance); } throw new ExpressionNotSupportedException(expression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConcatMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConcatMethodToAggregationExpressionTranslator.cs index 7dc3f32b209..52517ccb22f 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConcatMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConcatMethodToAggregationExpressionTranslator.cs @@ -14,16 +14,17 @@ */ using System.Linq.Expressions; +using MongoDB.Bson.Serialization; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators { internal static class ConcatMethodToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { if (EnumerableConcatMethodToAggregationExpressionTranslator.CanTranslate(expression)) { - return EnumerableConcatMethodToAggregationExpressionTranslator.Translate(context, expression); + return EnumerableConcatMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer); } if (StringConcatMethodToAggregationExpressionTranslator.CanTranslate(expression, out var method, out var arguments)) diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConstantMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConstantMethodToAggregationExpressionTranslator.cs index cc3aeea79e3..a14f0971f85 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConstantMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConstantMethodToAggregationExpressionTranslator.cs @@ -35,7 +35,7 @@ public static AggregationExpression Translate(TranslationContext context, Method var valueExpression = arguments[0]; var value = valueExpression.GetConstantValue(expression); - IBsonSerializer serializer = null; + IBsonSerializer resultSerializer = null; if (method.Is(MqlMethod.ConstantWithRepresentation)) { var representationExpression = arguments[1]; @@ -43,7 +43,7 @@ public static AggregationExpression Translate(TranslationContext context, Method var registeredSerializer = BsonSerializer.LookupSerializer(valueExpression.Type); if (registeredSerializer is IRepresentationConfigurable representationConfigurableSerializer) { - serializer = representationConfigurableSerializer.WithRepresentation(representation); + resultSerializer = representationConfigurableSerializer.WithRepresentation(representation); } else { @@ -53,14 +53,14 @@ public static AggregationExpression Translate(TranslationContext context, Method else if (method.Is(MqlMethod.ConstantWithSerializer)) { var serializerExpression = arguments[1]; - serializer = serializerExpression.GetConstantValue(expression); + resultSerializer = serializerExpression.GetConstantValue(expression); } - if (serializer != null) + if (resultSerializer != null) { - var serializedValue = SerializationHelper.SerializeValue(serializer, value); + var serializedValue = SerializationHelper.SerializeValue(resultSerializer, value); var ast = AstExpression.Constant(serializedValue); - return new AggregationExpression(expression, ast, serializer); + return new AggregationExpression(expression, ast, resultSerializer); } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsKeyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsKeyMethodToAggregationExpressionTranslator.cs index 62af1cf9f02..925365bd588 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsKeyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsKeyMethodToAggregationExpressionTranslator.cs @@ -70,7 +70,7 @@ private static AstExpression GetKeyFieldName(TranslationContext context, Express } ThrowIfKeyIsNotRepresentedAsAString(expression, keySerializer); - var keyTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, keyExpression); + var keyTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, keyExpression, keySerializer); return keyTranslation.Ast; } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs index 68c19d997de..fc1b6bb3d6c 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsMethodToAggregationExpressionTranslator.cs @@ -77,8 +77,9 @@ private static AggregationExpression TranslateEnumerableContains(TranslationCont { var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); + var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression, itemSerializer); var ast = AstExpression.In(valueTranslation.Ast, sourceTranslation.Ast); return new AggregationExpression(expression, ast, BooleanSerializer.Instance); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsValueMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsValueMethodToAggregationExpressionTranslator.cs index c2437521c6d..0d30c8b0061 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsValueMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ContainsValueMethodToAggregationExpressionTranslator.cs @@ -39,7 +39,7 @@ public static AggregationExpression Translate(TranslationContext context, Method var dictionarySerializer = GetDictionarySerializer(expression, dictionaryTranslation); var dictionaryRepresentation = dictionarySerializer.DictionaryRepresentation; - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression, dictionarySerializer.ValueSerializer); var (valueBinding, valueAst) = AstExpression.UseVarIfNotSimple("value", valueTranslation.Ast); AstExpression ast; diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CreateMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CreateMethodToAggregationExpressionTranslator.cs index d025ee64bff..5d606b2ea12 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CreateMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CreateMethodToAggregationExpressionTranslator.cs @@ -51,7 +51,7 @@ internal static class CreateMethodToAggregationExpressionTranslator ValueTupleMethod.Create8 }; - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { var method = expression.Method; var arguments = expression.Arguments; @@ -59,40 +59,48 @@ public static AggregationExpression Translate(TranslationContext context, Method if (method.IsOneOf(__tupleCreateMethods) || method.IsOneOf(__valueTupleCreateMethods)) { var tupleType = method.ReturnType; + var isValueTuple = tupleType.IsValueTuple(); + + IBsonTupleSerializer tupleTargetSerializer = null; + if (targetSerializer != null && (tupleTargetSerializer = targetSerializer as IBsonTupleSerializer) == null) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer {targetSerializer.GetType()} does not implement IBsonTupleSerializer"); + } var items = new AstExpression[arguments.Count]; var itemSerializers = new IBsonSerializer[arguments.Count]; for (var i = 0; i < arguments.Count; i++) { - var valueExpression = arguments[i]; - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); + var itemExpression = arguments[i]; + var itemTargetSerializer = tupleTargetSerializer?.GetItemSerializer(i); + var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression, itemTargetSerializer); AstExpression item; IBsonSerializer itemSerializer; if (i < 7) { - item = valueTranslation.Ast; - itemSerializer = valueTranslation.Serializer; + item = itemTranslation.Ast; + itemSerializer = itemTranslation.Serializer; } else { - item = AstExpression.ComputedArray(valueTranslation.Ast); - itemSerializer = CreateTupleSerializer(tupleType, new[] { valueTranslation.Serializer }); + item = AstExpression.ComputedArray(itemTranslation.Ast); + itemSerializer = CreateTupleSerializer(isValueTuple, [itemTranslation.Serializer]); } items[i] = item; itemSerializers[i] = itemSerializer; } - var ast = AstExpression.ComputedArray(items); - var tupleSerializer = CreateTupleSerializer(tupleType, itemSerializers); - return new AggregationExpression(expression, ast, tupleSerializer); + + var resultSerializer = targetSerializer ?? CreateTupleSerializer(isValueTuple, itemSerializers); + return new AggregationExpression(expression, ast, resultSerializer); } throw new ExpressionNotSupportedException(expression); } - private static IBsonSerializer CreateTupleSerializer(Type tupleType, IEnumerable itemSerializers) + private static IBsonSerializer CreateTupleSerializer(bool isValueTuple, IEnumerable itemSerializers) { - return tupleType.IsTuple() ? TupleSerializer.Create(itemSerializers) : ValueTupleSerializer.Create(itemSerializers); + return isValueTuple ? ValueTupleSerializer.Create(itemSerializers) : TupleSerializer.Create(itemSerializers); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs index 029d2f4fd7e..a9fcfdd9229 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DefaultIfEmptyMethodToAggregationExpressionTranslator.cs @@ -16,6 +16,7 @@ using System.Linq.Expressions; using System.Reflection; using MongoDB.Bson; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; @@ -40,31 +41,43 @@ internal static class DefaultIfEmptyMethodToAggregationExpressionTranslator QueryableMethod.DefaultIfEmptyWithDefaultValue, }; - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { var method = expression.Method; var arguments = expression.Arguments; + IBsonSerializer resultSerializer = null; + IBsonSerializer itemSerializer = null; + if (targetSerializer != null) + { + if (!(targetSerializer is IBsonArraySerializer arraySerializer)) + { + throw new ExpressionNotSupportedException(expression, because: $"the target serializer {targetSerializer.GetType()} does not implement IBsonArraySerializer."); + } + + resultSerializer = targetSerializer; + itemSerializer = ArraySerializerHelper.GetItemSerializer(arraySerializer); + } + if (method.IsOneOf(__defaultIfEmptyMethods)) { var sourceExpression = arguments[0]; - var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, targetSerializer); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); - var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); + itemSerializer ??= ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); var (sourceVarBinding, sourceAst) = AstExpression.UseVarIfNotSimple("source", sourceTranslation.Ast); AstExpression defaultValueAst; if (method.IsOneOf(__defaultIfEmptyWithDefaultValueMethods)) { var defaultValueExpression = arguments[1]; - var defaultValueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, defaultValueExpression); - defaultValueAst = AstExpression.ComputedArray(new[] { defaultValueTranslation.Ast }); + var defaultValueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, defaultValueExpression, itemSerializer); + defaultValueAst = AstExpression.ComputedArray([defaultValueTranslation.Ast]); } else { - var sourceItemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); - var defaultValue = sourceItemSerializer.ValueType.GetDefaultValue(); - var serializedDefaultValue = SerializationHelper.SerializeValue(sourceItemSerializer, defaultValue); + var defaultValue = itemSerializer.ValueType.GetDefaultValue(); + var serializedDefaultValue = SerializationHelper.SerializeValue(itemSerializer, defaultValue); defaultValueAst = AstExpression.Constant(new BsonArray { serializedDefaultValue }); } var ast = AstExpression.Let( @@ -74,8 +87,8 @@ public static AggregationExpression Translate(TranslationContext context, Method defaultValueAst, sourceAst)); - var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); - return new AggregationExpression(expression, ast, serializer); + resultSerializer ??= NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); + return new AggregationExpression(expression, ast, resultSerializer); } throw new ExpressionNotSupportedException(expression); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs index d63a77e2ff8..b944ecdfbf3 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/DistinctMethodToAggregationExpressionTranslator.cs @@ -15,6 +15,7 @@ using System.Linq.Expressions; using System.Reflection; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; @@ -30,7 +31,7 @@ internal static class DistinctMethodToAggregationExpressionTranslator QueryableMethod.Distinct }; - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { var method = expression.Method; var arguments = expression.Arguments; @@ -38,7 +39,7 @@ public static AggregationExpression Translate(TranslationContext context, Method if (method.IsOneOf(__distinctMethods)) { var sourceExpression = arguments[0]; - var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, targetSerializer); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs index c29e684be46..6899aac331a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ElementAtMethodToAggregationExpressionTranslator.cs @@ -15,9 +15,11 @@ using System.Linq.Expressions; using System.Reflection; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; +using MongoDB.Driver.Linq.Linq3Implementation.Serializers; using MongoDB.Driver.Support; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators @@ -38,7 +40,7 @@ internal static class ElementAtMethodToAggregationExpressionTranslator QueryableMethod.ElementAtOrDefault }; - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { var method = expression.Method; var arguments = expression.Arguments; @@ -46,7 +48,8 @@ public static AggregationExpression Translate(TranslationContext context, Method if (method.IsOneOf(__elementAtMethods)) { var sourceExpression = arguments[0]; - var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression); + var sourceTargetSerializer = GetSourceTargetSerializer(targetSerializer); + var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, sourceTargetSerializer); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation); var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer); @@ -79,5 +82,10 @@ public static AggregationExpression Translate(TranslationContext context, Method throw new ExpressionNotSupportedException(expression); } + + private static IBsonSerializer GetSourceTargetSerializer(IBsonSerializer targetSerializer) + { + return targetSerializer == null ? null : NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(targetSerializer.ValueType, targetSerializer); + } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs index 75d9982f917..344eb7249fa 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/EnumerableConcatMethodToAggregationExpressionTranslator.cs @@ -15,6 +15,7 @@ using System.Linq.Expressions; using System.Reflection; +using MongoDB.Bson.Serialization; using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions; using MongoDB.Driver.Linq.Linq3Implementation.Misc; using MongoDB.Driver.Linq.Linq3Implementation.Reflection; @@ -33,7 +34,7 @@ internal static class EnumerableConcatMethodToAggregationExpressionTranslator public static bool CanTranslate(MethodCallExpression expression) => expression.Method.IsOneOf(__concatMethods); - public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression) + public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer) { var method = expression.Method; var arguments = expression.Arguments; @@ -41,10 +42,10 @@ public static AggregationExpression Translate(TranslationContext context, Method if (method.IsOneOf(__concatMethods)) { var firstExpression = arguments[0]; - var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression); + var firstTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, firstExpression, targetSerializer); NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, firstTranslation); var secondExpression = arguments[1]; - var secondTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, secondExpression); + var secondTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, secondExpression, firstTranslation.Serializer); var ast = AstExpression.ConcatArrays(firstTranslation.Ast, secondTranslation.Ast); var itemSerializer = ArraySerializerHelper.GetItemSerializer(firstTranslation.Serializer); var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer); diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs index c6d95c36a43..329a479c700 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewArrayInitExpressionToAggregationExpressionTranslator.cs @@ -24,15 +24,34 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg { internal static class NewArrayInitExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, NewArrayExpression expression) + public static AggregationExpression Translate(TranslationContext context, NewArrayExpression expression, IBsonSerializer targetSerializer) { - var items = new List(); + IBsonArraySerializer resultSerializer = null; IBsonSerializer itemSerializer = null; + + if (targetSerializer != null) + { + if ((resultSerializer = targetSerializer as IBsonArraySerializer) == null) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer} does not implement IBsonArraySerializer"); + } + if (!resultSerializer.TryGetItemSerializationInfo(out var itemSerializationInfo)) + { + throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer} returned false for TryGetItemSerializationInfo"); + } + + itemSerializer = itemSerializationInfo.Serializer; + } + + var items = new List(); foreach (var itemExpression in expression.Expressions) { - var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression); + var itemTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, itemExpression, itemSerializer); items.Add(itemTranslation.Ast); - itemSerializer ??= itemTranslation.Serializer; + if (itemSerializer == null) + { + itemSerializer = itemTranslation.Serializer; + } // make sure all items are serialized using the same serializer if (!itemTranslation.Serializer.Equals(itemSerializer)) @@ -42,14 +61,16 @@ public static AggregationExpression Translate(TranslationContext context, NewArr } var ast = AstExpression.ComputedArray(items); + if (resultSerializer == null) + { + var arrayType = expression.Type; + var itemType = arrayType.GetElementType(); + itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer might be null + var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); + resultSerializer = (IBsonArraySerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); + } - var arrayType = expression.Type; - var itemType = arrayType.GetElementType(); - itemSerializer ??= BsonSerializer.LookupSerializer(itemType); // if the array is empty itemSerializer will be null - var arraySerializerType = typeof(ArraySerializer<>).MakeGenericType(itemType); - var arraySerializer = (IBsonSerializer)Activator.CreateInstance(arraySerializerType, itemSerializer); - - return new AggregationExpression(expression, ast, arraySerializer); + return new AggregationExpression(expression, ast, resultSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs index ee930fd6d5f..356161c5d7a 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/NewExpressionToAggregationExpressionTranslator.cs @@ -17,12 +17,13 @@ using System.Collections.Generic; using System.Linq.Expressions; using MongoDB.Bson; +using MongoDB.Bson.Serialization; namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators { internal static class NewExpressionToAggregationExpressionTranslator { - public static AggregationExpression Translate(TranslationContext context, NewExpression expression) + public static AggregationExpression Translate(TranslationContext context, NewExpression expression, IBsonSerializer targetSerializer) { var expressionType = expression.Type; @@ -46,7 +47,8 @@ public static AggregationExpression Translate(TranslationContext context, NewExp { return NewTupleExpressionToAggregationExpressionTranslator.Translate(context, expression); } - return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty()); + + return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, expression, expression, Array.Empty(), targetSerializer); } } } diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToSetStageTranslators/ExpressionToSetStageTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToSetStageTranslators/ExpressionToSetStageTranslator.cs index 018c2ae5051..803b000c4a6 100644 --- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToSetStageTranslators/ExpressionToSetStageTranslator.cs +++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToSetStageTranslators/ExpressionToSetStageTranslator.cs @@ -139,39 +139,27 @@ private static AstStage TranslateNewWithOptionalMemberInitializers(TranslationCo private static AstComputedField CreateComputedField(TranslationContext context, IBsonDocumentSerializer documentSerializer, MemberInfo member, Expression valueExpression) { string elementName; - AstExpression valueAst; + IBsonSerializer targetSerializer; if (documentSerializer.TryGetMemberSerializationInfo(member.Name, out var serializationInfo)) { elementName = serializationInfo.ElementName; - var memberSerializer = serializationInfo.Serializer; - - if (valueExpression is ConstantExpression constantValueExpression) - { - var value = constantValueExpression.Value; - var serializedValue = SerializationHelper.SerializeValue(memberSerializer, value); - valueAst = AstExpression.Constant(serializedValue); - } - else - { - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - ThrowIfMemberAndValueSerializersAreNotCompatible(valueExpression, memberSerializer, valueTranslation.Serializer); - valueAst = valueTranslation.Ast; - } + targetSerializer = serializationInfo.Serializer; } else { elementName = member.Name; - var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression); - valueAst = valueTranslation.Ast; + targetSerializer = null; } - return AstExpression.ComputedField(elementName, valueAst); + var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression, targetSerializer); + ThrowIfTargetAndValueSerializersAreNotCompatible(valueExpression, targetSerializer, valueTranslation.Serializer); + + return AstExpression.ComputedField(elementName, valueTranslation.Ast); } - private static void ThrowIfMemberAndValueSerializersAreNotCompatible(Expression expression, IBsonSerializer memberSerializer, IBsonSerializer valueSerializer) + private static void ThrowIfTargetAndValueSerializersAreNotCompatible(Expression expression, IBsonSerializer targetSerializer, IBsonSerializer valueSerializer) { - // TODO: depends on CSHARP-3315 - if (!memberSerializer.Equals(valueSerializer)) + if (targetSerializer != null && !targetSerializer.Equals(valueSerializer)) { throw new ExpressionNotSupportedException(expression, because: "member and value serializers are not compatible"); } diff --git a/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs new file mode 100644 index 00000000000..d2eca6a6ab9 --- /dev/null +++ b/tests/MongoDB.Driver.Tests/Linq/Linq3Implementation/Jira/CSharp5435Tests.cs @@ -0,0 +1,232 @@ +/* Copyright 2010-present MongoDB Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +using System.Linq; +using FluentAssertions; +using MongoDB.Bson; +using MongoDB.Bson.IO; +using MongoDB.Bson.Serialization; +using MongoDB.Bson.Serialization.Attributes; +using MongoDB.Bson.Serialization.Serializers; +using MongoDB.Driver.Core.Misc; +using MongoDB.Driver.Core.TestHelpers.XunitExtensions; +using Xunit; + +namespace MongoDB.Driver.Tests.Linq.Linq3Implementation.Jira +{ + public class CSharp5435Tests : Linq3IntegrationTest + { + [Fact] + public void Test_set_ValueObject_Value_using_creator_map() + { + RequireServer.Check().Supports(Feature.UpdateWithAggregationPipeline); + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue(x.ValueObject == null ? 1 : x.ValueObject.Value + 1) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_ValueObject_Value_using_property_setter() + { + RequireServer.Check().Supports(Feature.UpdateWithAggregationPipeline); + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + ValueObject = new MyValue() + { + Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1 + } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { ValueObject : { Value : { $cond : { if : { $eq : ['$ValueObject', null] }, then : 1, else : { $add : ['$ValueObject.Value', 1] } } } } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + // [Fact] + // public void Test_set_ValueObject_to_derived_value_using_property_setter() + // { + // var coll = GetCollection(); + // var doc = new MyDocument(); + // var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + // + // var pipelineError = new EmptyPipelineDefinition() + // .Set(x => new MyDocument() + // { + // ValueObject = new MyDerivedValue() + // { + // Value = x.ValueObject == null ? 1 : x.ValueObject.Value + 1, + // B = 42 + // } + // }); + // var updateError = Builders.Update.Pipeline(pipelineError); + // + // coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + // } + + [Fact] + public void Test_set_X_using_constructor() + { + RequireServer.Check().Supports(Feature.UpdateWithAggregationPipeline); + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + X = new X(x.Y) + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { X : { Y : '$Y' } } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + [Fact] + public void Test_set_A() + { + RequireServer.Check().Supports(Feature.UpdateWithAggregationPipeline); + var coll = GetCollection(); + var doc = new MyDocument(); + var filter = Builders.Filter.Eq(x => x.Id, doc.Id); + + var pipelineError = new EmptyPipelineDefinition() + .Set(x => new MyDocument() + { + A = new [] { 2, x.A[0] } + }); + var updateError = Builders.Update.Pipeline(pipelineError); + + var updateStages = + updateError.Render(new(coll.DocumentSerializer, BsonSerializer.SerializerRegistry)) + .AsBsonArray + .Cast(); + AssertStages(updateStages, "{ $set : { A : ['2', { $arrayElemAt : ['$A', 0] }] } }"); + + coll.UpdateOne(filter, updateError, new() { IsUpsert = true }); + } + + private IMongoCollection GetCollection() + { + var collection = GetCollection("test"); + CreateCollection( + collection.Database.GetCollection("test"), + BsonDocument.Parse("{ _id : 1 }"), + BsonDocument.Parse("{ _id : 2, X : null }"), + BsonDocument.Parse("{ _id : 3, X : 3 }")); + return collection; + } + + class MyDocument + { + [BsonRepresentation(MongoDB.Bson.BsonType.ObjectId)] + public string Id { get; set; } = ObjectId.GenerateNewId().ToString(); + + public MyValue ValueObject { get; set; } + + public long Long { get; set; } + + public X X { get; set; } + + public int Y { get; set; } + + [BsonRepresentation(BsonType.String)] + public int[] A { get; set; } + } + + class MyValue + { + [BsonConstructor] + public MyValue() { } + [BsonConstructor] + public MyValue(int value) { Value = value; } + public int Value { get; set; } + } + + class MyDerivedValue : MyValue + { + public int B { get; set; } + } + + [BsonSerializer(typeof(XSerializer))] + class X + { + public X(int y) + { + Y = y; + } + public int Y { get; } + } + + class XSerializer : SerializerBase, IBsonDocumentSerializer + { + public override X Deserialize(BsonDeserializationContext context, BsonDeserializationArgs args) + { + var reader = context.Reader; + reader.ReadStartArray(); + _ = reader.ReadName(); + var y = reader.ReadInt32(); + reader.ReadEndDocument(); + + return new X(y); + } + + public override void Serialize(BsonSerializationContext context, BsonSerializationArgs args, X value) + { + var writer = context.Writer; + writer.WriteStartDocument(); + writer.WriteName("Y"); + writer.WriteInt32(value.Y); + writer.WriteEndDocument(); + } + + public bool TryGetMemberSerializationInfo(string memberName, out BsonSerializationInfo serializationInfo) + { + serializationInfo = memberName == "Y" ? new BsonSerializationInfo("Y", Int32Serializer.Instance, typeof(int)) : null; + return serializationInfo != null; + } + } + } +}