Skip to content

Commit

Permalink
Fix #61 Drop if and while that has IsCancellationRequested and replac…
Browse files Browse the repository at this point in the history
…e condition with true if IsCancellationRequested is negated
  • Loading branch information
virzak committed Mar 19, 2024
1 parent 60cc83e commit fa72a97
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 21 deletions.
1 change: 1 addition & 0 deletions .vscode/spellright.dict
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ impl
func
Nuget
numerics
retval
zomp
91 changes: 71 additions & 20 deletions src/Zomp.SyncMethodGenerator/AsyncToSyncRewriter.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System.Globalization;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;
using static Microsoft.CodeAnalysis.CSharp.SyntaxFactory;

namespace Zomp.SyncMethodGenerator;

Expand Down Expand Up @@ -31,6 +30,7 @@ internal sealed class AsyncToSyncRewriter(SemanticModel semanticModel) : CSharpS
private const string Delay = "Delay";
private const string Span = "Span";
private const string CompletedTask = "CompletedTask";
private const string IsCancellationRequested = "IsCancellationRequested";
private static readonly HashSet<string> Drops = [IProgressInterface, CancellationTokenType];
private static readonly HashSet<string> InterfacesToDrop = [IProgressInterface, IAsyncResultInterface];
private static readonly Dictionary<string, string?> Replacements = new()
Expand Down Expand Up @@ -108,7 +108,7 @@ private enum SpecialMethod
var ext = whenNotNull switch
{
InvocationExpressionSyntax ies => GetExtensionExprSymbol(ies),
ConditionalAccessExpressionSyntax { Expression: InvocationExpressionSyntax ies } caes => GetExtensionExprSymbol(ies),
ConditionalAccessExpressionSyntax { Expression: InvocationExpressionSyntax ies } => GetExtensionExprSymbol(ies),
_ => null,
};

Expand Down Expand Up @@ -727,6 +727,11 @@ List<SyntaxTrivia> RemoveFirstEndIf(SyntaxTriviaList list)
retVal = retVal.WithElse(SyntaxFactory.ElseClause(SyntaxFactory.Block()));
}

if (ChecksIfNegatedIsCancellationRequested(node.Condition))
{
retVal = retVal.WithCondition(LiteralExpression(SyntaxKind.TrueLiteralExpression));
}

return retVal;
}

Expand Down Expand Up @@ -1024,6 +1029,19 @@ bool ShouldRemoveArgumentLocal(ArgumentSyntax arg, int index)
return @base.WithModifiers(StripAsyncModifier(@base.Modifiers)).WithTriviaFrom(@base);
}

/// <inheritdoc/>
public override SyntaxNode? VisitWhileStatement(WhileStatementSyntax node)
{
var @base = (WhileStatementSyntax)base.VisitWhileStatement(node)!;

if (ChecksIfNegatedIsCancellationRequested(node.Condition))
{
return @base.WithCondition(LiteralExpression(SyntaxKind.TrueLiteralExpression));
}

return @base;
}

/// <inheritdoc/>
public override SyntaxNode? VisitLocalDeclarationStatement(LocalDeclarationStatementSyntax node)
{
Expand Down Expand Up @@ -1464,9 +1482,10 @@ private static bool ShouldRemoveType(ITypeSymbol symbol)
_ => throw new NotSupportedException($"Can't process {symbol}"),
};

