Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC TLS resume on Linux client #64369

Merged
merged 17 commits into from
Mar 25, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Collections.Concurrent;
using System.Collections.Generic;
using System.Diagnostics;
using System.Globalization;
Expand All @@ -20,7 +21,9 @@ internal static partial class OpenSsl
{
private const string DisableTlsResumeCtxSwitch = "System.Net.Security.DisableTlsResume";
private const string DisableTlsResumeEnvironmentVariable = "DOTNET_SYSTEM_NET_SECURITY_DISABLETLSRESUME";
private const SslProtocols FakeAlpnSslProtocol = (SslProtocols)1; // used to distinguish server sessions with ALPN
private static readonly IdnMapping s_idnMapping = new IdnMapping();
private static readonly ConcurrentDictionary<SslProtocols, SafeSslContextHandle> s_clientSslContexts = new ConcurrentDictionary<SslProtocols, SafeSslContextHandle>();

#region internal methods
internal static SafeChannelBindingHandle? QueryChannelBinding(SafeSslHandle context, ChannelBindingKind bindingType)
Expand Down Expand Up @@ -58,6 +61,14 @@ private static bool DisableTlsResume
return disableTlsResume != 0;
}

// Resume does not work properly on older OpenSSL versions.
// This may be revisited but for now enable it only on 1.1.1 and above.
if (Interop.OpenSsl.OpenSslVersionNumber() < 0x10101000)
{
s_disableTlsResume = 1;
return true;
}

