From 1a1467c1ce7e3cf65125cef7fbfeb8481e6ead40 Mon Sep 17 00:00:00 2001 From: Christopher Scott Date: Mon, 8 Apr 2024 14:11:42 -0500 Subject: [PATCH] Increase IMDS retry count to 5 (#43249) --- eng/Packages.Data.props | 6 ++--- .../Azure.Identity/src/CredentialPipeline.cs | 24 +++++++++++++++++-- .../Credentials/ManagedIdentityCredential.cs | 4 ++-- .../src/DefaultAzureCredentialFactory.cs | 5 +--- .../tests/ImdsManagedIdentitySourceTests.cs | 2 +- .../tests/ManagedIdentityCredentialTests.cs | 4 ++-- 6 files changed, 31 insertions(+), 14 deletions(-) diff --git a/eng/Packages.Data.props b/eng/Packages.Data.props index 03a060b94889..7b111debb1c9 100644 --- a/eng/Packages.Data.props +++ b/eng/Packages.Data.props @@ -153,13 +153,13 @@ - - + + - + diff --git a/sdk/identity/Azure.Identity/src/CredentialPipeline.cs b/sdk/identity/Azure.Identity/src/CredentialPipeline.cs index f6d0a1b15fbd..0b7cfa3307aa 100644 --- a/sdk/identity/Azure.Identity/src/CredentialPipeline.cs +++ b/sdk/identity/Azure.Identity/src/CredentialPipeline.cs @@ -26,9 +26,29 @@ public CredentialPipeline(HttpPipeline httpPipeline, ClientDiagnostics diagnosti Diagnostics = diagnostics; } - public static CredentialPipeline GetInstance(TokenCredentialOptions options) + public static CredentialPipeline GetInstance(TokenCredentialOptions options, bool IsManagedIdentityCredential = false) { - return options is null ? s_singleton.Value : new CredentialPipeline(options); + return options switch + { + _ when IsManagedIdentityCredential => configureOptionsForManagedIdentity(options), + not null => new CredentialPipeline(options), + _ => s_singleton.Value, + + }; + } + + private static CredentialPipeline configureOptionsForManagedIdentity(TokenCredentialOptions options) + { + var clonedOptions = options switch + { + DefaultAzureCredentialOptions dac => dac.Clone(), + _ => options?.Clone() ?? new TokenCredentialOptions(), + }; + // Set the custom retry policy + clonedOptions.Retry.MaxRetries = 5; + clonedOptions.RetryPolicy ??= new DefaultAzureCredentialImdsRetryPolicy(clonedOptions.Retry); + clonedOptions.IsChainedCredential = clonedOptions is DefaultAzureCredentialOptions; + return new CredentialPipeline(clonedOptions); } public HttpPipeline HttpPipeline { get; } diff --git a/sdk/identity/Azure.Identity/src/Credentials/ManagedIdentityCredential.cs b/sdk/identity/Azure.Identity/src/Credentials/ManagedIdentityCredential.cs index 8aab9e375bb2..d91cb589133d 100644 --- a/sdk/identity/Azure.Identity/src/Credentials/ManagedIdentityCredential.cs +++ b/sdk/identity/Azure.Identity/src/Credentials/ManagedIdentityCredential.cs @@ -41,7 +41,7 @@ protected ManagedIdentityCredential() /// /// Options to configure the management of the requests sent to Microsoft Entra ID. public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions options = null) - : this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ClientId = clientId, Pipeline = CredentialPipeline.GetInstance(options), Options = options })) + : this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ClientId = clientId, Pipeline = CredentialPipeline.GetInstance(options, IsManagedIdentityCredential: true), Options = options })) { _logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false; } @@ -55,7 +55,7 @@ public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions /// /// Options to configure the management of the requests sent to Microsoft Entra ID. public ManagedIdentityCredential(ResourceIdentifier resourceId, TokenCredentialOptions options = null) - : this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ResourceIdentifier = resourceId, Pipeline = CredentialPipeline.GetInstance(options), Options = options })) + : this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ResourceIdentifier = resourceId, Pipeline = CredentialPipeline.GetInstance(options, IsManagedIdentityCredential: true), Options = options })) { _logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false; _clientId = resourceId.ToString(); diff --git a/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs b/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs index 2d89ce61719c..2d6817a97c6b 100644 --- a/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs +++ b/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs @@ -121,16 +121,13 @@ public virtual TokenCredential CreateWorkloadIdentityCredential() public virtual TokenCredential CreateManagedIdentityCredential() { var options = Options.Clone(); - // Set the custom retry policy - options.Retry.MaxRetries = 4; - options.RetryPolicy ??= new DefaultAzureCredentialImdsRetryPolicy(options.Retry); options.IsChainedCredential = true; var miOptions = new ManagedIdentityClientOptions { ResourceIdentifier = options.ManagedIdentityResourceId, ClientId = options.ManagedIdentityClientId, - Pipeline = CredentialPipeline.GetInstance(options), + Pipeline = CredentialPipeline.GetInstance(options, IsManagedIdentityCredential: true), Options = options, InitialImdsConnectionTimeout = TimeSpan.FromSeconds(1), ExcludeTokenExchangeManagedIdentitySource = options.ExcludeWorkloadIdentityCredential diff --git a/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs b/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs index 44755917422d..0cbe779d0bac 100644 --- a/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs +++ b/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs @@ -120,7 +120,7 @@ public void ManagedIdentityCredentialUsesDefaultTimeoutAndRetries() Assert.ThrowsAsync(async () => await cred.GetTokenAsync(new(new[] { "test" }))); - var expectedTimeouts = new TimeSpan?[] { null, null, null, null }; + var expectedTimeouts = new TimeSpan?[] { null, null, null, null, null, null }; CollectionAssert.AreEqual(expectedTimeouts, networkTimeouts); } diff --git a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs index 8ab5725778fe..5253768da832 100644 --- a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs @@ -683,13 +683,13 @@ public async Task VerifyMsiUnavailableOnIMDSRequestFailedExcpetion() { using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.001/" } }); - var options = new TokenCredentialOptions() { Retry = { MaxRetries = 0, NetworkTimeout = TimeSpan.FromMilliseconds(100) } }; + var options = new TokenCredentialOptions() { Retry = { MaxRetries = 0, NetworkTimeout = TimeSpan.FromMilliseconds(100), MaxDelay = TimeSpan.Zero } }; var credential = InstrumentClient(new ManagedIdentityCredential(options: options)); var ex = Assert.ThrowsAsync(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default))); - Assert.That(ex.Message, Does.Contain(ImdsManagedIdentitySource.NoResponseError)); + Assert.That(ex.Message, Does.Contain(ImdsManagedIdentitySource.AggregateError)); await Task.CompletedTask; }