private static bool ShouldRemoveArgument(ISymbol symbol) => symbol switch
private static bool ShouldRemoveArgument(ISymbol symbol, bool isNegated = false) => symbol switch
{
IPropertySymbol { Name: CompletedTask } ps => ps.Type.ToString() is TaskType or ValueTaskType,
IPropertySymbol { Name: IsCancellationRequested } ps => ps.ContainingType.ToString() is CancellationTokenType && !isNegated,
IMethodSymbol ms =>
IsSpecialMethod(ms) == SpecialMethod.None
&& ((ShouldRemoveType(ms.ReturnType) && ms.MethodKind != MethodKind.LocalFunction)
Expand All @@ -1477,6 +1496,17 @@ private static bool ShouldRemoveType(ITypeSymbol symbol)
private static MemberAccessExpressionSyntax AppendSpan(ExpressionSyntax @base)
=> MemberAccessExpression(SyntaxKind.SimpleMemberAccessExpression, @base, IdentifierName(Span));

private static ExpressionSyntax RemoveParentheses(ExpressionSyntax condition)
{
var node = condition;
while (node is ParenthesizedExpressionSyntax p)
{
node = p.Expression;
}

return node;
}

private static bool IsTaskExtension(IMethodSymbol methodSymbol)
{
var returnType = GetNameWithoutTypeParams(methodSymbol.ReturnType);
Expand Down Expand Up @@ -1850,8 +1880,14 @@ private TypeSyntax ProcessSyntaxUsingSymbol(TypeSyntax typeSyntax)
_ => false,
};

private bool HasSymbolAndShouldBeRemoved(ExpressionSyntax expr)
=> GetSymbol(expr) is ISymbol symbol && ShouldRemoveArgument(symbol);
private bool ChecksIfNegatedIsCancellationRequested(ExpressionSyntax condition)
=> RemoveParentheses(condition) is PrefixUnaryExpressionSyntax { OperatorToken.RawKind: (int)SyntaxKind.ExclamationToken } pe
&& RemoveParentheses(pe.Operand) is MemberAccessExpressionSyntax { Name.Identifier.ValueText: IsCancellationRequested } mae
&& GetSymbol(mae) is { } t
&& GetNameWithoutTypeParams(t.ContainingType) is CancellationTokenType;

private bool HasSymbolAndShouldBeRemoved(ExpressionSyntax expr, bool isNegated = false)
=> GetSymbol(expr) is ISymbol symbol && ShouldRemoveArgument(symbol, isNegated);

private bool DropInvocation(InvocationExpressionSyntax invocation)
{
Expand Down Expand Up @@ -1915,28 +1951,41 @@ private bool ShouldRemoveArrowExpression(ArrowExpressionClauseSyntax? arrowNulla
return ExpressionStatement((ExpressionSyntax)Visit(result).WithoutTrivia());
}

private bool ShouldRemoveArgument(ExpressionSyntax expr) => expr switch
private bool ShouldRemoveArgument(ExpressionSyntax expr)
=> ShouldRemoveArgument(expr, new());

private bool ShouldRemoveArgument(ExpressionSyntax expr, RemoveArgumentContext context) => expr switch
{
ElementAccessExpressionSyntax ee => ShouldRemoveArgument(ee.Expression),
BinaryExpressionSyntax be => ShouldRemoveArgument(be.Left) || ShouldRemoveArgument(be.Right),
CastExpressionSyntax ce => HasSymbolAndShouldBeRemoved(expr) || ShouldRemoveArgument(ce.Expression),
ParenthesizedExpressionSyntax pe => ShouldRemoveArgument(pe.Expression),
IdentifierNameSyntax id => !id.Identifier.ValueText.EndsWithAsync() && HasSymbolAndShouldBeRemoved(id),
ElementAccessExpressionSyntax ee => ShouldRemoveArgument(ee.Expression, context),
BinaryExpressionSyntax be => ShouldRemoveArgument(be.Left, context) || ShouldRemoveArgument(be.Right, context),
CastExpressionSyntax ce => HasSymbolAndShouldBeRemoved(expr) || ShouldRemoveArgument(ce.Expression, context),
ParenthesizedExpressionSyntax pe => ShouldRemoveArgument(pe.Expression, context),
IdentifierNameSyntax id => !id.Identifier.ValueText.EndsWithAsync() && HasSymbolAndShouldBeRemoved(id, context.IsNegated),
InvocationExpressionSyntax ie => DropInvocation(ie),
ConditionalExpressionSyntax ce => ShouldRemoveArgument(ce.WhenTrue) && ShouldRemoveArgument(ce.WhenFalse),
MemberAccessExpressionSyntax mae => ShouldRemoveArgument(mae.Name),
PostfixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand),
PrefixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand),
ObjectCreationExpressionSyntax oe => ShouldRemoveArgument(oe.Type) || ShouldRemoveObjectCreation(oe),
ConditionalExpressionSyntax ce => ShouldRemoveArgument(ce.WhenTrue, context) && ShouldRemoveArgument(ce.WhenFalse, context),
MemberAccessExpressionSyntax mae => ShouldRemoveArgument(mae.Name, context),
PostfixUnaryExpressionSyntax pue => ShouldRemoveArgument(pue.Operand, context),
PrefixUnaryExpressionSyntax pue => ProcessUnaryExpression(pue, context),
ObjectCreationExpressionSyntax oe => ShouldRemoveArgument(oe.Type, context) || ShouldRemoveObjectCreation(oe),
ImplicitObjectCreationExpressionSyntax oe => ShouldRemoveObjectCreation(oe),
ConditionalAccessExpressionSyntax cae => ShouldRemoveArgument(cae.Expression),
AwaitExpressionSyntax ae => ShouldRemoveArgument(ae.Expression),
AssignmentExpressionSyntax ae => ShouldRemoveArgument(ae.Right),
ConditionalAccessExpressionSyntax cae => ShouldRemoveArgument(cae.Expression, context),
AwaitExpressionSyntax ae => ShouldRemoveArgument(ae.Expression, context),
AssignmentExpressionSyntax ae => ShouldRemoveArgument(ae.Right, context),
GenericNameSyntax gn => HasSymbolAndShouldBeRemoved(gn),
LiteralExpressionSyntax le => ShouldRemoveLiteral(le),
_ => false,
};

private bool ProcessUnaryExpression(PrefixUnaryExpressionSyntax pue, RemoveArgumentContext context)
{
if (pue.OperatorToken.IsKind(SyntaxKind.ExclamationToken))
{
context = context with { IsNegated = !context.IsNegated };
}

return ShouldRemoveArgument(pue.Operand, context);
}

private bool ShouldRemoveLiteral(LiteralExpressionSyntax literalExpression)
=> literalExpression.Token.IsKind(SyntaxKind.DefaultKeyword)
&& semanticModel.GetTypeInfo(literalExpression).Type is INamedTypeSymbol { Name: "ValueTask", IsGenericType: false } t
Expand Down Expand Up @@ -2012,4 +2061,6 @@ public SyntaxList<StatementSyntax> PostProcess(SyntaxList<StatementSyntax> state
}

private sealed record ExtensionExprSymbol(InvocationExpressionSyntax InvocationExpression, IMethodSymbol ReducedFrom, ITypeSymbol ReturnType);

private sealed record RemoveArgumentContext(bool IsNegated = false);
}
28 changes: 28 additions & 0 deletions tests/Generator.Tests/IsCancellationRequestedTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
namespace Generator.Tests;

