Skip to content

Commit

Permalink
Add timestamp-based expiration to cached SafeFreeCredentials (#66334)
Browse files Browse the repository at this point in the history
* Add Expiry timestamp on SafeFreeCredentials handle

* Recalculate expiration timestamp based on CertificateContext

* Fix case when user provides CertificateContext
  • Loading branch information
rzikm committed Mar 24, 2022
1 parent 7d1191e commit e97af55
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,22 @@ internal abstract class SafeFreeCredentials : SafeHandle
{
#endif

internal DateTime _expiry;
internal Interop.SspiCli.CredHandle _handle; //should be always used as by ref in PInvokes parameters

protected SafeFreeCredentials() : base(IntPtr.Zero, true)
{
_handle = default;
_expiry = DateTime.MaxValue;
}

public override bool IsInvalid
{
get { return IsClosed || _handle.IsZero; }
}

public DateTime Expiry => _expiry;

#if DEBUG
public new IntPtr DangerousGetHandle()
{
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.

using System;
using System.Diagnostics;
using System.Runtime.ConstrainedExecution;
using System.Runtime.InteropServices;
Expand All @@ -18,8 +19,13 @@ internal abstract class SafeFreeCredentials : DebugSafeHandle
internal abstract class SafeFreeCredentials : SafeHandle
{
#endif
internal DateTime _expiry;

public DateTime Expiry => _expiry;

protected SafeFreeCredentials(IntPtr handle, bool ownsHandle) : base(handle, ownsHandle)
{
_expiry = DateTime.MaxValue;
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ internal void Close()
{
if (!_remoteCertificateExposed)
{
_remoteCertificate?.Dispose();
_remoteCertificate = null;
_remoteCertificate?.Dispose();
_remoteCertificate = null;
}

_securityContext?.Dispose();
Expand Down Expand Up @@ -607,9 +607,7 @@ private bool AcquireClientCredentials(ref byte[]? thumbPrint)
_sslAuthenticationOptions.CertificateContext = SslStreamCertificateContext.Create(selectedCert!);
}

_credentialsHandle = SslStreamPal.AcquireCredentialsHandle(_sslAuthenticationOptions.CertificateContext,
_sslAuthenticationOptions.EnabledSslProtocols, _sslAuthenticationOptions.EncryptionPolicy, _sslAuthenticationOptions.IsServer);

_credentialsHandle = AcquireCredentialsHandle(_sslAuthenticationOptions);
thumbPrint = guessedThumbPrint; // Delay until here in case something above threw.
}
}
Expand Down Expand Up @@ -713,14 +711,65 @@ private bool AcquireServerCredentials(ref byte[]? thumbPrint)
}
else
{
_credentialsHandle = SslStreamPal.AcquireCredentialsHandle(_sslAuthenticationOptions.CertificateContext, _sslAuthenticationOptions.EnabledSslProtocols,
_sslAuthenticationOptions.EncryptionPolicy, _sslAuthenticationOptions.IsServer);
_credentialsHandle = AcquireCredentialsHandle(_sslAuthenticationOptions);
thumbPrint = guessedThumbPrint;
}

return cachedCred;
}

private static SafeFreeCredentials AcquireCredentialsHandle(SslAuthenticationOptions sslAuthenticationOptions)
{
SafeFreeCredentials cred = SslStreamPal.AcquireCredentialsHandle(sslAuthenticationOptions.CertificateContext, sslAuthenticationOptions.EnabledSslProtocols,
sslAuthenticationOptions.EncryptionPolicy, sslAuthenticationOptions.IsServer);

if (sslAuthenticationOptions.CertificateContext != null)
{
//
// Since the SafeFreeCredentials can be cached and reused, it may happen on long running processes that some cert on
// the chain expires and all subsequent connections would send expired intermediate certificates. Find the earliest
// NotAfter timestamp on the chain and use it as expiration timestamp for the credentials.
// This provides an opportunity to recreate the credentials with an alternative (and still valid)
// certificate chain.
//
SslStreamCertificateContext certificateContext = sslAuthenticationOptions.CertificateContext;
cred._expiry = GetExpiryTimestamp(certificateContext);

if (cred._expiry < DateTime.UtcNow)
{
//
// The CertificateContext from auth options is recreated just before creating the SafeFreeCredentials. However, in case when
// it was provided by the user code, it may still contain the (now expired) certificate chain. Such expiration timestamp would
// effectively disable caching as it would lead to creating new credentials for each connection. We attempt to recover by creating
// a temporary certificate context (which builds a new chain with hopefully more recent chain).
//
certificateContext = SslStreamCertificateContext.Create(
certificateContext.Certificate,
new X509Certificate2Collection(certificateContext.IntermediateCertificates),
trust: certificateContext.Trust);

cred._expiry = GetExpiryTimestamp(certificateContext);
}

static DateTime GetExpiryTimestamp(SslStreamCertificateContext certificateContext)
{
DateTime expiry = certificateContext.Certificate.NotAfter;

foreach (X509Certificate2 cert in certificateContext.IntermediateCertificates)
{
if (cert.NotAfter < expiry)
{
expiry = cert.NotAfter;
}
}

return expiry.ToUniversalTime();
}
}

return cred;
}

//
internal ProtocolToken NextMessage(ReadOnlySpan<byte> incomingBuffer)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,9 +134,9 @@ public bool Equals(SslCredKey other)

//SafeCredentialReference? cached;
SafeFreeCredentials? credentials = GetCachedCredential(key);
if (credentials == null || credentials.IsClosed || credentials.IsInvalid)
if (credentials == null || credentials.IsClosed || credentials.IsInvalid || credentials.Expiry < DateTime.UtcNow)
{
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Not found or invalid, Current Cache Coun = {s_cachedCreds.Count}");
if (NetEventSource.Log.IsEnabled()) NetEventSource.Info(null, $"Not found or invalid, Current Cache Count = {s_cachedCreds.Count}");
return null;
}

Expand Down Expand Up @@ -169,12 +169,13 @@ internal static void CacheCredential(SafeFreeCredentials creds, byte[]? thumbPri

SafeFreeCredentials? credentials = GetCachedCredential(key);

if (credentials == null || credentials.IsClosed || credentials.IsInvalid)
DateTime utcNow = DateTime.UtcNow;
if (credentials == null || credentials.IsClosed || credentials.IsInvalid || credentials.Expiry < utcNow)
{
lock (s_cachedCreds)
{
credentials = GetCachedCredential(key);
if (credentials == null || credentials.IsClosed || credentials.IsInvalid)
if (credentials == null || credentials.IsClosed || credentials.IsInvalid || credentials.Expiry < utcNow)
{
SafeCredentialReference? cached = SafeCredentialReference.CreateReference(creds);

Expand Down

0 comments on commit e97af55

Please sign in to comment.