Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix disposing root X.509 certificate prematurely for OCSP stapling #82116

Merged
merged 5 commits into from
Feb 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -854,8 +854,8 @@ internal static void BuildPrivatePki(
rootAuthority = new CertificateAuthority(
rootCert,
rootDistributionViaHttp ? certUrl : null,
issuerRevocationViaCrl ? cdpUrl : null,
issuerRevocationViaOcsp ? ocspUrl : null);
issuerRevocationViaCrl || (endEntityRevocationViaCrl && intermediateAuthorityCount == 0) ? cdpUrl : null,
vcsjones marked this conversation as resolved.
Show resolved Hide resolved
issuerRevocationViaOcsp || (endEntityRevocationViaOcsp && intermediateAuthorityCount == 0) ? ocspUrl : null);

CertificateAuthority issuingAuthority = rootAuthority;
intermediateAuthorities = new CertificateAuthority[intermediateAuthorityCount];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ partial void SetNoOcspFetch(bool noOcspFetch)
_staplingForbidden = noOcspFetch;
}

partial void AddRootCertificate(X509Certificate2? rootCertificate)
partial void AddRootCertificate(X509Certificate2? rootCertificate, ref bool transferredOwnership)
{
if (IntermediateCertificates.Length == 0)
{
_ca = rootCertificate;
transferredOwnership = true;
}
else
{
Expand Down Expand Up @@ -197,6 +198,17 @@ partial void AddRootCertificate(X509Certificate2? rootCertificate)

IntPtr subject = Certificate.Handle;
IntPtr issuer = caCert.Handle;
Debug.Assert(subject != 0);
Debug.Assert(issuer != 0);

// This should not happen - but in the event that it does, we can't give null pointers when building the
// request, so skip stapling, and set it as forbidden so we don't bother looking for new stapled responses
// in the future.
if (subject == 0 || issuer == 0)
{
_staplingForbidden = true;
return null;
}

using (SafeOcspRequestHandle ocspRequest = Interop.Crypto.X509BuildOcspRequest(subject, issuer))
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,13 @@ internal static SslStreamCertificateContext Create(
// Dispose the copy of the target cert.
chain.ChainElements[0].Certificate.Dispose();

// Dispose the last cert, if we didn't include it.
for (int i = count + 1; i < chain.ChainElements.Count; i++)
// Dispose of the certificates that we do not need. If we are holding on to the root,
// don't dispose of it.
int stopDisposingChainPosition = root is null ?
chain.ChainElements.Count :
chain.ChainElements.Count - 1;

for (int i = count + 1; i < stopDisposingChainPosition; i++)
{
chain.ChainElements[i].Certificate.Dispose();
}
Expand All @@ -109,12 +114,19 @@ internal static SslStreamCertificateContext Create(
// On Linux, AddRootCertificate will start a background download of an OCSP response,
// unless this context was built "offline", or this came from the internal Create(X509Certificate2)
ctx.SetNoOcspFetch(offline || noOcspFetch);
ctx.AddRootCertificate(root);

bool transferredOwnership = false;
ctx.AddRootCertificate(root, ref transferredOwnership);

if (!transferredOwnership)
{
root?.Dispose();
vcsjones marked this conversation as resolved.
Show resolved Hide resolved
}

return ctx;
}

partial void AddRootCertificate(X509Certificate2? rootCertificate);
partial void AddRootCertificate(X509Certificate2? rootCertificate, ref bool transferredOwnership);
partial void SetNoOcspFetch(bool noOcspFetch);

internal SslStreamCertificateContext Duplicate()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,17 +106,19 @@ public Task ConnectWithRevocation_WithCallback(bool checkRevocation)
[PlatformSpecific(TestPlatforms.Linux)]
[ConditionalTheory]
[OuterLoop("Subject to system load race conditions")]
[InlineData(false)]
[InlineData(true)]
public Task ConnectWithRevocation_StapledOcsp(bool offlineContext)
[InlineData(false, false)]
[InlineData(true, false)]
[InlineData(false, true)]
[InlineData(true, true)]
public Task ConnectWithRevocation_StapledOcsp(bool offlineContext, bool noIntermediates)
{
// Offline will only work if
// a) the revocation has been checked recently enough that it is cached, or
// b) the server stapled the response
//
// At high load, the server's background fetch might not have completed before
// this test runs.
return ConnectWithRevocation_WithCallback_Core(X509RevocationMode.Offline, offlineContext);
return ConnectWithRevocation_WithCallback_Core(X509RevocationMode.Offline, offlineContext, noIntermediates);
}

[Fact]
Expand Down Expand Up @@ -192,7 +194,8 @@ static bool CertificateValidationCallback(

private async Task ConnectWithRevocation_WithCallback_Core(
X509RevocationMode revocationMode,
bool? offlineContext = false)
bool? offlineContext = false,
bool noIntermediates = false)
{
string offlinePart = offlineContext.HasValue ? offlineContext.GetValueOrDefault().ToString().ToLower() : "null";
string serverName = $"{revocationMode.ToString().ToLower()}.{offlinePart}.server.example";
Expand All @@ -203,13 +206,15 @@ private async Task ConnectWithRevocation_WithCallback_Core(
PkiOptions.EndEntityRevocationViaOcsp | PkiOptions.CrlEverywhere,
out RevocationResponder responder,
out CertificateAuthority rootAuthority,
out CertificateAuthority intermediateAuthority,
out CertificateAuthority[] intermediateAuthorities,
out X509Certificate2 serverCert,
intermediateAuthorityCount: noIntermediates ? 0 : 1,
subjectName: serverName,
keySize: 2048,
extensions: TestHelper.BuildTlsServerCertExtensions(serverName));

X509Certificate2 issuerCert = intermediateAuthority.CloneIssuerCert();
CertificateAuthority issuingAuthority = noIntermediates ? rootAuthority : intermediateAuthorities[0];
X509Certificate2 issuerCert = issuingAuthority.CloneIssuerCert();
X509Certificate2 rootCert = rootAuthority.CloneIssuerCert();

SslClientAuthenticationOptions clientOpts = new SslClientAuthenticationOptions
Expand Down Expand Up @@ -243,71 +248,80 @@ private async Task ConnectWithRevocation_WithCallback_Core(
serverCert = temp;
}

await using (clientStream)
await using (serverStream)
using (responder)
using (rootAuthority)
using (intermediateAuthority)
using (serverCert)
using (issuerCert)
using (rootCert)
await using (SslStream tlsClient = new SslStream(clientStream))
await using (SslStream tlsServer = new SslStream(serverStream))
try
{
vcsjones marked this conversation as resolved.
Show resolved Hide resolved
intermediateAuthority.Revoke(serverCert, serverCert.NotBefore);

SslServerAuthenticationOptions serverOpts = new SslServerAuthenticationOptions();

if (offlineContext.HasValue)
await using (clientStream)
await using (serverStream)
using (responder)
using (rootAuthority)
using (serverCert)
using (issuerCert)
using (rootCert)
await using (SslStream tlsClient = new SslStream(clientStream))
await using (SslStream tlsServer = new SslStream(serverStream))
{
serverOpts.ServerCertificateContext = SslStreamCertificateContext.Create(
serverCert,
new X509Certificate2Collection(issuerCert),
offlineContext.GetValueOrDefault());
issuingAuthority.Revoke(serverCert, serverCert.NotBefore);

if (revocationMode == X509RevocationMode.Offline)
SslServerAuthenticationOptions serverOpts = new SslServerAuthenticationOptions();

if (offlineContext.HasValue)
{
if (offlineContext.GetValueOrDefault(false))
{
// Add a delay just to show we're not winning because of race conditions.
await Task.Delay(200);
}
else
serverOpts.ServerCertificateContext = SslStreamCertificateContext.Create(
serverCert,
new X509Certificate2Collection(issuerCert),
offlineContext.GetValueOrDefault());

if (revocationMode == X509RevocationMode.Offline)
{
if (!OperatingSystem.IsLinux())
if (offlineContext.GetValueOrDefault(false))
{
throw new InvalidOperationException(
"This test configuration uses reflection and is only defined for Linux.");
// Add a delay just to show we're not winning because of race conditions.
await Task.Delay(200);
}

FieldInfo pendingDownloadTaskField = typeof(SslStreamCertificateContext).GetField(
"_pendingDownload",
BindingFlags.Instance | BindingFlags.NonPublic);

if (pendingDownloadTaskField is null)
else
{
throw new InvalidOperationException("Cannot find the pending download field.");
}

Task download = (Task)pendingDownloadTaskField.GetValue(serverOpts.ServerCertificateContext);

// If it's null, it should mean it has already finished. If not, it might not have.
if (download is not null)
{
await download;
if (!OperatingSystem.IsLinux())
{
throw new InvalidOperationException(
"This test configuration uses reflection and is only defined for Linux.");
}

FieldInfo pendingDownloadTaskField = typeof(SslStreamCertificateContext).GetField(
"_pendingDownload",
BindingFlags.Instance | BindingFlags.NonPublic);

if (pendingDownloadTaskField is null)
{
throw new InvalidOperationException("Cannot find the pending download field.");
}

Task download = (Task)pendingDownloadTaskField.GetValue(serverOpts.ServerCertificateContext);

// If it's null, it should mean it has already finished. If not, it might not have.
if (download is not null)
{
await download;
}
}
}
}
else
{
serverOpts.ServerCertificate = serverCert;
}

Task serverTask = tlsServer.AuthenticateAsServerAsync(serverOpts);
Task clientTask = tlsClient.AuthenticateAsClientAsync(clientOpts);

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(clientTask, serverTask);
}
else
}
finally
{
foreach (CertificateAuthority intermediateAuthority in intermediateAuthorities)
{
serverOpts.ServerCertificate = serverCert;
intermediateAuthority.Dispose();
}

Task serverTask = tlsServer.AuthenticateAsServerAsync(serverOpts);
Task clientTask = tlsClient.AuthenticateAsClientAsync(clientOpts);

await TestConfiguration.WhenAllOrAnyFailedWithTimeout(clientTask, serverTask);
}

static bool CertificateValidationCallback(
Expand Down