// First check for the AppContext switch, giving it priority over the environment variable.
if (AppContext.TryGetSwitch(DisableTlsResumeCtxSwitch, out bool value))
{
Expand All @@ -79,7 +90,7 @@ private static bool DisableTlsResume
private static SslProtocols CalculateEffectiveProtocols(SslAuthenticationOptions sslAuthenticationOptions)
{
// make sure low bit is not set since we use it in context dictionary to distinguish use with ALPN
Debug.Assert(((int)sslAuthenticationOptions.EnabledSslProtocols & 1) == 0);
Debug.Assert((sslAuthenticationOptions.EnabledSslProtocols & FakeAlpnSslProtocol) == 0);
SslProtocols protocols = sslAuthenticationOptions.EnabledSslProtocols & ~((SslProtocols)1);

if (!Interop.Ssl.Capabilities.Tls13Supported)
Expand Down Expand Up @@ -124,7 +135,7 @@ private static SslProtocols CalculateEffectiveProtocols(SslAuthenticationOptions
}

// This essentially wraps SSL_CTX* aka SSL_CTX_new + setting
internal static SafeSslContextHandle AllocateSslContext(SafeFreeSslCredentials credential, SslAuthenticationOptions sslAuthenticationOptions, SslProtocols protocols, bool enableResume)
internal static unsafe SafeSslContextHandle AllocateSslContext(SafeFreeSslCredentials credential, SslAuthenticationOptions sslAuthenticationOptions, SslProtocols protocols, bool enableResume)
{
SafeX509Handle? certHandle = credential.CertHandle;
SafeEvpPKeyHandle? certKeyHandle = credential.CertKeyHandle;
Expand Down Expand Up @@ -161,16 +172,13 @@ internal static SafeSslContextHandle AllocateSslContext(SafeFreeSslCredentials c

Debug.Assert(cipherSuites == null || (cipherSuites.Length >= 1 && cipherSuites[cipherSuites.Length - 1] == 0));

unsafe
fixed (byte* cipherListStr = cipherList)
fixed (byte* cipherSuitesStr = cipherSuites)
{
fixed (byte* cipherListStr = cipherList)
fixed (byte* cipherSuitesStr = cipherSuites)
if (!Ssl.SslCtxSetCiphers(sslCtx, cipherListStr, cipherSuitesStr))
{
if (!Ssl.SslCtxSetCiphers(sslCtx, cipherListStr, cipherSuitesStr))
{
Crypto.ErrClearError();
throw new PlatformNotSupportedException(SR.Format(SR.net_ssl_encryptionpolicy_notsupported, sslAuthenticationOptions.EncryptionPolicy));
}
Crypto.ErrClearError();
throw new PlatformNotSupportedException(SR.Format(SR.net_ssl_encryptionpolicy_notsupported, sslAuthenticationOptions.EncryptionPolicy));
}
}

Expand All @@ -183,14 +191,27 @@ internal static SafeSslContextHandle AllocateSslContext(SafeFreeSslCredentials c
// https://www.openssl.org/docs/manmaster/ssl/SSL_shutdown.html
Ssl.SslCtxSetQuietShutdown(sslCtx);

Ssl.SslCtxSetCaching(sslCtx, enableResume ? 1 : 0);

if (sslAuthenticationOptions.IsServer && sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0)
if (enableResume)
{
unsafe
if (sslAuthenticationOptions.IsServer)
{
Interop.Ssl.SslCtxSetAlpnSelectCb(sslCtx, &AlpnServerSelectCallback, IntPtr.Zero);
Ssl.SslCtxSetCaching(sslCtx, 1, null, null);
}
else
{
int result = Ssl.SslCtxSetCaching(sslCtx, 1, &NewSessionCallback, &RemoveSessionCallback);
Debug.Assert(result == 1);
sslCtx.EnableSessionCache();
}
}
else
{
Ssl.SslCtxSetCaching(sslCtx, 0, null, null);
}

if (sslAuthenticationOptions.IsServer && sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0)
{
Interop.Ssl.SslCtxSetAlpnSelectCb(sslCtx, &AlpnServerSelectCallback, IntPtr.Zero);
}

bool hasCertificateAndKey =
Expand Down Expand Up @@ -266,25 +287,61 @@ internal static SafeSslHandle AllocateSslHandle(SafeFreeSslCredentials credentia
SafeSslContextHandle? newCtxHandle = null;
SslProtocols protocols = CalculateEffectiveProtocols(sslAuthenticationOptions);
bool hasAlpn = sslAuthenticationOptions.ApplicationProtocols != null && sslAuthenticationOptions.ApplicationProtocols.Count != 0;
bool cacheSslContext = !DisableTlsResume && sslAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.RequireEncryption &&
sslAuthenticationOptions.IsServer &&
sslAuthenticationOptions.CertificateContext != null &&
sslAuthenticationOptions.CertificateContext.SslContexts != null &&
sslAuthenticationOptions.CipherSuitesPolicy == null;
bool cacheSslContext = !DisableTlsResume && sslAuthenticationOptions.EncryptionPolicy == EncryptionPolicy.RequireEncryption && sslAuthenticationOptions.CipherSuitesPolicy == null;

if (cacheSslContext)
{
if (sslAuthenticationOptions.IsClient)
{
// we don't want to try on emtpy TargetName since that is our key.
// And we don't want to mess up with client authentication. It may be possible
// but it seems safe to get full new session.
if (string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost) ||
sslAuthenticationOptions.CertificateContext != null ||
sslAuthenticationOptions.CertSelectionDelegate != null)
{
cacheSslContext = false;
}
}
else
{
// Server should always have certificate
Debug.Assert(sslAuthenticationOptions.CertificateContext != null);
if (sslAuthenticationOptions.CertificateContext == null ||
sslAuthenticationOptions.CertificateContext.SslContexts == null)
{
cacheSslContext = false;
}
}
}

if (cacheSslContext)
{
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryGetValue(protocols | (SslProtocols)(hasAlpn ? 1 : 0), out sslCtxHandle);
if (sslAuthenticationOptions.IsServer)
{
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryGetValue(protocols | (hasAlpn ? FakeAlpnSslProtocol : SslProtocols.None), out sslCtxHandle);
}
else
{

wfurt marked this conversation as resolved.
Show resolved Hide resolved
s_clientSslContexts.TryGetValue(protocols, out sslCtxHandle);
}
}

if (sslCtxHandle == null)
{
// We did not get SslContext from cache
sslCtxHandle = newCtxHandle = AllocateSslContext(credential, sslAuthenticationOptions, protocols, cacheSslContext);

if (cacheSslContext && sslAuthenticationOptions.CertificateContext!.SslContexts!.TryAdd(protocols | (SslProtocols)(hasAlpn ? 1 : 0), newCtxHandle))
if (cacheSslContext)
{
newCtxHandle = null;
bool added = sslAuthenticationOptions.IsServer ?
sslAuthenticationOptions.CertificateContext!.SslContexts!.TryAdd(protocols | (SslProtocols)(hasAlpn ? 1 : 0), newCtxHandle) :
s_clientSslContexts.TryAdd(protocols, newCtxHandle);
if (added)
{
newCtxHandle = null;
}
}
}

Expand All @@ -303,6 +360,7 @@ internal static SafeSslHandle AllocateSslHandle(SafeFreeSslCredentials credentia
{
if (sslAuthenticationOptions.IsServer)
{
Debug.Assert(Interop.Ssl.SslGetData(sslHandle) == IntPtr.Zero);
alpnHandle = GCHandle.Alloc(sslAuthenticationOptions.ApplicationProtocols);
Interop.Ssl.SslSetData(sslHandle, GCHandle.ToIntPtr(alpnHandle));
sslHandle.AlpnHandle = alpnHandle;
Expand All @@ -316,7 +374,7 @@ internal static SafeSslHandle AllocateSslHandle(SafeFreeSslCredentials credentia
}
}

if (!sslAuthenticationOptions.IsServer)
if (sslAuthenticationOptions.IsClient)
{
// The IdnMapping converts unicode input into the IDNA punycode sequence.
string punyCode = string.IsNullOrEmpty(sslAuthenticationOptions.TargetHost) ? string.Empty : s_idnMapping.GetAscii(sslAuthenticationOptions.TargetHost!);
Expand All @@ -327,6 +385,11 @@ internal static SafeSslHandle AllocateSslHandle(SafeFreeSslCredentials credentia
Crypto.ErrClearError();
}

if (cacheSslContext && !string.IsNullOrEmpty(punyCode))
{
sslCtxHandle.TrySetSession(sslHandle, punyCode);
}

// Set client cert callback, this will interrupt the handshake with SecurityStatusPalErrorCode.CredentialsNeeded
// if server actually requests a certificate.
Ssl.SslSetClientCertCallback(sslHandle, 1);
Expand Down Expand Up @@ -624,6 +687,59 @@ private static unsafe int AlpnServerSelectCallback(IntPtr ssl, byte** outp, byte
return Ssl.SSL_TLSEXT_ERR_ALERT_FATAL;
}

[UnmanagedCallersOnly]
// Invoked from OpenSSL when new session is created.
// We attached GCHandle to the SSL so we can find back SafeSslContextHandle holding the cache.
// New session ahs refCount of 1.
// If this function return 0, OpenSSL will drop the refCount and discard the session.
wfurt marked this conversation as resolved.
Show resolved Hide resolved
// If we return 1, the ownership is transfered to us and we will need to call SessionFree().
private static unsafe int NewSessionCallback(IntPtr ssl, IntPtr session)
{
Debug.Assert(ssl != IntPtr.Zero);
Debug.Assert(session != IntPtr.Zero);

IntPtr ptr = Ssl.SslGetData(ssl);
Debug.Assert(ptr != IntPtr.Zero);
GCHandle gch = GCHandle.FromIntPtr(ptr);

SafeSslContextHandle? ctxHandle = gch.Target as SafeSslContextHandle;
// There is no relation between SafeSslContextHandle and SafeSslHandle so the handle
// may be released while the ssl session is still active.
if (ctxHandle != null && ctxHandle.TryAddSession(Ssl.SslGetServerName(ssl), session))
{
// offered session was stored in our cache.
return 1;
}

// OpenSSL will destroy session.
return 0;
Comment on lines +725 to +729
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are these interesting from a logging perspective (if so, that can easily be in a future change)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking about logging but there are conditions that make it normal. If we hit this, there should be no functional change as we simply won't do the caching & resume.

}

[UnmanagedCallersOnly]
private static unsafe void RemoveSessionCallback(IntPtr ctx, IntPtr session)
{
Debug.Assert(ctx != IntPtr.Zero && session != IntPtr.Zero);

IntPtr ptr = Ssl.SslCtxGetData(ctx);
if (ptr == IntPtr.Zero)
{
// Same as above, SafeSslContextHandle could be released while OpenSSL still holds refferecne.
return;
}

GCHandle gch = GCHandle.FromIntPtr(ptr);
SafeSslContextHandle? ctxHandle = gch.Target as SafeSslContextHandle;
if (ctxHandle == null)
{
return;
}

//string? name = Marshal.PtrToStringAnsi(Ssl.SessionGetHostname(session));a
wfurt marked this conversation as resolved.
Show resolved Hide resolved
IntPtr name = Ssl.SessionGetHostname(session);
Debug.Assert(name != IntPtr.Zero);
ctxHandle.RemoveSession(name, session);
}

private static int BioRead(SafeBioHandle bio, byte[] buffer, int count)
{
Debug.Assert(buffer != null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ internal static partial class Ssl
[return: MarshalAs(UnmanagedType.Bool)]
internal static partial bool SslSetTlsExtHostName(SafeSslHandle ssl, string host);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetServerName")]
internal static unsafe partial IntPtr SslGetServerName(IntPtr ssl);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSetSession")]
internal static unsafe partial int SslSetSession(SafeSslHandle ssl, IntPtr session);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGet0AlpnSelected")]
internal static partial void SslGetAlpnSelected(SafeSslHandle ssl, out IntPtr protocol, out int len);

Expand Down Expand Up @@ -145,6 +151,9 @@ internal static partial class Ssl
[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetData")]
internal static partial IntPtr SslGetData(IntPtr ssl);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslGetData")]
internal static partial IntPtr SslGetData(SafeSslHandle ssl);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSetData")]
internal static partial int SslSetData(SafeSslHandle ssl, IntPtr data);

Expand All @@ -163,6 +172,18 @@ internal static partial class Ssl
[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_Tls13Supported")]
private static partial int Tls13SupportedImpl();

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionGetHostname")]
internal static partial IntPtr SessionGetHostname(IntPtr session);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionFree")]
internal static partial void SessionFree(IntPtr session);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionSetHostname", CharSet = CharSet.Ansi)]
internal static partial int SessionSetHostname(IntPtr session, string name);

[GeneratedDllImport(Libraries.CryptoNative, EntryPoint = "CryptoNative_SslSessionSetHostname")]
internal static partial int SessionSetHostname(IntPtr session, IntPtr name);

internal static class Capabilities
{
// needs separate type (separate static cctor) to be sure OpenSSL is initialized.
Expand Down
Loading