diff --git a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs
index fdcd59916ed..54e1dbb54bc 100644
--- a/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs
+++ b/src/MongoDB.Bson/Serialization/Serializers/BsonClassMapSerializer.cs
@@ -23,11 +23,22 @@
namespace MongoDB.Bson.Serialization
{
+ ///
+ /// An interface implemented by BsonClassMapSerializer.
+ ///
+ public interface IBsonClassMapSerializer
+ {
+ ///
+ /// Gets the class map for a BsonClassMapSerializer.
+ ///
+ public BsonClassMap ClassMap { get; }
+ }
+
///
/// Represents a serializer for a class map.
///
/// The type of the class.
- public sealed class BsonClassMapSerializer : SerializerBase, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention
+ public sealed class BsonClassMapSerializer : SerializerBase, IBsonClassMapSerializer, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention
{
// private fields
private readonly BsonClassMap _classMap;
@@ -57,6 +68,9 @@ public BsonClassMapSerializer(BsonClassMap classMap)
}
// public properties
+ ///
+ public BsonClassMap ClassMap => _classMap;
+
///
public IDiscriminatorConvention DiscriminatorConvention => _classMap.GetDiscriminatorConvention();
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs
index d4c74a22b09..0f08c7cfe03 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ConstantExpressionToAggregationExpressionTranslator.cs
@@ -17,25 +17,27 @@
using MongoDB.Bson.IO;
using MongoDB.Bson.Serialization;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
+using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
{
internal static class ConstantExpressionToAggregationExpressionTranslator
{
- public static AggregationExpression Translate(ConstantExpression constantExpression)
+ public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer targetSerializer)
{
- var constantType = constantExpression.Type;
- var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType);
- return Translate(constantExpression, constantSerializer);
- }
+ var resultSerializer = targetSerializer;
+ if (resultSerializer == null)
+ {
+ var constantType = constantExpression.Type;
+ resultSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType);
+ }
- public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer constantSerializer)
- {
- var constantValue = constantExpression.Value;
- var serializedValue = constantSerializer.ToBsonValue(constantValue);
+ var value = constantExpression.Value;
+ var serializedValue = SerializationHelper.SerializeValue(resultSerializer, value);
var ast = AstExpression.Constant(serializedValue);
- return new AggregationExpression(constantExpression, ast, constantSerializer);
+
+ return new AggregationExpression(constantExpression, ast, resultSerializer);
}
}
}
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs
index abb3ba1766d..bebb6dc6f6a 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/ExpressionToAggregationExpressionTranslator.cs
@@ -27,7 +27,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg
internal static class ExpressionToAggregationExpressionTranslator
{
// public static methods
- public static AggregationExpression Translate(TranslationContext context, Expression expression)
+ public static AggregationExpression Translate(TranslationContext context, Expression expression, IBsonSerializer targetSerializer = null)
{
switch (expression.NodeType)
{
@@ -63,11 +63,11 @@ public static AggregationExpression Translate(TranslationContext context, Expres
case ExpressionType.ArrayLength:
return ArrayLengthExpressionToAggregationExpressionTranslator.Translate(context, (UnaryExpression)expression);
case ExpressionType.Call:
- return MethodCallExpressionToAggregationExpressionTranslator.Translate(context, (MethodCallExpression)expression);
+ return MethodCallExpressionToAggregationExpressionTranslator.Translate(context, (MethodCallExpression)expression, targetSerializer);
case ExpressionType.Conditional:
return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression);
case ExpressionType.Constant:
- return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression);
+ return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression, targetSerializer);
case ExpressionType.Index:
return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression);
case ExpressionType.ListInit:
@@ -75,13 +75,13 @@ public static AggregationExpression Translate(TranslationContext context, Expres
case ExpressionType.MemberAccess:
return MemberExpressionToAggregationExpressionTranslator.Translate(context, (MemberExpression)expression);
case ExpressionType.MemberInit:
- return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression);
+ return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression, targetSerializer);
case ExpressionType.Negate:
return NegateExpressionToAggregationExpressionTranslator.Translate(context, (UnaryExpression)expression);
case ExpressionType.New:
- return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression);
+ return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression, targetSerializer);
case ExpressionType.NewArrayInit:
- return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression);
+ return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression, targetSerializer);
case ExpressionType.Parameter:
return ParameterExpressionToAggregationExpressionTranslator.Translate(context, (ParameterExpression)expression);
case ExpressionType.TypeIs:
@@ -91,19 +91,19 @@ public static AggregationExpression Translate(TranslationContext context, Expres
throw new ExpressionNotSupportedException(expression);
}
- public static AggregationExpression TranslateEnumerable(TranslationContext context, Expression expression)
+ public static AggregationExpression TranslateEnumerable(TranslationContext context, Expression expression, IBsonSerializer targetSerializer = null)
{
- var aggregateExpression = Translate(context, expression);
+ var aggregateExpression = Translate(context, expression, targetSerializer);
- var serializer = aggregateExpression.Serializer;
- if (serializer is IWrappedEnumerableSerializer wrappedEnumerableSerializer)
+ var resultSerializer = aggregateExpression.Serializer;
+ if (resultSerializer is IWrappedEnumerableSerializer wrappedEnumerableSerializer)
{
var enumerableFieldName = wrappedEnumerableSerializer.EnumerableFieldName;
- var enumerableElementSerializer = wrappedEnumerableSerializer.EnumerableElementSerializer;
- var enumerableSerializer = IEnumerableSerializer.Create(enumerableElementSerializer);
- var ast = AstExpression.GetField(aggregateExpression.Ast, enumerableFieldName);
+ var itemSerializer = wrappedEnumerableSerializer.EnumerableElementSerializer;
- return new AggregationExpression(aggregateExpression.Expression, ast, enumerableSerializer);
+ var ast = AstExpression.GetField(aggregateExpression.Ast, enumerableFieldName);
+ resultSerializer = IEnumerableSerializer.Create(itemSerializer);
+ return new AggregationExpression(aggregateExpression.Expression, ast, resultSerializer);
}
return aggregateExpression;
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs
index e1a8cd5f399..f24fe68a1a8 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MemberInitExpressionToAggregationExpressionTranslator.cs
@@ -28,22 +28,28 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg
{
internal static class MemberInitExpressionToAggregationExpressionTranslator
{
- public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression)
+ public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression, IBsonSerializer targetSerializer)
{
if (expression.Type == typeof(BsonDocument))
{
return NewBsonDocumentExpressionToAggregationExpressionTranslator.Translate(context, expression);
}
- return Translate(context, expression, expression.NewExpression, expression.Bindings);
+ return Translate(context, expression, expression.NewExpression, expression.Bindings, targetSerializer);
}
public static AggregationExpression Translate(
TranslationContext context,
Expression expression,
NewExpression newExpression,
- IReadOnlyList bindings)
+ IReadOnlyList bindings,
+ IBsonSerializer targetSerializer)
{
+ if (targetSerializer != null)
+ {
+ return TranslateWithTargetSerializer(context, expression, newExpression, bindings, targetSerializer);
+ }
+
var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct
var constructorArguments = newExpression.Arguments;
var computedFields = new List();
@@ -83,8 +89,7 @@ public static AggregationExpression Translate(
foreach (var binding in bindings)
{
var memberAssignment = (MemberAssignment)binding;
- var member = memberAssignment.Member;
- var memberMap = FindMemberMap(expression, classMap, member.Name);
+ var memberMap = FindMemberMap(expression, classMap, memberInfo: memberAssignment.Member);
var valueExpression = memberAssignment.Expression;
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer);
@@ -100,6 +105,70 @@ public static AggregationExpression Translate(
return new AggregationExpression(expression, ast, serializer);
}
+ private static AggregationExpression TranslateWithTargetSerializer(
+ TranslationContext context,
+ Expression expression,
+ NewExpression newExpression,
+ IReadOnlyList bindings,
+ IBsonSerializer targetSerializer)
+ {
+ var resultSerializer = targetSerializer as IBsonDocumentSerializer;
+ if (resultSerializer == null)
+ {
+ throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer.GetType()} does not implement IBsonDocumentSerializer.");
+ }
+
+ var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct
+ var constructorArguments = newExpression.Arguments;
+ var computedFields = new List();
+
+ if (constructorInfo != null && constructorArguments.Count > 0)
+ {
+ var constructorParameters = constructorInfo.GetParameters();
+
+ // if the documentSerializer is a BsonClassMappedSerializer we can use the classMap and creatorMap
+ var classMap = (resultSerializer as IBsonClassMapSerializer)?.ClassMap;
+ var creatorMap = classMap == null ? null : FindMatchingCreatorMap(classMap, constructorInfo);
+ if (creatorMap == null && classMap != null)
+ {
+ throw new ExpressionNotSupportedException(expression, because: "no matching creator map found");
+ }
+ var creatorMapArguments = creatorMap?.Arguments?.ToArray();
+
+ for (var i = 0; i < constructorParameters.Length; i++)
+ {
+ var argumentExpression = constructorArguments[i];
+
+ // if we have a classMap (and therefore a creatorMap also) use them
+ // otherwise fall back to matching constructor parameter names to member names
+ var (elementName, memberSerializer) = classMap != null ?
+ FindMemberElementNameAndSerializer(argumentExpression, classMap, memberInfo: creatorMapArguments[i]) :
+ FindMemberElementNameAndSerializer(argumentExpression, resultSerializer, constructorParameterName: constructorParameters[i].Name);
+
+ var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression, targetSerializer: memberSerializer);
+ computedFields.Add(AstExpression.ComputedField(elementName, argumentTranslation.Ast));
+ }
+ }
+
+ foreach (var binding in bindings)
+ {
+ var memberAssignment = (MemberAssignment)binding;
+ var member = memberAssignment.Member;
+ var valueExpression = memberAssignment.Expression;
+ if (!resultSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo))
+ {
+ throw new ExpressionNotSupportedException(valueExpression, expression, because: $"couldn't find member {member.Name}");
+ }
+ var memberSerializer = memberSerializationInfo.Serializer;
+
+ var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression, memberSerializer);
+ computedFields.Add(AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast));
+ }
+
+ var ast = AstExpression.ComputedDocument(computedFields);
+ return new AggregationExpression(expression, ast, resultSerializer);
+ }
+
private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap)
{
BsonClassMap baseClassMap = null;
@@ -190,11 +259,44 @@ private static void EnsureDefaultValue(BsonMemberMap memberMap)
memberMap.SetDefaultValue(defaultValue);
}
- private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName)
+ private static BsonCreatorMap FindMatchingCreatorMap(BsonClassMap classMap, ConstructorInfo constructorInfo)
+ => classMap?.CreatorMaps.FirstOrDefault(m => m.MemberInfo.Equals(constructorInfo));
+
+ private static (string, IBsonSerializer) FindMemberElementNameAndSerializer(
+ Expression expression,
+ BsonClassMap classMap,
+ MemberInfo memberInfo)
+ {
+ var memberMap = FindMemberMap(expression, classMap, memberInfo);
+ return (memberMap.ElementName, memberMap.GetSerializer());
+ }
+
+ private static (string, IBsonSerializer) FindMemberElementNameAndSerializer(
+ Expression expression,
+ IBsonDocumentSerializer documentSerializer,
+ string constructorParameterName)
+ {
+ // case insensitive GetMember could return some false hits but TryGetMemberSerializationInfo will filter them out
+ var bindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy | BindingFlags.IgnoreCase;
+ foreach (var memberInfo in documentSerializer.ValueType.GetMember(constructorParameterName, bindingFlags))
+ {
+ if (documentSerializer.TryGetMemberSerializationInfo(memberInfo.Name, out var serializationInfo))
+ {
+ return (serializationInfo.ElementName, serializationInfo.Serializer);
+ }
+ }
+
+ throw new ExpressionNotSupportedException(expression, because: $"no matching member map found for constructor parameter: {constructorParameterName}");
+ }
+
+ private static BsonMemberMap FindMemberMap(
+ Expression expression,
+ BsonClassMap classMap,
+ MemberInfo memberInfo)
{
foreach (var memberMap in classMap.DeclaredMemberMaps)
{
- if (memberMap.MemberName == memberName)
+ if (memberMap.MemberInfo == memberInfo)
{
return memberMap;
}
@@ -202,10 +304,10 @@ private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap c
if (classMap.BaseClassMap != null)
{
- return FindMemberMap(expression, classMap.BaseClassMap, memberName);
+ return FindMemberMap(expression, classMap.BaseClassMap, memberInfo);
}
- throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}");
+ throw new ExpressionNotSupportedException(expression, because: $"no member map found for member: {memberInfo.Name}");
}
}
}
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs
index 1fe6689fbdc..eb31beaf92a 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodCallExpressionToAggregationExpressionTranslator.cs
@@ -14,39 +14,40 @@
*/
using System.Linq.Expressions;
+using MongoDB.Bson.Serialization;
using MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators;
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
{
internal static class MethodCallExpressionToAggregationExpressionTranslator
{
- public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
+ public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer)
{
switch (expression.Method.Name)
{
case "Abs": return AbsMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Add": return AddMethodToAggregationExpressionTranslator.Translate(context, expression);
case "AddToSet": return AddToSetMethodToAggregationExpressionTranslator.Translate(context, expression);
- case "Aggregate": return AggregateMethodToAggregationExpressionTranslator.Translate(context, expression);
+ case "Aggregate": return AggregateMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
case "All": return AllMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Any": return AnyMethodToAggregationExpressionTranslator.Translate(context, expression);
- case "AsQueryable": return AsQueryableMethodToAggregationExpressionTranslator.Translate(context, expression);
+ case "AsQueryable": return AsQueryableMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
case "Average": return AverageMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Ceiling": return CeilingMethodToAggregationExpressionTranslator.Translate(context, expression);
case "CompareTo": return CompareToMethodToAggregationExpressionTranslator.Translate(context, expression);
- case "Concat": return ConcatMethodToAggregationExpressionTranslator.Translate(context, expression);
+ case "Concat": return ConcatMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
case "Constant": return ConstantMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Contains": return ContainsMethodToAggregationExpressionTranslator.Translate(context, expression);
case "ContainsKey": return ContainsKeyMethodToAggregationExpressionTranslator.Translate(context, expression);
case "ContainsValue": return ContainsValueMethodToAggregationExpressionTranslator.Translate(context, expression);
case "CovariancePopulation": return CovariancePopulationMethodToAggregationExpressionTranslator.Translate(context, expression);
case "CovarianceSample": return CovarianceSampleMethodToAggregationExpressionTranslator.Translate(context, expression);
- case "Create": return CreateMethodToAggregationExpressionTranslator.Translate(context, expression);
+ case "Create": return CreateMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
case "DateFromString": return DateFromStringMethodToAggregationExpressionTranslator.Translate(context, expression);
- case "DefaultIfEmpty": return DefaultIfEmptyMethodToAggregationExpressionTranslator.Translate(context, expression);
+ case "DefaultIfEmpty": return DefaultIfEmptyMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
case "DenseRank": return DenseRankMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Derivative": return DerivativeMethodToAggregationExpressionTranslator.Translate(context, expression);
- case "Distinct": return DistinctMethodToAggregationExpressionTranslator.Translate(context, expression);
+ case "Distinct": return DistinctMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
case "DocumentNumber": return DocumentNumberMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Equals": return EqualsMethodToAggregationExpressionTranslator.Translate(context, expression);
case "Except": return ExceptMethodToAggregationExpressionTranslator.Translate(context, expression);
@@ -122,7 +123,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
case "Append":
case "Prepend":
- return AppendOrPrependMethodToAggregationExpressionTranslator.Translate(context, expression);
+ return AppendOrPrependMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
case "Bottom":
case "BottomN":
@@ -140,7 +141,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
case "ElementAt":
case "ElementAtOrDefault":
- return ElementAtMethodToAggregationExpressionTranslator.Translate(context, expression);
+ return ElementAtMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
case "First":
case "FirstOrDefault":
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs
index e2e2f0f1b8c..8bcbef04746 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AggregateMethodToAggregationExpressionTranslator.cs
@@ -19,6 +19,7 @@
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
+using MongoDB.Driver.Linq.Linq3Implementation.Serializers;
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
@@ -34,27 +35,33 @@ internal static class AggregateMethodToAggregationExpressionTranslator
QueryableMethod.AggregateWithSeedFuncAndResultSelector
};
- private static readonly MethodInfo[] __aggregateWithoutSeedMethods =
+ private static readonly MethodInfo[] __aggregateWithFuncMethods =
{
EnumerableMethod.AggregateWithFunc,
QueryableMethod.AggregateWithFunc
};
- private static readonly MethodInfo[] __aggregateWithSeedMethods =
+ private static readonly MethodInfo[] __aggregateWithSeedAndFuncMethods =
{
EnumerableMethod.AggregateWithSeedAndFunc,
+ QueryableMethod.AggregateWithSeedAndFunc
+ };
+
+ private static readonly MethodInfo[] __aggregateWithSeedAndFuncAndResultSelectorMethods =
+ {
EnumerableMethod.AggregateWithSeedFuncAndResultSelector,
- QueryableMethod.AggregateWithSeedAndFunc,
QueryableMethod.AggregateWithSeedFuncAndResultSelector
};
- private static readonly MethodInfo[] __aggregateWithSeedFuncAndResultSelectorMethods =
- {
+ private static readonly MethodInfo[] __aggregateIncludingSeedMethods =
+ {
+ EnumerableMethod.AggregateWithSeedAndFunc,
EnumerableMethod.AggregateWithSeedFuncAndResultSelector,
+ QueryableMethod.AggregateWithSeedAndFunc,
QueryableMethod.AggregateWithSeedFuncAndResultSelector
};
- public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
+ public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer)
{
var method = expression.Method;
var arguments = expression.Arguments;
@@ -62,11 +69,12 @@ public static AggregationExpression Translate(TranslationContext context, Method
if (method.IsOneOf(__aggregateMethods))
{
var sourceExpression = arguments[0];
- var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
+ var sourceTargetSerializer = GetSourceTargetSerializer(method, targetSerializer);
+ var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, sourceTargetSerializer);
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
- if (method.IsOneOf(__aggregateWithoutSeedMethods))
+ if (method.IsOneOf(__aggregateWithFuncMethods))
{
var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[1]);
var funcParameters = funcLambda.Parameters;
@@ -75,7 +83,8 @@ public static AggregationExpression Translate(TranslationContext context, Method
var itemParameter = funcParameters[1];
var itemSymbol = context.CreateSymbolWithVarName(itemParameter, varName: "this", itemSerializer); // note: MQL uses $$this for the item being processed
var funcContext = context.WithSymbols(accumulatorSymbol, itemSymbol);
- var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body);
+ var funcTargetSerializer = GetFuncTargetSerializer(method, targetSerializer);
+ var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body, funcTargetSerializer);
var (sourceVarBinding, sourceAst) = AstExpression.UseVarIfNotSimple("source", sourceTranslation.Ast);
var seedVar = AstExpression.Var("seed");
@@ -95,10 +104,11 @@ public static AggregationExpression Translate(TranslationContext context, Method
return new AggregationExpression(expression, ast, itemSerializer);
}
- else if (method.IsOneOf(__aggregateWithSeedMethods))
+ else if (method.IsOneOf(__aggregateIncludingSeedMethods))
{
var seedExpression = arguments[1];
- var seedTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, seedExpression);
+ var seedTargetSerializer = GetSeedTargetSerializer(method, targetSerializer);
+ var seedTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, seedExpression, seedTargetSerializer);
var funcLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[2]);
var funcParameters = funcLambda.Parameters;
@@ -108,7 +118,8 @@ public static AggregationExpression Translate(TranslationContext context, Method
var itemParameter = funcParameters[1];
var itemSymbol = context.CreateSymbolWithVarName(itemParameter, varName: "this", itemSerializer); // note: MQL uses $$this for the item being processed
var funcContext = context.WithSymbols(accumulatorSymbol, itemSymbol);
- var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body);
+ var funcTargetSerializer = GetFuncTargetSerializer(method, targetSerializer);
+ var funcTranslation = ExpressionToAggregationExpressionTranslator.Translate(funcContext, funcLambda.Body, funcTargetSerializer);
var ast = AstExpression.Reduce(
input: sourceTranslation.Ast,
@@ -116,13 +127,14 @@ public static AggregationExpression Translate(TranslationContext context, Method
@in: funcTranslation.Ast);
var serializer = accumulatorSerializer;
- if (method.IsOneOf(__aggregateWithSeedFuncAndResultSelectorMethods))
+ if (method.IsOneOf(__aggregateWithSeedAndFuncAndResultSelectorMethods))
{
var resultSelectorLambda = ExpressionHelper.UnquoteLambdaIfQueryableMethod(method, arguments[3]);
var resultSelectorAccumulatorParameter = resultSelectorLambda.Parameters[0];
var resultSelectorAccumulatorSymbol = context.CreateSymbol(resultSelectorAccumulatorParameter, accumulatorSerializer);
var resultSelectorContext = context.WithSymbol(resultSelectorAccumulatorSymbol);
- var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body);
+ var resultSelectorTargetSerializer = GetResultSelectorTargetSerializer(method, targetSerializer);
+ var resultSelectorTranslation = ExpressionToAggregationExpressionTranslator.Translate(resultSelectorContext, resultSelectorLambda.Body, resultSelectorTargetSerializer);
ast = AstExpression.Let(
var: AstExpression.VarBinding(resultSelectorAccumulatorSymbol.Var, ast),
@@ -136,5 +148,57 @@ public static AggregationExpression Translate(TranslationContext context, Method
throw new ExpressionNotSupportedException(expression);
}
+
+ private static IBsonSerializer GetFuncTargetSerializer(MethodInfo method, IBsonSerializer targetSerializer)
+ {
+ if (method.IsOneOf(__aggregateWithFuncMethods, __aggregateWithSeedAndFuncMethods))
+ {
+ return targetSerializer;
+ }
+
+ return null;
+ }
+
+ private static IBsonSerializer GetResultSelectorTargetSerializer(MethodInfo method, IBsonSerializer targetSerializer)
+ {
+ if (method.IsOneOf(__aggregateWithSeedAndFuncAndResultSelectorMethods))
+ {
+ return targetSerializer;
+ }
+
+ return null;
+ }
+
+ private static IBsonSerializer GetSeedTargetSerializer(MethodInfo method, IBsonSerializer targetSerializer)
+ {
+ if (method.IsOneOf(__aggregateWithSeedAndFuncMethods))
+ {
+ return targetSerializer;
+ }
+
+ return null;
+ }
+
+ private static IBsonSerializer GetSourceTargetSerializer(MethodInfo method, IBsonSerializer targetSerializer)
+ {
+ IBsonSerializer itemSerializer = null;
+ if (method.IsOneOf(__aggregateWithFuncMethods))
+ {
+ itemSerializer = targetSerializer;
+ }
+
+ if (method.IsOneOf(__aggregateWithSeedAndFuncMethods))
+ {
+ var genericArguments = method.GetGenericArguments();
+ var sourceType = genericArguments[0];
+ var accumulateType = genericArguments[1];
+ if (sourceType == accumulateType)
+ {
+ itemSerializer = targetSerializer;
+ }
+ }
+
+ return itemSerializer == null ? null : IEnumerableSerializer.Create(itemSerializer);
+ }
}
}
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs
index fa87251f83a..7f1b193fe3b 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AppendOrPrependMethodToAggregationExpressionTranslator.cs
@@ -15,6 +15,7 @@
using System.Linq.Expressions;
using System.Reflection;
+using MongoDB.Bson.Serialization;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Reflection;
@@ -38,7 +39,7 @@ internal static class AppendOrPrependMethodToAggregationExpressionTranslator
QueryableMethod.Append
};
- public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
+ public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer)
{
var method = expression.Method;
var arguments = expression.Arguments;
@@ -48,32 +49,22 @@ public static AggregationExpression Translate(TranslationContext context, Method
var sourceExpression = arguments[0];
var elementExpression = arguments[1];
- var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
+ var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, targetSerializer);
NestedAsQueryableHelper.EnsureQueryableMethodHasNestedAsQueryableSource(expression, sourceTranslation);
var itemSerializer = ArraySerializerHelper.GetItemSerializer(sourceTranslation.Serializer);
- AggregationExpression elementTranslation;
- if (elementExpression is ConstantExpression elementConstantExpression)
+ var elementTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, elementExpression, itemSerializer);
+ if (!elementTranslation.Serializer.Equals(itemSerializer))
{
- var value = elementConstantExpression.Value;
- var serializedValue = SerializationHelper.SerializeValue(itemSerializer, value);
- elementTranslation = new AggregationExpression(elementExpression, AstExpression.Constant(serializedValue), itemSerializer);
- }
- else
- {
- elementTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, elementExpression);
- if (!elementTranslation.Serializer.Equals(itemSerializer))
- {
- throw new ExpressionNotSupportedException(expression, because: "argument serializers are not compatible");
- }
+ throw new ExpressionNotSupportedException(expression, because: "argument serializers are not compatible");
}
var ast = method.IsOneOf(__appendMethods) ?
AstExpression.ConcatArrays(sourceTranslation.Ast, AstExpression.ComputedArray(elementTranslation.Ast)) :
AstExpression.ConcatArrays(AstExpression.ComputedArray(elementTranslation.Ast), sourceTranslation.Ast);
- var serializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);
+ var resultSerializer = NestedAsQueryableSerializer.CreateIEnumerableOrNestedAsQueryableSerializer(expression.Type, itemSerializer);
- return new AggregationExpression(expression, ast, serializer);
+ return new AggregationExpression(expression, ast, resultSerializer);
}
throw new ExpressionNotSupportedException(expression);
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AsQueryableMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AsQueryableMethodToAggregationExpressionTranslator.cs
index b97dd8615b8..72a8ed788a8 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AsQueryableMethodToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/AsQueryableMethodToAggregationExpressionTranslator.cs
@@ -23,7 +23,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg
{
internal static class AsQueryableMethodToAggregationExpressionTranslator
{
- public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
+ public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer)
{
var method = expression.Method;
var arguments = expression.Arguments;
@@ -31,12 +31,13 @@ public static AggregationExpression Translate(TranslationContext context, Method
if (method.Is(QueryableMethod.AsQueryable))
{
var sourceExpression = arguments[0];
- var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression);
+ var sourceTargetSerializer = GetSourceTargetSerializer(targetSerializer);
+ var sourceTranslation = ExpressionToAggregationExpressionTranslator.TranslateEnumerable(context, sourceExpression, sourceTargetSerializer);
IBsonSerializer serializer;
if (sourceTranslation.Serializer is INestedAsQueryableSerializer)
{
- serializer = sourceTranslation.Serializer;
+ serializer = sourceTranslation.Serializer;
}
else
{
@@ -49,5 +50,16 @@ public static AggregationExpression Translate(TranslationContext context, Method
throw new ExpressionNotSupportedException(expression);
}
+
+ private static IBsonSerializer GetSourceTargetSerializer(IBsonSerializer targetSerializer)
+ {
+ if (targetSerializer is IBsonArraySerializer arraySerializer)
+ {
+ var itemSerializer = ArraySerializerHelper.GetItemSerializer(arraySerializer);
+ return IEnumerableSerializer.Create(itemSerializer);
+ }
+
+ return null;
+ }
}
}
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareToMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareToMethodToAggregationExpressionTranslator.cs
index 014a337dcf4..98ae2595962 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareToMethodToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/CompareToMethodToAggregationExpressionTranslator.cs
@@ -30,11 +30,23 @@ public static AggregationExpression Translate(TranslationContext context, Method
if (IComparableMethod.IsCompareToMethod(method))
{
var objectExpression = expression.Object;
- var objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression);
var otherExpression = arguments[0];
- var otherTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, otherExpression);
+
+ AggregationExpression objectTranslation;
+ AggregationExpression otherTranslation;
+ if (objectExpression is ConstantExpression && otherExpression is not ConstantExpression)
+ {
+ otherTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, otherExpression);
+ objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression, otherTranslation.Serializer);
+ }
+ else
+ {
+ objectTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, objectExpression);
+ otherTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, otherExpression, objectTranslation.Serializer);
+ }
var ast = AstExpression.Cmp(objectTranslation.Ast, otherTranslation.Ast);
- return new AggregationExpression(expression, ast, new Int32Serializer());
+
+ return new AggregationExpression(expression, ast, Int32Serializer.Instance);
}
throw new ExpressionNotSupportedException(expression);
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConcatMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConcatMethodToAggregationExpressionTranslator.cs
index 7dc3f32b209..52517ccb22f 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConcatMethodToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConcatMethodToAggregationExpressionTranslator.cs
@@ -14,16 +14,17 @@
*/
using System.Linq.Expressions;
+using MongoDB.Bson.Serialization;
namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators.MethodTranslators
{
internal static class ConcatMethodToAggregationExpressionTranslator
{
- public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression)
+ public static AggregationExpression Translate(TranslationContext context, MethodCallExpression expression, IBsonSerializer targetSerializer)
{
if (EnumerableConcatMethodToAggregationExpressionTranslator.CanTranslate(expression))
{
- return EnumerableConcatMethodToAggregationExpressionTranslator.Translate(context, expression);
+ return EnumerableConcatMethodToAggregationExpressionTranslator.Translate(context, expression, targetSerializer);
}
if (StringConcatMethodToAggregationExpressionTranslator.CanTranslate(expression, out var method, out var arguments))
diff --git a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConstantMethodToAggregationExpressionTranslator.cs b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConstantMethodToAggregationExpressionTranslator.cs
index cc3aeea79e3..a14f0971f85 100644
--- a/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConstantMethodToAggregationExpressionTranslator.cs
+++ b/src/MongoDB.Driver/Linq/Linq3Implementation/Translators/ExpressionToAggregationExpressionTranslators/MethodTranslators/ConstantMethodToAggregationExpressionTranslator.cs
@@ -35,7 +35,7 @@ public static AggregationExpression Translate(TranslationContext context, Method
var valueExpression = arguments[0];
var value = valueExpression.GetConstantValue