Skip to content

Commit

Permalink
Fix failing projection inlining with static mapper from another assem…
Browse files Browse the repository at this point in the history
…bly (#1418)
  • Loading branch information
trejjam authored Aug 3, 2024
1 parent b845ec0 commit 6e7ab62
Show file tree
Hide file tree
Showing 12 changed files with 239 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,12 @@ public static class InlineExpressionMappingBuilder
return null;
}

var semanticModel = ctx.Compilation.GetSemanticModel(methodSyntax.SyntaxTree);
var semanticModel = ctx.GetSemanticModel(methodSyntax.SyntaxTree);
if (semanticModel is null)
{
return null;
}

var inlineRewriter = new InlineExpressionRewriter(semanticModel, ctx.FindNewInstanceMapping);
var bodyExpression = (ExpressionSyntax?)body.Expression.Accept(inlineRewriter);
if (bodyExpression == null || !inlineRewriter.CanBeInlined)
Expand Down
18 changes: 18 additions & 0 deletions src/Riok.Mapperly/Descriptors/SimpleMappingBuilderContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,24 @@ protected SimpleMappingBuilderContext(SimpleMappingBuilderContext ctx, Location?
/// </summary>
protected InlinedExpressionMappingCollection InlinedMappings { get; } = inlinedMappings;

public SemanticModel? GetSemanticModel(SyntaxTree syntaxTree)
{
if (_compilationContext.Compilation.ContainsSyntaxTree(syntaxTree))
{
return _compilationContext.Compilation.GetSemanticModel(syntaxTree);
}

foreach (var compilation in _compilationContext.NestedCompilations)
{
if (compilation.ContainsSyntaxTree(syntaxTree))
{
return compilation.GetSemanticModel(syntaxTree);
}
}

return null;
}

public virtual bool IsConversionEnabled(MappingConversionType conversionType) =>
Configuration.Mapper.EnabledConversions.HasFlag(conversionType);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ public static IncrementalValuesProvider<TSource> WhereNotNull<TSource>(this Incr
#nullable enable
}

public static IncrementalValuesProvider<TTarget> OfType<TSource, TTarget>(this IncrementalValuesProvider<TSource> source)
where TTarget : class
{
return source.Select((x, _) => x as TTarget).WhereNotNull();
}

/// <summary>
/// Registers an output node into an <see cref="IncrementalGeneratorInitializationContext"/> to output a diagnostic.
/// </summary>
Expand Down
11 changes: 9 additions & 2 deletions src/Riok.Mapperly/MapperGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.CodeAnalysis;
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.Configuration;
Expand Down Expand Up @@ -28,9 +29,15 @@ public void Initialize(IncrementalGeneratorInitializationContext context)
);
context.ReportDiagnostics(compilationDiagnostics);

var nestedCompilations = SyntaxProvider.GetNestedCompilations(context);

// build the compilation context
var compilationContext = context
.CompilationProvider.Select(static (c, _) => new CompilationContext(c, new WellKnownTypes(c), new FileNameBuilder()))
.CompilationProvider.Combine(nestedCompilations)
.Select(
static (c, _) =>
new CompilationContext(c.Left, new WellKnownTypes(c.Left), c.Right.ToImmutableArray(), new FileNameBuilder())
)
.WithTrackingName(MapperGeneratorStepNames.BuildCompilationContext);

// build the assembly default configurations
Expand Down
8 changes: 7 additions & 1 deletion src/Riok.Mapperly/Symbols/CompilationContext.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Riok.Mapperly.Descriptors;
using Riok.Mapperly.Helpers;

namespace Riok.Mapperly.Symbols;

public sealed record CompilationContext(Compilation Compilation, WellKnownTypes Types, FileNameBuilder FileNameBuilder);
public sealed record CompilationContext(
Compilation Compilation,
WellKnownTypes Types,
ImmutableArray<Compilation> NestedCompilations,
FileNameBuilder FileNameBuilder
);
37 changes: 37 additions & 0 deletions src/Riok.Mapperly/SyntaxProvider.Roslyn4.0.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#if !ROSLYN4_4_OR_GREATER

using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Helpers;
Expand All @@ -12,6 +13,42 @@ internal static class SyntaxProvider
private static readonly SymbolDisplayFormat _fullyQualifiedFormatWithoutGlobal =
SymbolDisplayFormat.FullyQualifiedFormat.WithGlobalNamespaceStyle(SymbolDisplayGlobalNamespaceStyle.OmittedAsContaining);

public static IncrementalValueProvider<ImmutableArray<Compilation>> GetNestedCompilations(
IncrementalGeneratorInitializationContext context
)
{
return context
.GetMetadataReferencesProvider()
.OfType<MetadataReference, CompilationReference>()
.Select((x, _) => x.Compilation)
.Collect();
}

