Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

VSTHRD114: Handle delegate and local functions #625

Merged
merged 6 commits into from
May 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,40 @@

namespace Microsoft.VisualStudio.Threading.Analyzers
{
using System.Linq;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Editing;

internal static class SyntaxGeneratorExtensions
{
/// <summary>
/// Creates a reference to a named type suitable for use in accessing a static member of the type.
/// </summary>
/// <param name="generator">The <see cref="SyntaxGenerator"/> used to create the type reference.</param>
/// <param name="typeSymbol">The named type to reference.</param>
/// <returns>A <see cref="SyntaxNode"/> representing the type reference expression.</returns>
internal static SyntaxNode TypeExpressionForStaticMemberAccess(this SyntaxGenerator generator, INamedTypeSymbol typeSymbol)
{
var qualifiedNameSyntaxKind = generator.QualifiedName(generator.IdentifierName("ignored"), generator.IdentifierName("ignored")).RawKind;
var memberAccessExpressionSyntaxKind = generator.MemberAccessExpression(generator.IdentifierName("ignored"), "ignored").RawKind;

var typeExpression = generator.TypeExpression(typeSymbol);
return QualifiedNameToMemberAccess(qualifiedNameSyntaxKind, memberAccessExpressionSyntaxKind, typeExpression, generator);

// Local function
static SyntaxNode QualifiedNameToMemberAccess(int qualifiedNameSyntaxKind, int memberAccessExpressionSyntaxKind, SyntaxNode expression, SyntaxGenerator generator)
{
if (expression.RawKind == qualifiedNameSyntaxKind)
{
var left = QualifiedNameToMemberAccess(qualifiedNameSyntaxKind, memberAccessExpressionSyntaxKind, expression.ChildNodes().First(), generator);
var right = expression.ChildNodes().Last();
return generator.MemberAccessExpression(left, right);
}

return expression;
}
}

internal static SyntaxNode? TryGetContainingDeclaration(this SyntaxGenerator generator, SyntaxNode? node, DeclarationKind? kind = null)
{
if (node is null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CodeActions;
using Microsoft.CodeAnalysis.CodeFixes;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Simplification;
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.Operations;

[ExportCodeFixProvider(LanguageNames.CSharp)]
[ExportCodeFixProvider(LanguageNames.CSharp, LanguageNames.VisualBasic)]
public class VSTHRD114AvoidReturningNullTaskCodeFix : CodeFixProvider
{
private static readonly ImmutableArray<string> ReusableFixableDiagnosticIds = ImmutableArray.Create(
Expand All @@ -28,54 +27,43 @@ public override async Task RegisterCodeFixesAsync(CodeFixContext context)
{
var semanticModel = await context.Document.GetSemanticModelAsync(context.CancellationToken).ConfigureAwait(false);
var syntaxRoot = await context.Document.GetSyntaxRootAsync(context.CancellationToken).ConfigureAwait(false);

if (!(syntaxRoot.FindNode(diagnostic.Location.SourceSpan) is LiteralExpressionSyntax nullLiteral))
{
continue;
}

var methodDeclaration = nullLiteral.FirstAncestorOrSelf<MethodDeclarationSyntax>();
if (methodDeclaration == null)
{
continue;
}

if (!(methodDeclaration.ReturnType is GenericNameSyntax genericReturnType))
{
context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD114_CodeFix_CompletedTask, ct => ApplyTaskCompletedTaskFix(ct), "CompletedTask"), diagnostic);
}
else
var nullLiteral = syntaxRoot.FindNode(diagnostic.Location.SourceSpan);
if (semanticModel.GetOperation(nullLiteral, context.CancellationToken) is ILiteralOperation { ConstantValue: { HasValue: true, Value: null } })
{
if (genericReturnType.TypeArgumentList.Arguments.Count != 1)
var typeInfo = semanticModel.GetTypeInfo(nullLiteral, context.CancellationToken);
if (typeInfo.ConvertedType is INamedTypeSymbol returnType)
{
continue;
if (returnType.IsGenericType)
{
context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD114_CodeFix_FromResult, ct => ApplyTaskFromResultFix(returnType, ct), nameof(Task.FromResult)), diagnostic);
}
else
{
context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD114_CodeFix_CompletedTask, ct => ApplyTaskCompletedTaskFix(returnType, ct), nameof(Task.CompletedTask)), diagnostic);
}
}

context.RegisterCodeFix(CodeAction.Create(Strings.VSTHRD114_CodeFix_FromResult, ct => ApplyTaskFromResultFix(genericReturnType.TypeArgumentList.Arguments[0], ct), "FromResult"), diagnostic);
}

Task<Document> ApplyTaskCompletedTaskFix(CancellationToken cancellationToken)
{
ExpressionSyntax completedTaskExpression = SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("Task"),
SyntaxFactory.IdentifierName("CompletedTask"))
.WithAdditionalAnnotations(Simplifier.Annotation);
Task<Document> ApplyTaskCompletedTaskFix(INamedTypeSymbol returnType, CancellationToken cancellationToken)
{
var generator = SyntaxGenerator.GetGenerator(context.Document);
SyntaxNode completedTaskExpression = generator.MemberAccessExpression(
generator.TypeExpressionForStaticMemberAccess(returnType),
generator.IdentifierName(nameof(Task.CompletedTask)));

return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
}
return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
}

