diff --git a/sdk/identity/Azure.Identity/src/CredentialPipeline.cs b/sdk/identity/Azure.Identity/src/CredentialPipeline.cs index f6d0a1b15fbd5..0b7cfa3307aa0 100644 --- a/sdk/identity/Azure.Identity/src/CredentialPipeline.cs +++ b/sdk/identity/Azure.Identity/src/CredentialPipeline.cs @@ -26,9 +26,29 @@ public CredentialPipeline(HttpPipeline httpPipeline, ClientDiagnostics diagnosti Diagnostics = diagnostics; } - public static CredentialPipeline GetInstance(TokenCredentialOptions options) + public static CredentialPipeline GetInstance(TokenCredentialOptions options, bool IsManagedIdentityCredential = false) { - return options is null ? s_singleton.Value : new CredentialPipeline(options); + return options switch + { + _ when IsManagedIdentityCredential => configureOptionsForManagedIdentity(options), + not null => new CredentialPipeline(options), + _ => s_singleton.Value, + + }; + } + + private static CredentialPipeline configureOptionsForManagedIdentity(TokenCredentialOptions options) + { + var clonedOptions = options switch + { + DefaultAzureCredentialOptions dac => dac.Clone(), + _ => options?.Clone() ?? new TokenCredentialOptions(), + }; + // Set the custom retry policy + clonedOptions.Retry.MaxRetries = 5; + clonedOptions.RetryPolicy ??= new DefaultAzureCredentialImdsRetryPolicy(clonedOptions.Retry); + clonedOptions.IsChainedCredential = clonedOptions is DefaultAzureCredentialOptions; + return new CredentialPipeline(clonedOptions); } public HttpPipeline HttpPipeline { get; } diff --git a/sdk/identity/Azure.Identity/src/Credentials/ManagedIdentityCredential.cs b/sdk/identity/Azure.Identity/src/Credentials/ManagedIdentityCredential.cs index 8aab9e375bb23..d91cb589133db 100644 --- a/sdk/identity/Azure.Identity/src/Credentials/ManagedIdentityCredential.cs +++ b/sdk/identity/Azure.Identity/src/Credentials/ManagedIdentityCredential.cs @@ -41,7 +41,7 @@ protected ManagedIdentityCredential() /// /// Options to configure the management of the requests sent to Microsoft Entra ID. public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions options = null) - : this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ClientId = clientId, Pipeline = CredentialPipeline.GetInstance(options), Options = options })) + : this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ClientId = clientId, Pipeline = CredentialPipeline.GetInstance(options, IsManagedIdentityCredential: true), Options = options })) { _logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false; } @@ -55,7 +55,7 @@ public ManagedIdentityCredential(string clientId = null, TokenCredentialOptions /// /// Options to configure the management of the requests sent to Microsoft Entra ID. public ManagedIdentityCredential(ResourceIdentifier resourceId, TokenCredentialOptions options = null) - : this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ResourceIdentifier = resourceId, Pipeline = CredentialPipeline.GetInstance(options), Options = options })) + : this(new ManagedIdentityClient(new ManagedIdentityClientOptions { ResourceIdentifier = resourceId, Pipeline = CredentialPipeline.GetInstance(options, IsManagedIdentityCredential: true), Options = options })) { _logAccountDetails = options?.Diagnostics?.IsAccountIdentifierLoggingEnabled ?? false; _clientId = resourceId.ToString(); diff --git a/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs b/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs index a4c9b7164eb66..2d6817a97c6bb 100644 --- a/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs +++ b/sdk/identity/Azure.Identity/src/DefaultAzureCredentialFactory.cs @@ -121,16 +121,13 @@ public virtual TokenCredential CreateWorkloadIdentityCredential() public virtual TokenCredential CreateManagedIdentityCredential() { var options = Options.Clone(); - // Set the custom retry policy - options.Retry.MaxRetries = 5; - options.RetryPolicy ??= new DefaultAzureCredentialImdsRetryPolicy(options.Retry); options.IsChainedCredential = true; var miOptions = new ManagedIdentityClientOptions { ResourceIdentifier = options.ManagedIdentityResourceId, ClientId = options.ManagedIdentityClientId, - Pipeline = CredentialPipeline.GetInstance(options), + Pipeline = CredentialPipeline.GetInstance(options, IsManagedIdentityCredential: true), Options = options, InitialImdsConnectionTimeout = TimeSpan.FromSeconds(1), ExcludeTokenExchangeManagedIdentitySource = options.ExcludeWorkloadIdentityCredential diff --git a/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs b/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs index 44755917422dc..0cbe779d0bacb 100644 --- a/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs +++ b/sdk/identity/Azure.Identity/tests/ImdsManagedIdentitySourceTests.cs @@ -120,7 +120,7 @@ public void ManagedIdentityCredentialUsesDefaultTimeoutAndRetries() Assert.ThrowsAsync(async () => await cred.GetTokenAsync(new(new[] { "test" }))); - var expectedTimeouts = new TimeSpan?[] { null, null, null, null }; + var expectedTimeouts = new TimeSpan?[] { null, null, null, null, null, null }; CollectionAssert.AreEqual(expectedTimeouts, networkTimeouts); } diff --git a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs index 8ab5725778fe5..5253768da832f 100644 --- a/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs +++ b/sdk/identity/Azure.Identity/tests/ManagedIdentityCredentialTests.cs @@ -683,13 +683,13 @@ public async Task VerifyMsiUnavailableOnIMDSRequestFailedExcpetion() { using var environment = new TestEnvVar(new() { { "MSI_ENDPOINT", null }, { "MSI_SECRET", null }, { "IDENTITY_ENDPOINT", null }, { "IDENTITY_HEADER", null }, { "AZURE_POD_IDENTITY_AUTHORITY_HOST", "http://169.254.169.001/" } }); - var options = new TokenCredentialOptions() { Retry = { MaxRetries = 0, NetworkTimeout = TimeSpan.FromMilliseconds(100) } }; + var options = new TokenCredentialOptions() { Retry = { MaxRetries = 0, NetworkTimeout = TimeSpan.FromMilliseconds(100), MaxDelay = TimeSpan.Zero } }; var credential = InstrumentClient(new ManagedIdentityCredential(options: options)); var ex = Assert.ThrowsAsync(async () => await credential.GetTokenAsync(new TokenRequestContext(MockScopes.Default))); - Assert.That(ex.Message, Does.Contain(ImdsManagedIdentitySource.NoResponseError)); + Assert.That(ex.Message, Does.Contain(ImdsManagedIdentitySource.AggregateError)); await Task.CompletedTask; }