Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the option to have UserAssignedMsi when authentication uses KeyVault certificate #14676

Merged
merged 1 commit into from
Sep 12, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,11 @@ private enum CertIdentifierType
/// <param name="certIdentifierType"></param>
/// <returns></returns>
[Theory]
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier)]
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier, false)]
[InlineData(CertIdentifierType.KeyVaultCertificateSecretIdentifier, true)]
[InlineData(CertIdentifierType.SubjectName)]
[InlineData(CertIdentifierType.Thumbprint)]
private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType certIdentifierType)
private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType certIdentifierType, bool useUserAssignedMsi = false)
{
string testCertUrl = Environment.GetEnvironmentVariable(Constants.TestCertUrlEnv);

Expand Down Expand Up @@ -208,7 +209,9 @@ private async Task GetTokenUsingServicePrincipalWithCertTest(CertIdentifierType
connectionString = $"RunAs=App;AppId={app.AppId};TenantId={_tenantId};{thumbprintOrSubjectName};CertificateStoreLocation={Constants.CurrentUserStore};";
break;
case CertIdentifierType.KeyVaultCertificateSecretIdentifier:
connectionString = $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};";
connectionString = useUserAssignedMsi
? $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};KeyVaultUserAssignedManagedIdentityId={Constants.TestUserAssignedManagedIdentityId}" //TODO: figure out real MSI to use here. Also, does the test really use MSI or does it rely on the fallback?
: $"RunAs=App;AppId={app.AppId};KeyVaultCertificateSecretIdentifier={testCertUrl};";
break;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ public class Constants
public static readonly string CertificateConnStringThumbprintCurrentUser = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};CertificateThumbprint=123;CertificateStoreLocation=CurrentUser";
public static readonly string CertificateConnStringSubjectNameCurrentUser = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};CertificateSubjectName=123;CertificateStoreLocation=CurrentUser";
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifier = $"RunAs=App;AppId={TestAppId};KeyVaultCertificateSecretIdentifier=SecretIdentifier";
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi = $"RunAs=App;AppId={TestAppId};KeyVaultCertificateSecretIdentifier=SecretIdentifier;KeyVaultAppId={TestUserAssignedManagedIdentityId}";
public static readonly string CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};KeyVaultCertificateSecretIdentifier=SecretIdentifier";
public static readonly string ClientSecretConnString = $"RunAs=App;AppId={TestAppId};TenantId={TenantId};AppKey={ClientSecret}";
public static readonly string ConnectionStringEnvironmentVariableName = "AzureServicesAuthConnectionString";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,11 @@ public void CertValidTest()
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifier, provider.ConnectionString);
Assert.IsType<ClientCertificateAzureServiceTokenProvider>(provider);

provider = AzureServiceTokenProviderFactory.Create(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi, Constants.AzureAdInstance);
Assert.NotNull(provider);
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierUserAssignedMsi, provider.ConnectionString);
Assert.IsType<ClientCertificateAzureServiceTokenProvider>(provider);

provider = AzureServiceTokenProviderFactory.Create(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId, Constants.AzureAdInstance);
Assert.NotNull(provider);
Assert.Equal(Constants.CertificateConnStringKeyVaultCertificateSecretIdentifierWithOptionalTenantId, provider.ConnectionString);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public async Task ThumbprintSuccessTest()

// Create ClientCertificateAzureServiceTokenProvider instance
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on thumbprint in the connection string.
var authResult = await provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId).ConfigureAwait(false);
Expand All @@ -64,7 +64,7 @@ public async Task ThumbprintFailTest()

// Create ClientCertificateAzureServiceTokenProvider instance
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

// Ensure exception is thrown when getting the token
var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId));
Expand All @@ -89,12 +89,12 @@ public void ClientIdNullOrEmptyTest()

// Create ClientCertificateAzureServiceTokenProvider instance
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(null,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());

exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(string.Empty,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());
}
Expand All @@ -114,12 +114,12 @@ public void StoreLocationNullOrEmptyTest()

// Create ClientCertificateAzureServiceTokenProvider instance
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, null, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, null, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());

exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, string.Empty, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, string.Empty, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());
}
Expand All @@ -135,12 +135,12 @@ public void CertSubjectNameOrThumbprintNullOrEmptyTest()

// Create ClientCertificateAzureServiceTokenProvider instance
var exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
null, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
null, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());

exception = Assert.Throws<ArgumentNullException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
string.Empty, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
string.Empty, CertificateIdentifierType.Thumbprint, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.CannotBeNullError, exception.ToString());
}
Expand All @@ -160,7 +160,7 @@ public void InvalidStoreLocationTest()

// Create ClientCertificateAzureServiceTokenProvider instance
var exception = Assert.Throws<ArgumentException>(() => new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.InvalidString, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext));
cert.Thumbprint, CertificateIdentifierType.Thumbprint, Constants.InvalidString, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext));

