Skip to content

Commit

Permalink
Back port dotnet#1925
Browse files Browse the repository at this point in the history
  • Loading branch information
DavoudEshtehari committed Apr 26, 2023
1 parent acfdeca commit ad560a8
Showing 1 changed file with 136 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
using System.Collections.Concurrent;
using System.Linq;
using System.Security;
using System.Runtime.Caching;
using System.Security.Cryptography;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Identity.Client;
Expand All @@ -23,6 +26,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
/// </summary>
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
private static readonly int s_accountPwCacheTtlInHours = 2;
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
private static readonly string s_defaultScopeSuffix = "/.default";
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
Expand Down Expand Up @@ -101,7 +106,9 @@ public override void BeforeUnload(SqlAuthenticationMethod authentication)
/// <include file='../../../../../../doc/snippets/Microsoft.Data.SqlClient/ActiveDirectoryAuthenticationProvider.xml' path='docs/members[@name="ActiveDirectoryAuthenticationProvider"]/AcquireTokenAsync/*'/>
public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenticationParameters parameters) => Task.Run(async () =>
{
AuthenticationResult result;
CancellationTokenSource cts = new();
AuthenticationResult result = null;
string scope = parameters.Resource.EndsWith(s_defaultScopeSuffix) ? parameters.Resource : parameters.Resource + s_defaultScopeSuffix;
string[] scopes = new string[] { scope };
Expand Down Expand Up @@ -147,69 +154,84 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync().Result;
}
else
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
if (result == null)
{
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync().Result;
if (!string.IsNullOrEmpty(parameters.UserId))
{
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.WithUsername(parameters.UserId)
.ExecuteAsync(cancellationToken: cts.Token).Result;
}
else
{
result = app.AcquireTokenByIntegratedWindowsAuth(scopes)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync(cancellationToken: cts.Token).Result;
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result.ExpiresOn);
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result.ExpiresOn);
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
{
SecureString password = new SecureString();
foreach (char c in parameters.Password)
password.AppendChar(c);
password.MakeReadOnly();
result = app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync().Result;
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result.ExpiresOn);
string pwCacheKey = GetAccountPwCacheKey(parameters);
object previousPw = s_accountPwCache.Get(pwCacheKey);
byte[] currPwHash = GetHash(parameters.Password);
if (null != previousPw &&
previousPw is byte[] previousPwBytes &&
// Only get the cached token if the current password hash matches the previously used password hash
currPwHash.SequenceEqual(previousPwBytes))
{
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
}
if (result == null)
{
SecureString password = new SecureString();
foreach (char c in parameters.Password)
password.AppendChar(c);
password.MakeReadOnly();
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
.WithCorrelationId(parameters.ConnectionId)
.ExecuteAsync()
.ConfigureAwait(false);
// We cache the password hash to ensure future connection requests include a validated password
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
// against the same tenant could succeed with an invalid password when we re-use the cached token.
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
{
s_accountPwCache.Remove(pwCacheKey);
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
}
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result.ExpiresOn);
}
}
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
{
// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerable<IAccount> accounts = await app.GetAccountsAsync();
IAccount account;
if (!string.IsNullOrEmpty(parameters.UserId))
try
{
account = accounts.FirstOrDefault(a => parameters.UserId.Equals(a.Username, System.StringComparison.InvariantCultureIgnoreCase));
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
else
catch (MsalUiRequiredException)
{
account = accounts.FirstOrDefault();
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
if (null != account)
{
try
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync();
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
catch (MsalUiRequiredException)
{
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
// or the user needs to perform two factor authentication.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
}
else
if (result == null)
{
// If no existing 'account' is found, we request user to sign in interactively.
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod);
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result.ExpiresOn);
}
}
Expand All @@ -222,11 +244,58 @@ public override Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthentication
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
});

private static async Task<AuthenticationResult> TryAcquireTokenSilent(IPublicClientApplication app,
SqlAuthenticationParameters parameters,
string[] scopes,
CancellationTokenSource cts)
{
AuthenticationResult result = null;

// Fetch available accounts from 'app' instance
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();

IAccount account = default;
if (accounts.MoveNext())
{
if (!string.IsNullOrEmpty(parameters.UserId))
{
do
{
IAccount currentVal = accounts.Current;
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
{
account = currentVal;
break;
}
}
while (accounts.MoveNext());
}
else
{
account = accounts.Current;
}
}

if (null != account)
{
// If 'account' is available in 'app', we use the same to acquire token silently.
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
}

return result;
}

private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
SqlAuthenticationMethod authenticationMethod)
private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app,
string[] scopes,
Guid connectionId,
string userId,
SqlAuthenticationMethod authenticationMethod,
CancellationTokenSource cts,
ICustomWebUi customWebUI,
Func<DeviceCodeResult, Task> deviceCodeFlowCallback)
{
CancellationTokenSource cts = new CancellationTokenSource();
#if NETCOREAPP
/*
* On .NET Core, MSAL will start the system browser as a separate process. MSAL does not have control over this browser,
Expand All @@ -243,11 +312,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
{
if (authenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive)
{
if (_customWebUI != null)
if (customWebUI != null)
{
return await app.AcquireTokenInteractive(scopes)
.WithCorrelationId(connectionId)
.WithCustomWebUi(_customWebUI)
.WithCustomWebUi(customWebUI)
.WithLoginHint(userId)
.ExecuteAsync(cts.Token);
}
Expand Down Expand Up @@ -279,7 +348,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
else
{
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult)).ExecuteAsync();
deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult)).ExecuteAsync();
return result;
}
}
Expand Down Expand Up @@ -329,6 +398,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
return clientApplicationInstance;
}

private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
{
return parameters.Authority + "+" + parameters.UserId;
}

private static byte[] GetHash(string input)
{
byte[] unhashedBytes = Encoding.Unicode.GetBytes(input);
SHA256 sha256 = SHA256.Create();
byte[] hashedBytes = sha256.ComputeHash(unhashedBytes);
return hashedBytes;
}

private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
{
IPublicClientApplication publicClientApplication;
Expand Down

0 comments on commit ad560a8

Please sign in to comment.