Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate System.Net.Security for nullable reference types #32541

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.


#nullable enable
using System;
using System.Collections.Generic;
using System.Diagnostics;
Expand Down Expand Up @@ -301,8 +301,8 @@ internal static void SslSetTargetName(SafeSslHandle sslHandle, string targetName

internal static unsafe void SslCtxSetAlpnProtos(SafeSslHandle ctx, List<SslApplicationProtocol> protocols)
{
SafeCreateHandle cfProtocolsRefs = null;
SafeCreateHandle[] cfProtocolsArrayRef = null;
SafeCreateHandle? cfProtocolsRefs = null;
SafeCreateHandle[]? cfProtocolsArrayRef = null;
try
{
if (protocols.Count == 1 && protocols[0] == SslApplicationProtocol.Http2)
Expand Down Expand Up @@ -353,7 +353,7 @@ internal static unsafe void SslCtxSetAlpnProtos(SafeSslHandle ctx, List<SslAppli
}
}

internal static byte[] SslGetAlpnSelected(SafeSslHandle ssl)
internal static byte[]? SslGetAlpnSelected(SafeSslHandle ssl)
{
SafeCFDataHandle protocol;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable enable
using System;

internal static partial class Interop
Expand All @@ -14,7 +15,7 @@ internal SslException()
{
}

internal SslException(int errorCode, string message)
internal SslException(int errorCode, string? message)
: base(message)
{
HResult = errorCode;
Expand All @@ -26,7 +27,7 @@ internal static partial class AppleCrypto
{
internal static Exception CreateExceptionForOSStatus(int osStatus)
{
string msg = GetSecErrorString(osStatus);
string? msg = GetSecErrorString(osStatus);

// msg might be null, but that's OK
return new SslException(osStatus, msg);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable enable
using System;
using System.Collections.Generic;
using System.Diagnostics;
Expand Down Expand Up @@ -43,14 +44,14 @@ public GssApiException(Status majorStatus, Status minorStatus, string helpText)
_minorStatus = minorStatus;
}

private static string GetGssApiDisplayStatus(Status majorStatus, Status minorStatus, string helpText)
private static string GetGssApiDisplayStatus(Status majorStatus, Status minorStatus, string? helpText)
{
string majorError = GetGssApiDisplayStatus(majorStatus, isMinor: false);
string? majorError = GetGssApiDisplayStatus(majorStatus, isMinor: false);
string errorMessage;

if (minorStatus != Status.GSS_S_COMPLETE)
{
string minorError = GetGssApiDisplayStatus(minorStatus, isMinor: true);
string? minorError = GetGssApiDisplayStatus(minorStatus, isMinor: true);
errorMessage = (majorError != null && minorError != null) ?
SR.Format(SR.net_gssapi_operation_failed_detailed, majorError, minorError) :
SR.Format(SR.net_gssapi_operation_failed, majorStatus.ToString("x"), minorStatus.ToString("x"));
Expand All @@ -70,7 +71,7 @@ private static string GetGssApiDisplayStatus(Status majorStatus, Status minorSta
return errorMessage;
}

private static string GetGssApiDisplayStatus(Status status, bool isMinor)
private static string? GetGssApiDisplayStatus(Status status, bool isMinor)
{
GssBuffer displayBuffer = default(GssBuffer);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable enable
using System;
using System.Diagnostics;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -79,9 +80,9 @@ internal static extern Status InitSecContext(
SafeGssCredHandle initiatorCredHandle,
ref SafeGssContextHandle contextHandle,
bool isNtlmOnly,
SafeGssNameHandle targetName,
SafeGssNameHandle? targetName,
uint reqFlags,
byte[] inputBytes,
byte[]? inputBytes,
int inputLength,
ref GssBuffer token,
out uint retFlags,
Expand All @@ -95,9 +96,9 @@ internal static extern Status InitSecContext(
bool isNtlmOnly,
IntPtr cbt,
int cbtSize,
SafeGssNameHandle targetName,
SafeGssNameHandle? targetName,
uint reqFlags,
byte[] inputBytes,
byte[]? inputBytes,
int inputLength,
ref GssBuffer token,
out uint retFlags,
Expand All @@ -108,7 +109,7 @@ internal static extern Status AcceptSecContext(
out Status minorStatus,
SafeGssCredHandle acceptorCredHandle,
ref SafeGssContextHandle acceptContextHandle,
byte[] inputBytes,
byte[]? inputBytes,
int inputLength,
ref GssBuffer token,
out uint retFlags,
Expand All @@ -122,13 +123,13 @@ internal static extern Status DeleteSecContext(
[DllImport(Interop.Libraries.NetSecurityNative, EntryPoint="NetSecurityNative_GetUser")]
internal static extern Status GetUser(
out Status minorStatus,
SafeGssContextHandle acceptContextHandle,
SafeGssContextHandle? acceptContextHandle,
ref GssBuffer token);

[DllImport(Interop.Libraries.NetSecurityNative, EntryPoint="NetSecurityNative_Wrap")]
private static extern Status Wrap(
out Status minorStatus,
SafeGssContextHandle contextHandle,
SafeGssContextHandle? contextHandle,
bool isEncrypt,
byte[] inputBytes,
int offset,
Expand All @@ -138,15 +139,15 @@ private static extern Status Wrap(
[DllImport(Interop.Libraries.NetSecurityNative, EntryPoint="NetSecurityNative_Unwrap")]
private static extern Status Unwrap(
out Status minorStatus,
SafeGssContextHandle contextHandle,
SafeGssContextHandle? contextHandle,
byte[] inputBytes,
int offset,
int count,
ref GssBuffer outBuffer);

internal static Status WrapBuffer(
out Status minorStatus,
SafeGssContextHandle contextHandle,
SafeGssContextHandle? contextHandle,
bool isEncrypt,
byte[] inputBytes,
int offset,
Expand All @@ -162,7 +163,7 @@ internal static Status WrapBuffer(

internal static Status UnwrapBuffer(
out Status minorStatus,
SafeGssContextHandle contextHandle,
SafeGssContextHandle? contextHandle,
byte[] inputBytes,
int offset,
int count,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.

#nullable enable
using System;
using System.Buffers;
using System.Collections.Generic;
Expand All @@ -27,13 +28,13 @@ internal static partial class OpenSsl
private static readonly IdnMapping s_idnMapping = new IdnMapping();

#region internal methods
internal static SafeChannelBindingHandle QueryChannelBinding(SafeSslHandle context, ChannelBindingKind bindingType)
internal static SafeChannelBindingHandle? QueryChannelBinding(SafeSslHandle context, ChannelBindingKind bindingType)
{
Debug.Assert(
bindingType != ChannelBindingKind.Endpoint,
"Endpoint binding should be handled by EndpointChannelBindingToken");

SafeChannelBindingHandle bindingHandle;
SafeChannelBindingHandle? bindingHandle;
switch (bindingType)
{
case ChannelBindingKind.Unique:
Expand All @@ -50,9 +51,9 @@ internal static SafeChannelBindingHandle QueryChannelBinding(SafeSslHandle conte
return bindingHandle;
}

internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX509Handle certHandle, SafeEvpPKeyHandle certKeyHandle, EncryptionPolicy policy, SslAuthenticationOptions sslAuthenticationOptions)
internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX509Handle? certHandle, SafeEvpPKeyHandle? certKeyHandle, EncryptionPolicy policy, SslAuthenticationOptions sslAuthenticationOptions)
{
SafeSslHandle context = null;
SafeSslHandle? context = null;

// Always use SSLv23_method, regardless of protocols. It supports negotiating to the highest
// mutually supported version and can thus handle any of the set protocols, and we then use
Expand Down Expand Up @@ -122,12 +123,12 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50
// https://www.openssl.org/docs/manmaster/ssl/SSL_shutdown.html
Ssl.SslCtxSetQuietShutdown(innerContext);

byte[] cipherList =
byte[]? cipherList =
CipherSuitesPolicyPal.GetOpenSslCipherList(sslAuthenticationOptions.CipherSuitesPolicy, protocols, policy);

Debug.Assert(cipherList == null || (cipherList.Length >= 1 && cipherList[cipherList.Length - 1] == 0));

byte[] cipherSuites =
byte[]? cipherSuites =
CipherSuitesPolicyPal.GetOpenSslCipherSuites(sslAuthenticationOptions.CipherSuitesPolicy, protocols, policy);

Debug.Assert(cipherSuites == null || (cipherSuites.Length >= 1 && cipherSuites[cipherSuites.Length - 1] == 0));
Expand All @@ -151,7 +152,7 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50

if (hasCertificateAndKey)
{
SetSslCertificate(innerContext, certHandle, certKeyHandle);
SetSslCertificate(innerContext, certHandle!, certKeyHandle!);
eiriktsarpalis marked this conversation as resolved.
Show resolved Hide resolved
}

if (sslAuthenticationOptions.IsServer && sslAuthenticationOptions.RemoteCertRequired)
Expand Down Expand Up @@ -189,7 +190,7 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50
if (!sslAuthenticationOptions.IsServer)
{
// The IdnMapping converts unicode input into the IDNA punycode sequence.
string punyCode = s_idnMapping.GetAscii(sslAuthenticationOptions.TargetHost);
string punyCode = s_idnMapping.GetAscii(sslAuthenticationOptions.TargetHost!);

// Similar to windows behavior, set SNI on openssl by default for client context, ignore errors.
if (!Ssl.SslSetTlsExtHostName(context, punyCode))
Expand All @@ -203,10 +204,10 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50
bool hasCertReference = false;
try
{
certHandle.DangerousAddRef(ref hasCertReference);
certHandle!.DangerousAddRef(ref hasCertReference);
using (X509Certificate2 cert = new X509Certificate2(certHandle.DangerousGetHandle()))
{
X509Chain chain = null;
X509Chain? chain = null;
try
{
chain = TLSCertificateExtensions.BuildNewChain(cert, includeClientApplicationPolicy: false);
Expand All @@ -222,7 +223,7 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50
int elementsCount = chain.ChainElements.Count;
for (int i = 0; i < elementsCount; i++)
{
chain.ChainElements[i].Certificate.Dispose();
chain.ChainElements[i].Certificate!.Dispose();
}

chain.Dispose();
Expand All @@ -233,7 +234,7 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50
finally
{
if (hasCertReference)
certHandle.DangerousRelease();
certHandle!.DangerousRelease();
}
}

Expand All @@ -253,15 +254,15 @@ internal static SafeSslHandle AllocateSslContext(SslProtocols protocols, SafeX50
return context;
}

internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> input, out byte[] sendBuf, out int sendCount)
internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> input, out byte[]? sendBuf, out int sendCount)
{
sendBuf = null;
sendCount = 0;
Exception handshakeException = null;
Exception? handshakeException = null;

if (input.Length > 0)
{
if (Ssl.BioWrite(context.InputBio, ref MemoryMarshal.GetReference(input), input.Length) != input.Length)
if (Ssl.BioWrite(context.InputBio!, ref MemoryMarshal.GetReference(input), input.Length) != input.Length)
{
// Make sure we clear out the error that is stored in the queue
throw Crypto.CreateOpenSslCryptographicException();
Expand All @@ -271,7 +272,7 @@ internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> in
int retVal = Ssl.SslDoHandshake(context);
if (retVal != 1)
{
Exception innerError;
Exception? innerError;
Ssl.SslErrorCode error = GetSslError(context, retVal, out innerError);

if ((retVal != -1) || (error != Ssl.SslErrorCode.SSL_ERROR_WANT_READ))
Expand All @@ -283,14 +284,14 @@ internal static bool DoSslHandshake(SafeSslHandle context, ReadOnlySpan<byte> in
}
}

sendCount = Crypto.BioCtrlPending(context.OutputBio);
sendCount = Crypto.BioCtrlPending(context.OutputBio!);
if (sendCount > 0)
{
sendBuf = new byte[sendCount];

try
{
sendCount = BioRead(context.OutputBio, sendBuf, sendCount);
sendCount = BioRead(context.OutputBio!, sendBuf, sendCount);
}
catch (Exception) when (handshakeException != null)
{
Expand Down Expand Up @@ -330,7 +331,7 @@ internal static int Encrypt(SafeSslHandle context, ReadOnlySpan<byte> input, ref
errorCode = Ssl.SslErrorCode.SSL_ERROR_NONE;

int retVal;
Exception innerError = null;
Exception? innerError = null;

lock (context)
{
Expand Down Expand Up @@ -359,14 +360,14 @@ internal static int Encrypt(SafeSslHandle context, ReadOnlySpan<byte> input, ref
}
else
{
int capacityNeeded = Crypto.BioCtrlPending(context.OutputBio);
int capacityNeeded = Crypto.BioCtrlPending(context.OutputBio!);

if (output == null || output.Length < capacityNeeded)
{
output = new byte[capacityNeeded];
}

retVal = BioRead(context.OutputBio, output, capacityNeeded);
retVal = BioRead(context.OutputBio!, output, capacityNeeded);

if (retVal <= 0)
{
Expand All @@ -386,8 +387,8 @@ internal static int Decrypt(SafeSslHandle context, byte[] outBuffer, int offset,
#endif
errorCode = Ssl.SslErrorCode.SSL_ERROR_NONE;

int retVal = BioWrite(context.InputBio, outBuffer, offset, count);
Exception innerError = null;
int retVal = BioWrite(context.InputBio!, outBuffer, offset, count);
Exception? innerError = null;

lock (context)
{
Expand Down Expand Up @@ -561,7 +562,7 @@ private static int BioWrite(SafeBioHandle bio, byte[] buffer, int offset, int co
return bytes;
}

private static Ssl.SslErrorCode GetSslError(SafeSslHandle context, int result, out Exception innerError)
private static Ssl.SslErrorCode GetSslError(SafeSslHandle context, int result, out Exception? innerError)
{
ErrorInfo lastErrno = Sys.GetLastErrorInfo(); // cache it before we make more P/Invoke calls, just in case we need it

Expand Down Expand Up @@ -633,17 +634,17 @@ internal static SslException CreateSslException(string message)

internal sealed class SslException : Exception
{
public SslException(string inputMessage)
public SslException(string? inputMessage)
: base(inputMessage)
{
}

public SslException(string inputMessage, Exception ex)
public SslException(string? inputMessage, Exception? ex)
: base(inputMessage, ex)
{
}

public SslException(string inputMessage, int error)
public SslException(string? inputMessage, int error)
: this(inputMessage)
{
HResult = error;
Expand Down
Loading