/// <summary>
/// Workaround to mitigate binary incompatibility introduced in Microsoft.CodeAnalysis=4.2
/// <link cref="https://github.com/dotnet/roslyn/issues/61333#issuecomment-1129073030"/>
/// </summary>
private static IncrementalValuesProvider<MetadataReference> GetMetadataReferencesProvider(
this IncrementalGeneratorInitializationContext context
)
{
var metadataProviderProperty =
context.GetType().GetProperty(nameof(context.MetadataReferencesProvider))
?? throw new Exception($"The property '{nameof(context.MetadataReferencesProvider)}' not found");

var metadataProvider = metadataProviderProperty.GetValue(context);

if (metadataProvider is IncrementalValuesProvider<MetadataReference> metadataValuesProvider)
return metadataValuesProvider;

if (metadataProvider is IncrementalValueProvider<MetadataReference> metadataValueProvider)
return metadataValueProvider.SelectMany(static (reference, _) => ImmutableArray.Create(reference));

throw new Exception(
$"The '{nameof(context.MetadataReferencesProvider)}' is neither an '{nameof(IncrementalValuesProvider<MetadataReference>)}<{nameof(MetadataReference)}>' nor an '{nameof(IncrementalValueProvider<MetadataReference>)}<{nameof(MetadataReference)}>.'"
);
}

