Skip to content

Commit

Permalink
Add nullability information to extract method (#37851)
Browse files Browse the repository at this point in the history
Extract method will now maintain nullable reference types for parameters and return types. If the parameters or return types are determined to be non-nullable through flow state analysis, then we will adjust them to be non-null when generating. 

For return type, the adjustment is done after the method is generated. This is because the generator can and will introduce new return statements that need to be analyzed for null state. The only way a return type can be modified to non null is if it:

1. Was null annotated to start with
2. All returns are determined to return non-null by flow state analysis. 

Similarly, reference parameters can be assigned null or non-null values inside a method. We determine parameters can be adjusted if:

1. Null is not passed into the parameter.
2. Null is never assigned to the parameter.
  • Loading branch information
ryzngard authored Aug 22, 2019
1 parent 7eb2327 commit bee0632
Show file tree
Hide file tree
Showing 10 changed files with 811 additions and 113 deletions.

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.ExtractMethod;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp.ExtractMethod
Expand Down Expand Up @@ -129,6 +131,45 @@ protected override bool ReadOnlyFieldAllowed()
var scope = this.SelectionResult.GetContainingScopeOf<ConstructorDeclarationSyntax>();
return scope == null;
}

protected override ITypeSymbol GetSymbolType(SemanticModel semanticModel, ISymbol symbol)
{
var selectionOperation = semanticModel.GetOperation(this.SelectionResult.GetContainingScope());

switch (symbol)
{
case ILocalSymbol localSymbol when localSymbol.NullableAnnotation == NullableAnnotation.Annotated:
case IParameterSymbol parameterSymbol when parameterSymbol.NullableAnnotation == NullableAnnotation.Annotated:

// For local symbols and parameters, we can check what the flow state
// for refences to the symbols are and determine if we can change
// the nullability to a less permissive state.
var references = selectionOperation.DescendantsAndSelf()
.Where(IsSymbolReferencedByOperation);

if (AreAllReferencesNotNull(references))
{
return base.GetSymbolType(semanticModel, symbol).WithNullability(NullableAnnotation.NotAnnotated);
}

return base.GetSymbolType(semanticModel, symbol);

default:
return base.GetSymbolType(semanticModel, symbol);
}

bool AreAllReferencesNotNull(IEnumerable<IOperation> references)
=> references.All(r => semanticModel.GetTypeInfo(r.Syntax).Nullability.FlowState == NullableFlowState.NotNull);

bool IsSymbolReferencedByOperation(IOperation operation)
=> operation switch
{
ILocalReferenceOperation localReference => localReference.Local.Equals(symbol),
IParameterReferenceOperation parameterReference => parameterReference.Parameter.Equals(symbol),
IAssignmentOperation assignment => IsSymbolReferencedByOperation(assignment.Target),
_ => false
};
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
using Microsoft.CodeAnalysis.Editing;
using Microsoft.CodeAnalysis.ExtractMethod;
using Microsoft.CodeAnalysis.Formatting;
using Microsoft.CodeAnalysis.LanguageServices;
using Microsoft.CodeAnalysis.Operations;
using Microsoft.CodeAnalysis.Shared.Extensions;
using Microsoft.CodeAnalysis.Simplification;
using Roslyn.Utilities;
Expand Down Expand Up @@ -283,21 +285,21 @@ private OperationStatus CheckActiveStatements(IEnumerable<StatementSyntax> state

foreach (var statement in statements)
{
var declStatement = statement as LocalDeclarationStatementSyntax;
if (declStatement == null)
if (statement is LocalDeclarationStatementSyntax declStatement)
{
// found one
return OperationStatus.Succeeded;
}

foreach (var variable in declStatement.Declaration.Variables)
{
if (variable.Initializer != null)
foreach (var variable in declStatement.Declaration.Variables)
{
// found one
return OperationStatus.Succeeded;
if (variable.Initializer != null)
{
// found one
return OperationStatus.Succeeded;
}
}
}
else
{
return OperationStatus.Succeeded;
}
}

return OperationStatus.NoActiveStatement;
Expand Down Expand Up @@ -452,8 +454,7 @@ private StatementSyntax FixDeclarationExpressionsAndDeclarationPatterns(Statemen
// We don't have a good refactoring for this, so we just annotate the conflict
// For instance, when a local declared by a pattern declaration (`3 is int i`) is
// used outside the block we're trying to extract.
var designation = pattern.Designation as SingleVariableDesignationSyntax;
if (designation == null)
if (!(pattern.Designation is SingleVariableDesignationSyntax designation))
{
break;
}
Expand Down Expand Up @@ -652,6 +653,69 @@ protected StatementSyntax GetStatementContainingInvocationToExtractedMethodWorke

return SyntaxFactory.ExpressionStatement(callSignature);
}

protected override async Task<SemanticDocument> UpdateMethodAfterGenerationAsync(
SemanticDocument originalDocument,
OperationStatus<IMethodSymbol> methodSymbolResult,
CancellationToken cancellationToken)
{
// Only need to update for nullable reference types in return
if (methodSymbolResult.Data.ReturnType.GetNullability() != NullableAnnotation.Annotated)
{
return await base.UpdateMethodAfterGenerationAsync(originalDocument, methodSymbolResult, cancellationToken).ConfigureAwait(false);
}

var syntaxNode = originalDocument.Root.GetAnnotatedNodesAndTokens(MethodDefinitionAnnotation).FirstOrDefault().AsNode();
if (syntaxNode == null || !(syntaxNode is MethodDeclarationSyntax methodDeclaration))
{
return await base.UpdateMethodAfterGenerationAsync(originalDocument, methodSymbolResult, cancellationToken).ConfigureAwait(false);
}

var semanticModel = originalDocument.SemanticModel;
var methodOperation = semanticModel.GetOperation(methodDeclaration, cancellationToken);

var returnOperations = methodOperation.DescendantsAndSelf().OfType<IReturnOperation>();

foreach (var returnOperation in returnOperations)
{
// If thereturn statement is located in a nested local function or lambda it
// shouldn't contribute to the nullability of the extracted method's return type
if (!ReturnOperationBelongsToMethod(returnOperation.Syntax, methodOperation.Syntax))
{
continue;
}

var syntax = returnOperation.ReturnedValue?.Syntax ?? returnOperation.Syntax;
var returnTypeInfo = semanticModel.GetTypeInfo(syntax, cancellationToken);
if (returnTypeInfo.Nullability.FlowState == NullableFlowState.MaybeNull)
{
// Flow state shows that return is correctly nullable
return await base.UpdateMethodAfterGenerationAsync(originalDocument, methodSymbolResult, cancellationToken).ConfigureAwait(false);
}
}

// Return type can be updated to not be null
var newType = methodSymbolResult.Data.ReturnType.WithNullability(NullableAnnotation.NotAnnotated);

var oldRoot = await originalDocument.Document.GetSyntaxRootAsync(cancellationToken).ConfigureAwait(false);
var newRoot = oldRoot.ReplaceNode(methodDeclaration.ReturnType, newType.GenerateTypeSyntax());

var newDocument = originalDocument.Document.WithSyntaxRoot(newRoot);
return await SemanticDocument.CreateAsync(newDocument, cancellationToken).ConfigureAwait(false);

static bool ReturnOperationBelongsToMethod(SyntaxNode returnOperationSyntax, SyntaxNode methodSyntax)
{
var enclosingMethod = returnOperationSyntax.FirstAncestorOrSelf<SyntaxNode>(n => n switch
{
BaseMethodDeclarationSyntax _ => true,
AnonymousFunctionExpressionSyntax _ => true,
LocalFunctionStatementSyntax _ => true,
_ => false
});

return enclosingMethod == methodSyntax;
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ protected override async Task<OperationStatus> CheckTypeAsync(
{
var typeName = SyntaxFactory.ParseTypeName(typeParameter.Name);
var currentType = semanticModel.GetSpeculativeTypeInfo(contextNode.SpanStart, typeName, SpeculativeBindingOption.BindAsTypeOrNamespace).Type;
if (currentType == null || !currentType.Equals(typeParameter))
if (currentType == null || !AllNullabilityIgnoringSymbolComparer.Instance.Equals(currentType, typeParameter))
{
return new OperationStatus(OperationStatusFlag.BestEffort,
string.Format(FeaturesResources.Type_parameter_0_is_hidden_by_another_type_parameter_1,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

#nullable enable

using System.Linq;
using Microsoft.CodeAnalysis.CSharp.Extensions;
using Microsoft.CodeAnalysis.CSharp.Symbols;
Expand Down Expand Up @@ -52,7 +54,6 @@ public override ITypeSymbol GetContainingScopeType()
if (!node.IsExpression())
{
Contract.Fail("this shouldn't happen");
return null;
}

// special case for array initializer and explicit cast
Expand All @@ -61,7 +62,7 @@ public override ITypeSymbol GetContainingScopeType()
var variableDeclExpression = node.GetAncestorOrThis<VariableDeclarationSyntax>();
if (variableDeclExpression != null)
{
return model.GetTypeInfo(variableDeclExpression.Type).Type;
return model.GetTypeInfo(variableDeclExpression.Type).GetTypeWithAnnotatedNullability();
}
}

Expand All @@ -80,7 +81,7 @@ public override ITypeSymbol GetContainingScopeType()

if (node.Parent is CastExpressionSyntax castExpression)
{
return model.GetTypeInfo(castExpression.Type).Type;
return model.GetTypeInfo(castExpression.Type).GetTypeWithAnnotatedNullability();
}
}

Expand All @@ -98,35 +99,35 @@ private static ITypeSymbol GetRegularExpressionType(SemanticModel semanticModel,
if (info.ConvertedType == null || info.ConvertedType.IsErrorType())
{
// there is no implicit conversion involved. no need to go further
return info.Type;
return info.GetTypeWithAnnotatedNullability();
}

// always use converted type if method group
if ((!node.IsKind(SyntaxKind.ObjectCreationExpression) && semanticModel.GetMemberGroup(expression).Length > 0) ||
IsCoClassImplicitConversion(info, conv, semanticModel.Compilation.CoClassType()))
{
return info.ConvertedType;
return info.GetConvertedTypeWithAnnotatedNullability();
}

// check implicit conversion
if (conv.IsImplicit && (conv.IsConstantExpression || conv.IsEnumeration))
{
return info.ConvertedType;
return info.GetConvertedTypeWithAnnotatedNullability();
}

// use FormattableString if conversion between String and FormattableString
if (info.Type?.SpecialType == SpecialType.System_String &&
info.ConvertedType?.IsFormattableString() == true)
{
return info.ConvertedType;
return info.GetConvertedTypeWithAnnotatedNullability();
}

// always try to use type that is more specific than object type if possible.
return !info.Type.IsObjectType() ? info.Type : info.ConvertedType;
return !info.Type.IsObjectType() ? info.GetTypeWithAnnotatedNullability() : info.GetConvertedTypeWithAnnotatedNullability();
}
}

private static bool IsCoClassImplicitConversion(TypeInfo info, Conversion conversion, ISymbol coclassSymbol)
private static bool IsCoClassImplicitConversion(TypeInfo info, Conversion conversion, ISymbol? coclassSymbol)
{
if (!conversion.IsImplicit ||
info.ConvertedType == null ||
Expand Down
12 changes: 8 additions & 4 deletions src/Features/Core/Portable/ExtractMethod/Extensions.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
// Copyright (c) Microsoft. All Rights Reserved. Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

#nullable enable

using System.Collections.Generic;
using System.Linq;
using System.Threading;
Expand Down Expand Up @@ -54,7 +56,7 @@ public static OperationStatusFlag RemoveFlag(this OperationStatusFlag baseFlag,
return baseFlag & ~flagToRemove;
}

public static ITypeSymbol GetLambdaOrAnonymousMethodReturnType(this SemanticModel binding, SyntaxNode node)
public static ITypeSymbol? GetLambdaOrAnonymousMethodReturnType(this SemanticModel binding, SyntaxNode node)
{
var info = binding.GetSymbolInfo(node);
if (info.Symbol == null)
Expand All @@ -63,12 +65,12 @@ public static ITypeSymbol GetLambdaOrAnonymousMethodReturnType(this SemanticMode
}

var methodSymbol = info.Symbol as IMethodSymbol;
if (methodSymbol.MethodKind != MethodKind.AnonymousFunction)
if (methodSymbol?.MethodKind != MethodKind.AnonymousFunction)
{
return null;
}

return methodSymbol.ReturnType;
return methodSymbol.GetReturnTypeWithAnnotatedNullability();
}

public static Task<SemanticDocument> WithSyntaxRootAsync(this SemanticDocument semanticDocument, SyntaxNode root, CancellationToken cancellationToken)
Expand All @@ -89,7 +91,9 @@ public static SyntaxToken GetTokenWithAnnotation(this SemanticDocument document,
/// </summary>
public static T ResolveType<T>(this SemanticModel semanticModel, T symbol) where T : class, ITypeSymbol
{
return (T)symbol.GetSymbolKey().Resolve(semanticModel.Compilation).GetAnySymbol();
// Can be cleaned up when https://github.com/dotnet/roslyn/issues/38061 is resolved
var typeSymbol = (T)symbol.GetSymbolKey().Resolve(semanticModel.Compilation).GetAnySymbol();
return typeSymbol.WithNullability(symbol.GetNullability());
}

/// <summary>
Expand Down
Loading

0 comments on commit bee0632

Please sign in to comment.