Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TokenExchangeManagedIdentitySource with async IO #38939

Merged
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions sdk/identity/Azure.Identity/src/MsalConfidentialClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -104,11 +104,19 @@ protected virtual async ValueTask<IConfidentialClientApplication> CreateClientCo

if (_assertionCallback != null)
{
if (_asyncAssertionCallback != null)
{
throw new InvalidOperationException($"Cannot set both {nameof(_assertionCallback)} and {nameof(_asyncAssertionCallback)}");
}
confClientBuilder.WithClientAssertion(_assertionCallback);
}

if (_asyncAssertionCallback != null)
{
if (_assertionCallback != null)
{
throw new InvalidOperationException($"Cannot set both {nameof(_assertionCallback)} and {nameof(_asyncAssertionCallback)}");
}
confClientBuilder.WithClientAssertion(_asyncAssertionCallback);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Text;
Expand All @@ -15,12 +16,13 @@ internal class TokenExchangeManagedIdentitySource : ManagedIdentitySource
{
private TokenFileCache _tokenFileCache;
private ClientAssertionCredential _clientAssertionCredential;
private static readonly int DefaultBufferSize = 4096;

private TokenExchangeManagedIdentitySource(CredentialPipeline pipeline, string tenantId, string clientId, string tokenFilePath)
: base(pipeline)
{
_tokenFileCache = new TokenFileCache(tokenFilePath);
_clientAssertionCredential = new ClientAssertionCredential(tenantId, clientId, _tokenFileCache.GetTokenFileContents, new ClientAssertionCredentialOptions { Pipeline = pipeline });
_clientAssertionCredential = new ClientAssertionCredential(tenantId, clientId, _tokenFileCache.GetTokenFileContentsAsync, new ClientAssertionCredentialOptions { Pipeline = pipeline });
}

public static ManagedIdentitySource TryCreate(ManagedIdentityClientOptions options)
Expand All @@ -47,13 +49,10 @@ protected override Request CreateRequest(string[] scopes)
throw new NotImplementedException();
}

// Ideally this class should handle I/O asynchronously, and have a design similar to AccessTokenCache in BearerTokenAuthenticationPolicy.
// However, MSAL currently only accepts sync callbacks for client assertions so this has been radically simplified in light of this. If MSAL
// were to add support for an async callback we should update this accordingly.
// See, https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/issues/2863
private class TokenFileCache
{
private readonly object _lock = new object();
christothes marked this conversation as resolved.
Show resolved Hide resolved
private static SemaphoreSlim semaphore = new SemaphoreSlim(1, 1);
private readonly string _tokenFilePath;
private string _tokenFileContents;
private DateTimeOffset _refreshOn = DateTimeOffset.MinValue;
Expand All @@ -63,23 +62,84 @@ public TokenFileCache(string tokenFilePath)
_tokenFilePath = tokenFilePath;
}

public string GetTokenFileContents()
public async Task<string> GetTokenFileContentsAsync(CancellationToken cancellationToken)
{
if (_refreshOn <= DateTimeOffset.UtcNow)
{
lock (_lock)
try
{
if (_refreshOn <= DateTimeOffset.UtcNow)
await semaphore.WaitAsync(cancellationToken).ConfigureAwait(false);
{
_tokenFileContents = File.ReadAllText(_tokenFilePath);
if (_refreshOn <= DateTimeOffset.UtcNow)
{
_tokenFileContents = await ReadAllTextAsync(_tokenFilePath).ConfigureAwait(false);

_refreshOn = DateTimeOffset.UtcNow.AddMinutes(5);
_refreshOn = DateTimeOffset.UtcNow.AddMinutes(5);
}
}
}
finally
{
semaphore.Release();
}
}

return _tokenFileContents;
}
}

// Since File.ReadAllTextAsync is not available in netstandard2.0, the below implementation is borrowed with some modifications from
// https://github.com/dotnet/runtime/blob/8bcd03c650a85d523d542715e4e2543251f1dfa5/src/libraries/System.Private.CoreLib/src/System/IO/File.cs#L863-L906
internal static Task<string> ReadAllTextAsync(string path, CancellationToken cancellationToken = default)
=> ReadAllTextAsync(path, Encoding.UTF8, cancellationToken);

internal static Task<string> ReadAllTextAsync(string path, Encoding encoding, CancellationToken cancellationToken = default(CancellationToken))
{
Argument.AssertNotNullOrEmpty(path, nameof(path));
Argument.AssertNotNull(encoding, nameof(encoding));

return cancellationToken.IsCancellationRequested
? Task.FromCanceled<string>(cancellationToken)
: InternalReadAllTextAsync(path, encoding, cancellationToken);
}

private static async Task<string> InternalReadAllTextAsync(string path, Encoding encoding, CancellationToken cancellationToken)
{
char[] buffer = null;
StreamReader sr = AsyncStreamReader(path, encoding);
try
{
cancellationToken.ThrowIfCancellationRequested();
buffer = ArrayPool<char>.Shared.Rent(sr.CurrentEncoding.GetMaxCharCount(DefaultBufferSize));
StringBuilder sb = new StringBuilder();
int totalRead = 0;
while (true)
{
int read = await sr.ReadAsync(buffer, totalRead, DefaultBufferSize - totalRead).ConfigureAwait(false);
if (read == 0)
{
return sb.ToString();
}

sb.Append(buffer, 0, read);
totalRead += read;
}
}
finally
{
sr.Dispose();
if (buffer != null)
{
ArrayPool<char>.Shared.Return(buffer);
}
}
}

// If we use the path-taking constructors, we won't have FileOptions.Asynchronous set and
// we will have asynchronous file access faked by the thread pool. We want the real thing.
private static StreamReader AsyncStreamReader(string path, Encoding encoding)
=> new StreamReader(
new FileStream(path, FileMode.Open, FileAccess.Read, FileShare.Read, DefaultBufferSize, FileOptions.Asynchronous | FileOptions.SequentialScan),
encoding, detectEncodingFromByteOrderMarks: true);
}
}