public static IncrementalValuesProvider<MapperDeclaration> GetMapperDeclarations(IncrementalGeneratorInitializationContext context)
{
return context
Expand Down
12 changes: 12 additions & 0 deletions src/Riok.Mapperly/SyntaxProvider.Roslyn4.4.cs
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
#if ROSLYN4_4_OR_GREATER

using System.Collections.Immutable;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Riok.Mapperly.Helpers;
using Riok.Mapperly.Symbols;

namespace Riok.Mapperly;

internal static class SyntaxProvider
{
public static IncrementalValueProvider<ImmutableArray<Compilation>> GetNestedCompilations(
IncrementalGeneratorInitializationContext context
)
{
return context
.MetadataReferencesProvider.OfType<MetadataReference, CompilationReference>()
.Select((x, _) => x.Compilation)
.Collect();
}

public static IncrementalValuesProvider<MapperDeclaration> GetMapperDeclarations(IncrementalGeneratorInitializationContext context)
{
return context
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// <auto-generated />
#nullable enable
namespace Riok.Mapperly.IntegrationTests.Mapper
{
public static partial class UseExternalMapperFromAnotherAssembly
{
[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
public static partial global::System.Linq.IQueryable<global::Riok.Mapperly.IntegrationTests.Mapper.UseExternalMapperFromAnotherAssembly.Target> ProjectToTarget(global::System.Linq.IQueryable<global::Riok.Mapperly.IntegrationTests.Mapper.UseExternalMapperFromAnotherAssembly.Source> source)
{
#nullable disable
return System.Linq.Queryable.Select(source, x => new global::Riok.Mapperly.IntegrationTests.Mapper.UseExternalMapperFromAnotherAssembly.Target()
{
DateTime = global::Riok.Mapperly.TestDependency.Mapper.DateTimeMapper.MapToDateTimeOffset(x.DateTime),
});
#nullable enable
}

[global::System.CodeDom.Compiler.GeneratedCode("Riok.Mapperly", "0.0.1.0")]
public static partial global::Riok.Mapperly.IntegrationTests.Mapper.UseExternalMapperFromAnotherAssembly.Target MapToTarget(global::Riok.Mapperly.IntegrationTests.Mapper.UseExternalMapperFromAnotherAssembly.Source source)
{
var target = new global::Riok.Mapperly.IntegrationTests.Mapper.UseExternalMapperFromAnotherAssembly.Target();
target.DateTime = global::Riok.Mapperly.TestDependency.Mapper.DateTimeMapper.MapToDateTimeOffset(source.DateTime);
return target;
}
}
}
2 changes: 1 addition & 1 deletion test/Riok.Mapperly.Tests/Helpers/GenericTypeCheckerTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ public class Mapper
var methodNode = nodes.OfType<MethodDeclarationSyntax>().Single(x => x.Identifier.Text == "Test");
var model = compilation.GetSemanticModel(classNode.SyntaxTree);
var mapperSymbol = model.GetDeclaredSymbol(classNode) ?? throw new NullReferenceException();
var compilationContext = new CompilationContext(compilation, new WellKnownTypes(compilation), new FileNameBuilder());
var compilationContext = new CompilationContext(compilation, new WellKnownTypes(compilation), [], new FileNameBuilder());
var symbolAccessor = new SymbolAccessor(compilationContext, mapperSymbol);
var typeChecker = new GenericTypeChecker(symbolAccessor, compilationContext.Types);

Expand Down
105 changes: 105 additions & 0 deletions test/Riok.Mapperly.Tests/Mapping/UseStaticMapperTest.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
using Riok.Mapperly.Diagnostics;

namespace Riok.Mapperly.Tests.Mapping;

public class UseStaticMapperTest
Expand Down Expand Up @@ -359,4 +361,107 @@ public partial class Mapper
"""
);
}

private MapperGenerationResultAssertions ExecuteStaticGenericMapperStaticMethodFromAnotherAssemblyCompilation(
bool asCompilationReference
)
{
var testDependencySource = TestSourceBuilder.SyntaxTree(
"""
using System;
using Riok.Mapperly.Abstractions;
namespace Riok.Mapperly.TestDependency.Mapper
{
[Mapper]
public static partial class DateTimeMapper
{
public static DateTimeOffset MapToDateTimeOffset(DateTime dateTime) => new(dateTime, TimeSpan.Zero);
}
}
"""
);

using var testDependencyAssembly = TestHelper.BuildAssembly(
"Riok.Mapperly.TestDependency",
asCompilationReference,
testDependencySource
);

var source = TestSourceBuilder.CSharp(
"""
using System;
using System.Linq;
using Riok.Mapperly.Abstractions;
using Riok.Mapperly.TestDependency.Mapper;
[Mapper]
[UseStaticMapper(typeof(DateTimeMapper))]
public static partial class Mapper
{
public static partial IQueryable<Target> ProjectToTarget(IQueryable<Source> source);
public static partial Target MapToTarget(Source source);
public class Source
{
public DateTime DateTime { get; set; }
}
public class Target
{
public DateTimeOffset DateTime { get; set; }
}
}
"""
);

return TestHelper
.GenerateMapper(source, TestHelperOptions.AllowDiagnostics, additionalAssemblies: [testDependencyAssembly])
.Should();
}

/// <summary>
/// This tests a situation when your IDE runs the source generator (references are other syntax trees)
/// </summary>
[Fact]
public void UseStaticGenericMapperStaticMethodFromAnotherAssemblyAsReference()
{
var result = ExecuteStaticGenericMapperStaticMethodFromAnotherAssemblyCompilation(asCompilationReference: true);

result.HaveMethodBody(
"ProjectToTarget",
"""
#nullable disable
return System.Linq.Queryable.Select(source, x => new global::Mapper.Target()
{
DateTime = new global::System.DateTimeOffset(x.DateTime, global::System.TimeSpan.Zero),
});
#nullable enable
"""
);
}

/// <summary>
/// This tests a situation when compiler produces final assembly (references are compiled assemblies)
/// </summary>
[Fact]
public void UseStaticGenericMapperStaticMethodFromAnotherAssemblyAsCompiledAssembly()
{
var result = ExecuteStaticGenericMapperStaticMethodFromAnotherAssemblyCompilation(asCompilationReference: false);

result
.HaveDiagnostic(DiagnosticDescriptors.QueryableProjectionMappingCannotInline)
.HaveMethodBody(
"ProjectToTarget",
"""
#nullable disable
return System.Linq.Queryable.Select(source, x => new global::Mapper.Target()
{
DateTime = global::Riok.Mapperly.TestDependency.Mapper.DateTimeMapper.MapToDateTimeOffset(x.DateTime),
});
#nullable enable
"""
);
}
}
7 changes: 7 additions & 0 deletions test/Riok.Mapperly.Tests/TestAssembly.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,13 @@ internal TestAssembly(Compilation compilation)
MetadataReference = MetadataReference.CreateFromStream(_data);
}

private TestAssembly(MetadataReference metadataReference)
{
MetadataReference = metadataReference;
}

internal static TestAssembly CreateAsCompilationReference(Compilation compilation) => new(compilation.ToMetadataReference());

public MetadataReference MetadataReference { get; }

public void Dispose() => _data.Dispose();
Expand Down
7 changes: 5 additions & 2 deletions test/Riok.Mapperly.Tests/TestHelper.cs
Original file line number Diff line number Diff line change
Expand Up @@ -60,10 +60,13 @@ public static CSharpCompilation BuildCompilation([StringSyntax(StringSyntax.CSha
public static CSharpCompilation BuildCompilation(params SyntaxTree[] syntaxTrees) =>
BuildCompilation("Tests", NullableContextOptions.Enable, true, syntaxTrees);

public static TestAssembly BuildAssembly(string name, params SyntaxTree[] syntaxTrees)
public static TestAssembly BuildAssembly(string name, params SyntaxTree[] syntaxTrees) =>
BuildAssembly(name, asCompilationReference: false, syntaxTrees);

public static TestAssembly BuildAssembly(string name, bool asCompilationReference, params SyntaxTree[] syntaxTrees)
{
var compilation = BuildCompilation(name, NullableContextOptions.Enable, false, syntaxTrees);
return new TestAssembly(compilation);
return asCompilationReference ? TestAssembly.CreateAsCompilationReference(compilation) : new TestAssembly(compilation);
}

public static GeneratorDriver GenerateTracked(Compilation compilation)
Expand Down

0 comments on commit 6e7ab62

Please sign in to comment.