Assert.Contains(Constants.InvalidCertLocationError, exception.ToString());
}
Expand All @@ -177,7 +177,7 @@ public async Task SubjectNameSuccessTest()

// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
var authResult = await provider.GetAuthResultAsync(Constants.KeyVaultResourceId, string.Empty).ConfigureAwait(false);
Expand All @@ -204,7 +204,7 @@ public void CannotAcquireTokenThroughCertTest()

// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
cert.Subject, CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
var exception = Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, string.Empty));
Expand All @@ -226,7 +226,7 @@ public async Task CertificateNotFoundTest()
MockAuthenticationContext mockAuthenticationContext = new MockAuthenticationContext(MockAuthenticationContext.MockAuthenticationContextTestType.AcquireTokenAsyncClientCertificateSuccess);

ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
Guid.NewGuid().ToString(), CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext);
Guid.NewGuid().ToString(), CertificateIdentifierType.SubjectName, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext);

var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => Task.Run(() => provider.GetAuthResultAsync(Constants.KeyVaultResourceId, Constants.TenantId)));

Expand Down Expand Up @@ -257,7 +257,7 @@ public async Task KeyVaultCertificateSecretIdentifierSuccessTest(bool includeTen

// Create ClientCertificateAzureServiceTokenProvider instance with a subject name
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
Constants.TestKeyVaultCertificateSecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, null, Constants.AzureAdInstance, tenantIdParam, 0, mockAuthenticationContext, keyVaultClient);
Constants.TestKeyVaultCertificateSecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, null, Constants.AzureAdInstance, tenantIdParam, 0, authenticationContext: mockAuthenticationContext, keyVaultClient: keyVaultClient);

// Get the token. This will test that ClientCertificateAzureServiceTokenProvider could fetch the cert from CurrentUser store based on subject name in the connection string.
var authResult = await provider.GetAuthResultAsync(Constants.ArmResourceId, string.Empty).ConfigureAwait(false);
Expand All @@ -283,7 +283,7 @@ public async Task KeyVaultCertificateNotFoundTest()

string SecretIdentifier = "https://testbedkeyvault.vault.azure.net/secrets/secret/";
ClientCertificateAzureServiceTokenProvider provider = new ClientCertificateAzureServiceTokenProvider(Constants.TestAppId,
SecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, mockAuthenticationContext, keyVaultClient);
SecretIdentifier, CertificateIdentifierType.KeyVaultCertificateSecretIdentifier, Constants.CurrentUserStore, Constants.AzureAdInstance, Constants.TenantId, 0, authenticationContext: mockAuthenticationContext, keyVaultClient: keyVaultClient);

var exception = await Assert.ThrowsAsync<AzureServiceTokenProviderException>(() => Task.Run(() => provider.GetAuthResultAsync(Constants.ArmResourceId, Constants.TenantId)));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ internal class AzureServiceTokenProviderFactory
private const string CertificateSubjectName = "CertificateSubjectName";
private const string CertificateThumbprint = "CertificateThumbprint";
private const string KeyVaultCertificateSecretIdentifier = "KeyVaultCertificateSecretIdentifier";
private const string KeyVaultUserAssignedManagedIdentityId = "KeyVaultUserAssignedManagedIdentityId";
private const string CertificateStoreLocation = "CertificateStoreLocation";
private const string MsiRetryTimeout = "MsiRetryTimeout";

Expand Down Expand Up @@ -125,7 +126,7 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
azureAdInstance,
connectionSettings[TenantId],
0,
new AdalAuthenticationContext(httpClientFactory));
authenticationContext: new AdalAuthenticationContext(httpClientFactory));
}
else if (connectionSettings.ContainsKey(CertificateThumbprint) ||
connectionSettings.ContainsKey(CertificateSubjectName))
Expand All @@ -138,6 +139,11 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
{
ValidateMsiRetryTimeout(connectionSettings, connectionString);

var msiRetryTimeout = connectionSettings.ContainsKey(MsiRetryTimeout)
? int.Parse(connectionSettings[MsiRetryTimeout])
: 0;
connectionSettings.TryGetValue(KeyVaultUserAssignedManagedIdentityId, out var keyVaultUserAssignedManagedIdentityId);

azureServiceTokenProvider =
new ClientCertificateAzureServiceTokenProvider(
connectionSettings[AppId],
Expand All @@ -148,9 +154,8 @@ internal static NonInteractiveAzureServiceTokenProviderBase Create(string connec
connectionSettings.ContainsKey(TenantId) // tenantId can be specified in connection string or retrieved from Key Vault access token later
? connectionSettings[TenantId]
: default,
connectionSettings.ContainsKey(MsiRetryTimeout)
? int.Parse(connectionSettings[MsiRetryTimeout])
: 0,
msiRetryTimeout,
keyVaultUserAssignedManagedIdentityId,
new AdalAuthenticationContext(httpClientFactory));
}
else if (connectionSettings.ContainsKey(AppKey))
Expand Down
Loading