diff --git a/src/Libraries/Microsoft.Extensions.Http.Resilience/Polly/HttpRetryStrategyOptionsExtensions.cs b/src/Libraries/Microsoft.Extensions.Http.Resilience/Polly/HttpRetryStrategyOptionsExtensions.cs new file mode 100644 index 00000000000..85168988c7b --- /dev/null +++ b/src/Libraries/Microsoft.Extensions.Http.Resilience/Polly/HttpRetryStrategyOptionsExtensions.cs @@ -0,0 +1,66 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics.CodeAnalysis; +using System.Linq; +using System.Net.Http; +using System.Threading.Tasks; +using Microsoft.Shared.DiagnosticIds; +using Microsoft.Shared.Diagnostics; +using Polly; + +namespace Microsoft.Extensions.Http.Resilience; + +/// +/// Extensions for . +/// +[Experimental(diagnosticId: DiagnosticIds.Experiments.Resilience, UrlFormat = DiagnosticIds.UrlFormat)] +public static class HttpRetryStrategyOptionsExtensions +{ +#if !NET8_0_OR_GREATER + private static readonly HttpMethod _connect = new("CONNECT"); + private static readonly HttpMethod _patch = new("PATCH"); +#endif + + /// + /// Disables retry attempts for POST, PATCH, PUT, DELETE, and CONNECT HTTP methods. + /// + /// The retry strategy options. + public static void DisableForUnsafeHttpMethods(this HttpRetryStrategyOptions options) + { + options.DisableFor( + HttpMethod.Delete, HttpMethod.Post, HttpMethod.Put, +#if !NET8_0_OR_GREATER + _connect, _patch); +#else + HttpMethod.Connect, HttpMethod.Patch); +#endif + } + + /// + /// Disables retry attempts for the given list of HTTP methods. + /// + /// The retry strategy options. + /// The list of HTTP methods. + public static void DisableFor(this HttpRetryStrategyOptions options, params HttpMethod[] methods) + { + _ = Throw.IfNullOrEmpty(methods); + + var shouldHandle = Throw.IfNullOrMemberNull(options, options?.ShouldHandle); + + options.ShouldHandle = async args => + { + var result = await shouldHandle(args).ConfigureAwait(args.Context.ContinueOnCapturedContext); + + if (result && + args.Outcome.Result is HttpResponseMessage response && + response.RequestMessage is HttpRequestMessage request) + { + return !methods.Contains(request.Method); + } + + return result; + }; + } +} + diff --git a/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Polly/HttpRetryStrategyOptionsExtensionsTests.cs b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Polly/HttpRetryStrategyOptionsExtensionsTests.cs new file mode 100644 index 00000000000..4d43c020d1f --- /dev/null +++ b/test/Libraries/Microsoft.Extensions.Http.Resilience.Tests/Polly/HttpRetryStrategyOptionsExtensionsTests.cs @@ -0,0 +1,115 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; +using System.Net.Http; +using System.Threading.Tasks; +using Polly; +using Polly.Retry; +using Xunit; + +namespace Microsoft.Extensions.Http.Resilience.Test.Polly; + +public class HttpRetryStrategyOptionsExtensionsTests +{ + [Fact] + public void DisableFor_RetryOptionsIsNull_Throws() + { + Assert.Throws(() => ((HttpRetryStrategyOptions)null!).DisableFor(HttpMethod.Get)); + } + + [Fact] + public void DisableFor_HttpMethodsIsNull_Throws() + { + Assert.Throws(() => new HttpRetryStrategyOptions().DisableFor(null!)); + } + + [Fact] + public void DisableFor_HttpMethodsIsEmptry_Throws() + { + Assert.Throws(() => new HttpRetryStrategyOptions().DisableFor([])); + } + + [Fact] + public void DisableFor_ShouldHandleIsNull_Throws() + { + var options = new HttpRetryStrategyOptions { ShouldHandle = null! }; + Assert.Throws(() => options.DisableFor(HttpMethod.Get)); + } + + [Theory] + [InlineData("POST", false)] + [InlineData("DELETE", false)] + [InlineData("GET", true)] + public async Task DisableFor_PositiveScenario(string httpMethod, bool shouldHandle) + { + var options = new HttpRetryStrategyOptions { ShouldHandle = _ => PredicateResult.True() }; + options.DisableFor(HttpMethod.Post, HttpMethod.Delete); + + using var request = new HttpRequestMessage { Method = new HttpMethod(httpMethod) }; + using var response = new HttpResponseMessage { RequestMessage = request }; + + Assert.Equal(shouldHandle, await options.ShouldHandle(CreatePredicateArguments(response))); + } + + [Fact] + public async Task DisableFor_RespectsOriginalShouldHandlePredicate() + { + var options = new HttpRetryStrategyOptions { ShouldHandle = _ => PredicateResult.False() }; + options.DisableFor(HttpMethod.Post); + + using var request = new HttpRequestMessage { Method = HttpMethod.Get }; + using var response = new HttpResponseMessage { RequestMessage = request }; + + Assert.False(await options.ShouldHandle(CreatePredicateArguments(response))); + } + + [Fact] + public async Task DisableFor_ResponseMessageIsNull_DoesNotDisableRetries() + { + var options = new HttpRetryStrategyOptions { ShouldHandle = _ => PredicateResult.True() }; + options.DisableFor(HttpMethod.Post); + + Assert.True(await options.ShouldHandle(CreatePredicateArguments(null))); + } + + [Fact] + public async Task DisableFor_RequestMessageIsNull_DoesNotDisableRetries() + { + var options = new HttpRetryStrategyOptions { ShouldHandle = _ => PredicateResult.True() }; + options.DisableFor(HttpMethod.Post); + + using var response = new HttpResponseMessage { RequestMessage = null }; + + Assert.True(await options.ShouldHandle(CreatePredicateArguments(response))); + } + + [Theory] + [InlineData("POST", false)] + [InlineData("DELETE", false)] + [InlineData("PUT", false)] + [InlineData("PATCH", false)] + [InlineData("CONNECT", false)] + [InlineData("GET", true)] + [InlineData("HEAD", true)] + [InlineData("TRACE", true)] + [InlineData("OPTIONS", true)] + public async Task DisableForUnsafeHttpMethods_PositiveScenario(string httpMethod, bool shouldHandle) + { + var options = new HttpRetryStrategyOptions { ShouldHandle = _ => PredicateResult.True() }; + options.DisableForUnsafeHttpMethods(); + + using var request = new HttpRequestMessage { Method = new HttpMethod(httpMethod) }; + using var response = new HttpResponseMessage { RequestMessage = request }; + + Assert.Equal(shouldHandle, await options.ShouldHandle(CreatePredicateArguments(response))); + } + + private static RetryPredicateArguments CreatePredicateArguments(HttpResponseMessage? response) + { + return new RetryPredicateArguments( + ResilienceContextPool.Shared.Get(), + Outcome.FromResult(response), + attemptNumber: 1); + } +}