public class IsCancellationRequestedTests
{
[Fact]
public Task WhileNotCancelled() => $$"""
while (((!((ct.IsCancellationRequested)))))
{
await Task.Delay(120000, ct);
}
""".Verify(sourceType: SourceType.MethodBody);

[Fact]
public Task IfCancelled() => $$"""
if (((((ct.IsCancellationRequested)))))
{
await Task.Delay(120000, ct);
}
""".Verify(sourceType: SourceType.MethodBody);

[Fact]
public Task IfNotCancelled() => $$"""
if (((!((ct.IsCancellationRequested)))))
{
await Task.Delay(120000, ct);
}
""".Verify(sourceType: SourceType.MethodBody);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
//HintName: Test.Class.MethodAsync.g.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.MethodAsync.g.cs
if (true)
{
global::System.Threading.Thread.Sleep(120000);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
//HintName: Test.Class.MethodAsync.g.cs
while (true)
{
global::System.Threading.Thread.Sleep(120000);
}
2 changes: 1 addition & 1 deletion tests/Generator.Tests/SourceType.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
public enum SourceType
{
/// <summary>
/// Listing of a body of a block method.
/// Listing of a body of a block method. A single parameter is passed: CancellationToken ct.
/// </summary>
MethodBody,

Expand Down

0 comments on commit fa72a97

Please sign in to comment.