Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CSHARP-5461: Add targetSerializer parameter to Translate methods. #1596

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,22 @@

namespace MongoDB.Bson.Serialization
{
/// <summary>
/// An interface implemented by BsonClassMapSerializer.
/// </summary>
public interface IBsonClassMapSerializer
{
/// <summary>
/// Gets the class map for a BsonClassMapSerializer.
/// </summary>
public BsonClassMap ClassMap { get; }
}

/// <summary>
/// Represents a serializer for a class map.
/// </summary>
/// <typeparam name="TClass">The type of the class.</typeparam>
public sealed class BsonClassMapSerializer<TClass> : SerializerBase<TClass>, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention
public sealed class BsonClassMapSerializer<TClass> : SerializerBase<TClass>, IBsonClassMapSerializer, IBsonIdProvider, IBsonDocumentSerializer, IBsonPolymorphicSerializer, IHasDiscriminatorConvention
{
// private fields
private readonly BsonClassMap _classMap;
Expand Down Expand Up @@ -57,6 +68,9 @@ public BsonClassMapSerializer(BsonClassMap classMap)
}

// public properties
/// <inheritdoc/>
public BsonClassMap ClassMap => _classMap;

/// <inheritdoc/>
public IDiscriminatorConvention DiscriminatorConvention => _classMap.GetDiscriminatorConvention();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,27 @@
using MongoDB.Bson.IO;
using MongoDB.Bson.Serialization;
using MongoDB.Driver.Linq.Linq3Implementation.Ast.Expressions;
using MongoDB.Driver.Linq.Linq3Implementation.Misc;
using MongoDB.Driver.Linq.Linq3Implementation.Serializers;

namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggregationExpressionTranslators
{
internal static class ConstantExpressionToAggregationExpressionTranslator
{
public static AggregationExpression Translate(ConstantExpression constantExpression)
public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer targetSerializer)
{
var constantType = constantExpression.Type;
var constantSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType);
return Translate(constantExpression, constantSerializer);
}
var resultSerializer = targetSerializer;
if (resultSerializer == null)
{
var constantType = constantExpression.Type;
resultSerializer = StandardSerializers.TryGetSerializer(constantType, out var serializer) ? serializer : BsonSerializer.LookupSerializer(constantType);
}

public static AggregationExpression Translate(ConstantExpression constantExpression, IBsonSerializer constantSerializer)
{
var constantValue = constantExpression.Value;
var serializedValue = constantSerializer.ToBsonValue(constantValue);
var value = constantExpression.Value;
var serializedValue = SerializationHelper.SerializeValue(resultSerializer, value);
var ast = AstExpression.Constant(serializedValue);
return new AggregationExpression(constantExpression, ast, constantSerializer);

return new AggregationExpression(constantExpression, ast, resultSerializer);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg
internal static class ExpressionToAggregationExpressionTranslator
{
// public static methods
public static AggregationExpression Translate(TranslationContext context, Expression expression)
public static AggregationExpression Translate(TranslationContext context, Expression expression, IBsonSerializer targetSerializer = null)
{
switch (expression.NodeType)
{
Expand Down Expand Up @@ -63,25 +63,25 @@ public static AggregationExpression Translate(TranslationContext context, Expres
case ExpressionType.ArrayLength:
return ArrayLengthExpressionToAggregationExpressionTranslator.Translate(context, (UnaryExpression)expression);
case ExpressionType.Call:
return MethodCallExpressionToAggregationExpressionTranslator.Translate(context, (MethodCallExpression)expression);
return MethodCallExpressionToAggregationExpressionTranslator.Translate(context, (MethodCallExpression)expression, targetSerializer);
case ExpressionType.Conditional:
return ConditionalExpressionToAggregationExpressionTranslator.Translate(context, (ConditionalExpression)expression);
case ExpressionType.Constant:
return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression);
return ConstantExpressionToAggregationExpressionTranslator.Translate((ConstantExpression)expression, targetSerializer);
case ExpressionType.Index:
return IndexExpressionToAggregationExpressionTranslator.Translate(context, (IndexExpression)expression);
case ExpressionType.ListInit:
return ListInitExpressionToAggregationExpressionTranslator.Translate(context, (ListInitExpression)expression);
case ExpressionType.MemberAccess:
return MemberExpressionToAggregationExpressionTranslator.Translate(context, (MemberExpression)expression);
case ExpressionType.MemberInit:
return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression);
return MemberInitExpressionToAggregationExpressionTranslator.Translate(context, (MemberInitExpression)expression, targetSerializer);
case ExpressionType.Negate:
return NegateExpressionToAggregationExpressionTranslator.Translate(context, (UnaryExpression)expression);
case ExpressionType.New:
return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression);
return NewExpressionToAggregationExpressionTranslator.Translate(context, (NewExpression)expression, targetSerializer);
case ExpressionType.NewArrayInit:
return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression);
return NewArrayInitExpressionToAggregationExpressionTranslator.Translate(context, (NewArrayExpression)expression, targetSerializer);
case ExpressionType.Parameter:
return ParameterExpressionToAggregationExpressionTranslator.Translate(context, (ParameterExpression)expression);
case ExpressionType.TypeIs:
Expand All @@ -91,19 +91,19 @@ public static AggregationExpression Translate(TranslationContext context, Expres
throw new ExpressionNotSupportedException(expression);
}

public static AggregationExpression TranslateEnumerable(TranslationContext context, Expression expression)
public static AggregationExpression TranslateEnumerable(TranslationContext context, Expression expression, IBsonSerializer targetSerializer = null)
{
var aggregateExpression = Translate(context, expression);
var aggregateExpression = Translate(context, expression, targetSerializer);

var serializer = aggregateExpression.Serializer;
if (serializer is IWrappedEnumerableSerializer wrappedEnumerableSerializer)
var resultSerializer = aggregateExpression.Serializer;
if (resultSerializer is IWrappedEnumerableSerializer wrappedEnumerableSerializer)
{
var enumerableFieldName = wrappedEnumerableSerializer.EnumerableFieldName;
var enumerableElementSerializer = wrappedEnumerableSerializer.EnumerableElementSerializer;
var enumerableSerializer = IEnumerableSerializer.Create(enumerableElementSerializer);
var ast = AstExpression.GetField(aggregateExpression.Ast, enumerableFieldName);
var itemSerializer = wrappedEnumerableSerializer.EnumerableElementSerializer;

return new AggregationExpression(aggregateExpression.Expression, ast, enumerableSerializer);
var ast = AstExpression.GetField(aggregateExpression.Ast, enumerableFieldName);
resultSerializer = IEnumerableSerializer.Create(itemSerializer);
return new AggregationExpression(aggregateExpression.Expression, ast, resultSerializer);
}

return aggregateExpression;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,22 +28,28 @@ namespace MongoDB.Driver.Linq.Linq3Implementation.Translators.ExpressionToAggreg
{
internal static class MemberInitExpressionToAggregationExpressionTranslator
{
public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression)
public static AggregationExpression Translate(TranslationContext context, MemberInitExpression expression, IBsonSerializer targetSerializer)
{
if (expression.Type == typeof(BsonDocument))
{
return NewBsonDocumentExpressionToAggregationExpressionTranslator.Translate(context, expression);
}

return Translate(context, expression, expression.NewExpression, expression.Bindings);
return Translate(context, expression, expression.NewExpression, expression.Bindings, targetSerializer);
}

public static AggregationExpression Translate(
TranslationContext context,
Expression expression,
NewExpression newExpression,
IReadOnlyList<MemberBinding> bindings)
IReadOnlyList<MemberBinding> bindings,
IBsonSerializer targetSerializer)
{
if (targetSerializer != null)
{
return TranslateWithTargetSerializer(context, expression, newExpression, bindings, targetSerializer);
}

var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct
var constructorArguments = newExpression.Arguments;
var computedFields = new List<AstComputedField>();
Expand Down Expand Up @@ -83,8 +89,7 @@ public static AggregationExpression Translate(
foreach (var binding in bindings)
{
var memberAssignment = (MemberAssignment)binding;
var member = memberAssignment.Member;
var memberMap = FindMemberMap(expression, classMap, member.Name);
var memberMap = FindMemberMap(expression, classMap, memberInfo: memberAssignment.Member);
var valueExpression = memberAssignment.Expression;
var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression);
var memberSerializer = CoerceSourceSerializerToMemberSerializer(memberMap, valueTranslation.Serializer);
Expand All @@ -100,6 +105,70 @@ public static AggregationExpression Translate(
return new AggregationExpression(expression, ast, serializer);
}

private static AggregationExpression TranslateWithTargetSerializer(
TranslationContext context,
Expression expression,
NewExpression newExpression,
IReadOnlyList<MemberBinding> bindings,
IBsonSerializer targetSerializer)
{
var resultSerializer = targetSerializer as IBsonDocumentSerializer;
if (resultSerializer == null)
{
throw new ExpressionNotSupportedException(expression, because: $"serializer class {targetSerializer.GetType()} does not implement IBsonDocumentSerializer.");
}

var constructorInfo = newExpression.Constructor; // note: can be null when using the default constructor with a struct
var constructorArguments = newExpression.Arguments;
var computedFields = new List<AstComputedField>();

if (constructorInfo != null && constructorArguments.Count > 0)
{
var constructorParameters = constructorInfo.GetParameters();

// if the documentSerializer is a BsonClassMappedSerializer we can use the classMap and creatorMap
var classMap = (resultSerializer as IBsonClassMapSerializer)?.ClassMap;
var creatorMap = classMap == null ? null : FindMatchingCreatorMap(classMap, constructorInfo);
if (creatorMap == null && classMap != null)
{
throw new ExpressionNotSupportedException(expression, because: "no matching creator map found");
}
var creatorMapArguments = creatorMap?.Arguments?.ToArray();

for (var i = 0; i < constructorParameters.Length; i++)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we iterate over creatorMap arguments instead? In such case we do not need to do a member lookup and will use custom member if it was configured explicitly.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No because the creatorMap could be null.

We only have a creatorMap IFF the serializer is an IBsonClassMapSerializer and not all serializers are.

{
var argumentExpression = constructorArguments[i];

// if we have a classMap (and therefore a creatorMap also) use them
// otherwise fall back to matching constructor parameter names to member names
var (elementName, memberSerializer) = classMap != null ?
FindMemberElementNameAndSerializer(argumentExpression, classMap, memberInfo: creatorMapArguments[i]) :
FindMemberElementNameAndSerializer(argumentExpression, resultSerializer, constructorParameterName: constructorParameters[i].Name);

var argumentTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, argumentExpression, targetSerializer: memberSerializer);
computedFields.Add(AstExpression.ComputedField(elementName, argumentTranslation.Ast));
}
}

foreach (var binding in bindings)
{
var memberAssignment = (MemberAssignment)binding;
var member = memberAssignment.Member;
var valueExpression = memberAssignment.Expression;
if (!resultSerializer.TryGetMemberSerializationInfo(member.Name, out var memberSerializationInfo))
{
throw new ExpressionNotSupportedException(valueExpression, expression, because: $"couldn't find member {member.Name}");
}
var memberSerializer = memberSerializationInfo.Serializer;

var valueTranslation = ExpressionToAggregationExpressionTranslator.Translate(context, valueExpression, memberSerializer);
computedFields.Add(AstExpression.ComputedField(memberSerializationInfo.ElementName, valueTranslation.Ast));
}

var ast = AstExpression.ComputedDocument(computedFields);
return new AggregationExpression(expression, ast, resultSerializer);
}

private static BsonClassMap CreateClassMap(Type classType, ConstructorInfo constructorInfo, out BsonCreatorMap creatorMap)
{
BsonClassMap baseClassMap = null;
Expand Down Expand Up @@ -190,22 +259,55 @@ private static void EnsureDefaultValue(BsonMemberMap memberMap)
memberMap.SetDefaultValue(defaultValue);
}

private static BsonMemberMap FindMemberMap(Expression expression, BsonClassMap classMap, string memberName)
private static BsonCreatorMap FindMatchingCreatorMap(BsonClassMap classMap, ConstructorInfo constructorInfo)
=> classMap?.CreatorMaps.FirstOrDefault(m => m.MemberInfo.Equals(constructorInfo));

private static (string, IBsonSerializer) FindMemberElementNameAndSerializer(
Expression expression,
BsonClassMap classMap,
MemberInfo memberInfo)
{
var memberMap = FindMemberMap(expression, classMap, memberInfo);
return (memberMap.ElementName, memberMap.GetSerializer());
}

private static (string, IBsonSerializer) FindMemberElementNameAndSerializer(
Expression expression,
IBsonDocumentSerializer documentSerializer,
string constructorParameterName)
{
// case insensitive GetMember could return some false hits but TryGetMemberSerializationInfo will filter them out
var bindingFlags = BindingFlags.Instance | BindingFlags.Public | BindingFlags.FlattenHierarchy | BindingFlags.IgnoreCase;
foreach (var memberInfo in documentSerializer.ValueType.GetMember(constructorParameterName, bindingFlags))
{
if (documentSerializer.TryGetMemberSerializationInfo(memberInfo.Name, out var serializationInfo))
{
return (serializationInfo.ElementName, serializationInfo.Serializer);
}
}

throw new ExpressionNotSupportedException(expression, because: $"no matching member map found for constructor parameter: {constructorParameterName}");
}

private static BsonMemberMap FindMemberMap(
Expression expression,
BsonClassMap classMap,
MemberInfo memberInfo)
{
foreach (var memberMap in classMap.DeclaredMemberMaps)
{
if (memberMap.MemberName == memberName)
if (memberMap.MemberInfo == memberInfo)
{
return memberMap;
}
}

if (classMap.BaseClassMap != null)
{
return FindMemberMap(expression, classMap.BaseClassMap, memberName);
return FindMemberMap(expression, classMap.BaseClassMap, memberInfo);
}

throw new ExpressionNotSupportedException(expression, because: $"can't find member map: {memberName}");
throw new ExpressionNotSupportedException(expression, because: $"no member map found for member: {memberInfo.Name}");
}
}
}
Loading