Task<Document> ApplyTaskFromResultFix(TypeSyntax returnTypeArgument, CancellationToken cancellationToken)
{
ExpressionSyntax completedTaskExpression = SyntaxFactory.InvocationExpression(
SyntaxFactory.MemberAccessExpression(
SyntaxKind.SimpleMemberAccessExpression,
SyntaxFactory.IdentifierName("Task"),
SyntaxFactory.GenericName("FromResult").AddTypeArgumentListArguments(returnTypeArgument)))
.AddArgumentListArguments(SyntaxFactory.Argument(nullLiteral))
.WithAdditionalAnnotations(Simplifier.Annotation);
Task<Document> ApplyTaskFromResultFix(INamedTypeSymbol returnType, CancellationToken cancellationToken)
{
var generator = SyntaxGenerator.GetGenerator(context.Document);
SyntaxNode taskFromResultExpression = generator.InvocationExpression(
generator.MemberAccessExpression(
generator.TypeExpressionForStaticMemberAccess(returnType.BaseType),
generator.GenericName(nameof(Task.FromResult), returnType.TypeArguments[0])),
generator.NullLiteralExpression());

return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, completedTaskExpression)));
return Task.FromResult(context.Document.WithSyntaxRoot(syntaxRoot.ReplaceNode(nullLiteral, taskFromResultExpression)));
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
namespace Microsoft.VisualStudio.Threading.Analyzers.Tests
{
using System.Threading.Tasks;
using Microsoft.CodeAnalysis.CSharp;
using Xunit;
using VerifyCS = CSharpCodeFixVerifier<VSTHRD114AvoidReturningNullTaskAnalyzer, CodeAnalysis.Testing.EmptyCodeFixProvider>;
using VerifyVB = VisualBasicCodeFixVerifier<VSTHRD114AvoidReturningNullTaskAnalyzer, CodeAnalysis.Testing.EmptyCodeFixProvider>;
Expand Down Expand Up @@ -192,5 +191,97 @@ public Task<object> GetTaskObj(string s)
TestCode = test,
}.RunAsync();
}

[Fact]
public async Task AsyncAnonymousDelegateReturnsNull_NoDiagnostic()
{
var test = @"
using System.Threading.Tasks;

class Test
{
public Task Foo()
{
return Task.Run<object>(async delegate
{
return null;
});
}
}
";
await new VerifyCS.Test
{
TestCode = test,
}.RunAsync();
}

[Fact]
public async Task NonAsyncAnonymousDelegateReturnsNull_Diagnostic()
{
var test = @"
using System.Threading.Tasks;

class Test
{
public void Foo()
{
Task.Run<object>(delegate
{
return [|null|];
});
}
}
";
await new VerifyCS.Test
{
TestCode = test,
}.RunAsync();
}

[Fact]
public async Task LocalFunctionNonAsyncReturnsNull_Diagnostic()
{
var csharpTest = @"
using System.Threading.Tasks;

class Test
{
public void Foo()
{
Task<object> GetTaskObj()
{
return [|null|];
}
}
}
";
await new VerifyCS.Test
{
TestCode = csharpTest,
}.RunAsync();
}

[Fact]
public async Task LocalFunctionAsyncReturnsNull_NoDiagnostic()
{
var csharpTest = @"
using System.Threading.Tasks;

class Test
{
public void Foo()
{
async Task<object> GetTaskObj()
{
return null;
}
}
}
";
await new VerifyCS.Test
{
TestCode = csharpTest,
}.RunAsync();
}
}
}
Loading