Skip to content

Commit

Permalink
Enable request transforms to reject requests #1701 (#1923)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tratcher authored Nov 14, 2022
1 parent 74bef16 commit c39fcf8
Show file tree
Hide file tree
Showing 11 changed files with 232 additions and 10 deletions.
2 changes: 2 additions & 0 deletions docs/docfx/articles/transforms.md
Original file line number Diff line number Diff line change
Expand Up @@ -959,6 +959,8 @@ Only header1 and header2 are copied from the proxy response.

All request transforms must derive from the abstract base class [RequestTransform](xref:Yarp.ReverseProxy.Transforms.RequestTransform). These can freely modify the proxy `HttpRequestMessage`. Avoid reading or modifying the request body as this may disrupt the proxying flow. Consider also adding a parametrized extension method on `TransformBuilderContext` for discoverability and easy of use.

A request transform may conditionally produce an immediate response such as for error conditions. This prevents any remaining transforms from running and the request from being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`, or writing to the `HttpResponse.Body` or `BodyWriter`.

### ResponseTransform

All response transforms must derive from the abstract base class [ResponseTransform](xref:Yarp.ReverseProxy.Transforms.ResponseTransform). These can freely modify the client `HttpResponse`. Avoid reading or modifying the response body as this may disrupt the proxying flow. Consider also adding a parametrized extension method on `TransformBuilderContext` for discoverability and easy of use.
Expand Down
34 changes: 33 additions & 1 deletion samples/ReverseProxy.Auth.Sample/Startup.cs
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System.Net.Http.Headers;
using Microsoft.AspNetCore.Authentication;
using Microsoft.AspNetCore.Authentication.Cookies;
using Microsoft.AspNetCore.Builder;
using Microsoft.Extensions.Configuration;
using Microsoft.Extensions.DependencyInjection;
using Yarp.ReverseProxy.Transforms;

namespace Yarp.Sample
{
Expand All @@ -31,8 +34,37 @@ public void ConfigureServices(IServiceCollection services)
// Required to supply the authentication UI in Views/*
services.AddRazorPages();

services.AddSingleton<TokenService>();

services.AddReverseProxy()
.LoadFromConfig(_configuration.GetSection("ReverseProxy"));
.LoadFromConfig(_configuration.GetSection("ReverseProxy"))
.AddTransforms(transformBuilderContext => // Add transforms inline
{
// For each route+cluster pair decide if we want to add transforms, and if so, which?
// This logic is re-run each time a route is rebuilt.
// Only do this for routes that require auth.
if (string.Equals("myPolicy", transformBuilderContext.Route.AuthorizationPolicy))
{
transformBuilderContext.AddRequestTransform(async transformContext =>
{
// AuthN and AuthZ will have already been completed after request routing.
var ticket = await transformContext.HttpContext.AuthenticateAsync(CookieAuthenticationDefaults.AuthenticationScheme);
var tokenService = transformContext.HttpContext.RequestServices.GetRequiredService<TokenService>();
var token = await tokenService.GetAuthTokenAsync(ticket.Principal);
// Reject invalid requests
if (string.IsNullOrEmpty(token))
{
var response = transformContext.HttpContext.Response;
response.StatusCode = 401;
return;
}
transformContext.ProxyRequest.Headers.Authorization = new AuthenticationHeaderValue("Bearer", token);
});
}
}); ;

services.AddAuthentication(CookieAuthenticationDefaults.AuthenticationScheme)
.AddCookie();
Expand Down
22 changes: 22 additions & 0 deletions samples/ReverseProxy.Auth.Sample/TokenService.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT License.

using System;
using System.Security.Claims;
using System.Threading.Tasks;

namespace Yarp.Sample
{
internal class TokenService
{
internal Task<string> GetAuthTokenAsync(ClaimsPrincipal user)
{
// we only have tokens for bob
if (string.Equals("Bob", user.Identity.Name))
{
return Task.FromResult(Guid.NewGuid().ToString());
}
return Task.FromResult<string>(null);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

<form action="Login" method="post">
<input hidden name="returnurl" type="text" value="@ViewData["ReturnUrl"]" /><br />
<input name="Name" type="text" value="My Name" /><br />
<input name="Name" type="text" value="Bob" /><br />
<input name="myClaimValue" type="text" value="green" /><br />
<input type="submit">
<div><b>Note:</b>The authorization policy will check for the value of "green", other values should pass authentication, but not authorize for specific routes</div>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,7 @@ public void ValidateRoute(TransformRouteValidationContext context)
public void ValidateCluster(TransformClusterValidationContext context)
{
// Check all clusters for a custom property and validate the associated transform data.
string value = null;
if (context.Cluster.Metadata?.TryGetValue("CustomMetadata", out value) ?? false)
if (context.Cluster.Metadata?.TryGetValue("CustomMetadata", out var value) ?? false)
{
if (string.IsNullOrEmpty(value))
{
Expand Down
32 changes: 30 additions & 2 deletions src/ReverseProxy/Forwarder/HttpForwarder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,11 @@ public async ValueTask<ForwarderError> SendAsync(
_ = requestConfig ?? throw new ArgumentNullException(nameof(requestConfig));
_ = transformer ?? throw new ArgumentNullException(nameof(transformer));

if (RequestUtilities.IsResponseSet(context.Response))
{
throw new InvalidOperationException("The request cannot be forwarded, the response has already started");
}

// HttpClient overload for SendAsync changes response behavior to fully buffered which impacts performance
// See discussion in https://github.com/microsoft/reverse-proxy/issues/458
if (httpClient is HttpClient)
Expand All @@ -116,6 +121,15 @@ public async ValueTask<ForwarderError> SendAsync(
var (destinationRequest, requestContent) = await CreateRequestMessageAsync(
context, destinationPrefix, transformer, requestConfig, isStreamingRequest, activityCancellationSource);

// Transforms generated a response, do not proxy.
if (RequestUtilities.IsResponseSet(context.Response))
{
Log.NotProxying(_logger, context.Response.StatusCode);
return ForwarderError.None;
}

Log.Proxying(_logger, destinationRequest, isStreamingRequest);

// :: Step 4: Send the outgoing request using HttpClient
HttpResponseMessage destinationResponse;
try
Expand Down Expand Up @@ -282,6 +296,12 @@ public async ValueTask<ForwarderError> SendAsync(
// :: Step 3: Copy request headers Client --► Proxy --► Destination
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);

// The transformer generated a response, do not forward.
if (RequestUtilities.IsResponseSet(context.Response))
{
return (destinationRequest, requestContent);
}

if (isUpgradeRequest)
{
RestoreUpgradeHeaders(context, destinationRequest);
Expand All @@ -291,8 +311,6 @@ public async ValueTask<ForwarderError> SendAsync(
var request = context.Request;
destinationRequest.RequestUri ??= RequestUtilities.MakeDestinationAddress(destinationPrefix, request.Path, request.QueryString);

Log.Proxying(_logger, destinationRequest, isStreamingRequest);

if (requestConfig?.AllowResponseBuffering != true)
{
context.Features.Get<IHttpResponseBodyFeature>()?.DisableBuffering();
Expand Down Expand Up @@ -765,6 +783,11 @@ private static class Log
EventIds.ForwardingError,
"{error}: {message}");

private static readonly Action<ILogger, int, Exception?> _notProxying = LoggerMessage.Define<int>(
LogLevel.Information,
EventIds.NotForwarding,
"Not Proxying, a {statusCode} response was set by the transforms.");

public static void ResponseReceived(ILogger logger, HttpResponseMessage msg)
{
_responseReceived(logger, msg.Version, (int)msg.StatusCode, null);
Expand All @@ -782,6 +805,11 @@ public static void Proxying(ILogger logger, HttpRequestMessage msg, bool isStrea
}
}

public static void NotProxying(ILogger logger, int statusCode)
{
_notProxying(logger, statusCode, null);
}

public static void ErrorProxying(ILogger logger, ForwarderError error, Exception ex)
{
_proxyError(logger, error, GetMessage(error), ex);
Expand Down
3 changes: 3 additions & 0 deletions src/ReverseProxy/Forwarder/HttpTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ private static bool IsBodylessStatusCode(HttpStatusCode statusCode) =>
/// See <see cref="RequestUtilities.MakeDestinationAddress(string, PathString, QueryString)"/> for constructing a custom request Uri.
/// The string parameter represents the destination URI prefix that should be used when constructing the RequestUri.
/// The headers are copied by the base implementation, excluding some protocol headers like HTTP/2 pseudo headers (":authority").
/// This method may be overridden to conditionally produce a response, such as for error conditions, and prevent the request from
/// being proxied. This is indicated by setting the `HttpResponse.StatusCode` to a value other than 200, or calling `HttpResponse.StartAsync()`,
/// or writing to the `HttpResponse.Body` or `BodyWriter`.
/// </summary>
/// <param name="httpContext">The incoming request.</param>
/// <param name="proxyRequest">The outgoing proxy request.</param>
Expand Down
6 changes: 6 additions & 0 deletions src/ReverseProxy/Forwarder/RequestUtilities.cs
Original file line number Diff line number Diff line change
Expand Up @@ -404,4 +404,10 @@ static StringValues ToArray(in HeaderStringValues values)
values = default;
return false;
}

internal static bool IsResponseSet(HttpResponse response)
{
return response.StatusCode != StatusCodes.Status200OK
|| response.HasStarted;
}
}
6 changes: 6 additions & 0 deletions src/ReverseProxy/Transforms/Builder/StructuredTransformer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,12 @@ public override async ValueTask TransformRequestAsync(HttpContext httpContext, H
foreach (var requestTransform in RequestTransforms)
{
await requestTransform.ApplyAsync(transformContext);

// The transform generated a response, do not apply further transforms and do not forward.
if (RequestUtilities.IsResponseSet(httpContext.Response))
{
return;
}
}

// Allow a transform to directly set a custom RequestUri.
Expand Down
1 change: 1 addition & 0 deletions src/ReverseProxy/Utilities/EventIds.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,5 @@ internal static class EventIds
public static readonly EventId ResponseReceived = new EventId(56, "ResponseReceived");
public static readonly EventId DelegationQueueReset = new EventId(57, "DelegationQueueReset");
public static readonly EventId Http10RequestVersionDetected = new EventId(58, "Http10RequestVersionDetected");
public static readonly EventId NotForwarding = new EventId(59, "NotForwarding");
}
131 changes: 127 additions & 4 deletions test/ReverseProxy.Tests/Forwarder/HttpForwarderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,125 @@ public async Task TransformRequestAsync_ReplaceBody()
events.AssertContainProxyStages();
}

[Fact]
public async Task TransformRequestAsync_SetsStatus_ShortCircuits()
{
var events = TestEventListener.Collect();

var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "POST";
httpContext.Request.Protocol = "HTTP/2";

var destinationPrefix = "https://localhost/";

var transforms = new DelegateHttpTransforms()
{
CopyRequestHeaders = true,
OnRequest = (context, request, destination) =>
{
context.Response.StatusCode = 401;
return Task.CompletedTask;
}
};

var sut = CreateProxy();
var client = MockHttpHandler.CreateClient(
(HttpRequestMessage request, CancellationToken cancellationToken) =>
{
throw new NotImplementedException();
});

var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms);

Assert.Equal(ForwarderError.None, proxyError);
Assert.Equal(StatusCodes.Status401Unauthorized, httpContext.Response.StatusCode);

AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode);
events.AssertContainProxyStages(new ForwarderStage[0]);
}

[Fact]
public async Task TransformRequestAsync_StartsResponse_ShortCircuits()
{
var events = TestEventListener.Collect();

var httpContext = new DefaultHttpContext();
var responseBody = new TestResponseBody();
httpContext.Features.Set<IHttpResponseFeature>(responseBody);
httpContext.Features.Set<IHttpResponseBodyFeature>(responseBody);
httpContext.Request.Method = "POST";
httpContext.Request.Protocol = "HTTP/2";

var destinationPrefix = "https://localhost/";

var transforms = new DelegateHttpTransforms()
{
CopyRequestHeaders = true,
OnRequest = (context, request, destination) =>
{
return context.Response.StartAsync();
}
};

var sut = CreateProxy();
var client = MockHttpHandler.CreateClient(
(HttpRequestMessage request, CancellationToken cancellationToken) =>
{
throw new NotImplementedException();
});

var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms);

Assert.Equal(ForwarderError.None, proxyError);
Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode);
Assert.True(httpContext.Response.HasStarted);

AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode);
events.AssertContainProxyStages(new ForwarderStage[0]);
}

[Fact]
public async Task TransformRequestAsync_WritesToResponse_ShortCircuits()
{
var events = TestEventListener.Collect();

var httpContext = new DefaultHttpContext();
var resultStream = new MemoryStream();
var responseBody = new TestResponseBody(resultStream);
httpContext.Features.Set<IHttpResponseFeature>(responseBody);
httpContext.Features.Set<IHttpResponseBodyFeature>(responseBody);
httpContext.Request.Method = "POST";
httpContext.Request.Protocol = "HTTP/2";

var destinationPrefix = "https://localhost/";

var transforms = new DelegateHttpTransforms()
{
CopyRequestHeaders = true,
OnRequest = (context, request, destination) =>
{
return context.Response.Body.WriteAsync(Encoding.UTF8.GetBytes("Hello World")).AsTask();
}
};

var sut = CreateProxy();
var client = MockHttpHandler.CreateClient(
(HttpRequestMessage request, CancellationToken cancellationToken) =>
{
throw new NotImplementedException();
});

var proxyError = await sut.SendAsync(httpContext, destinationPrefix, client, ForwarderRequestConfig.Empty, transforms);

Assert.Equal(ForwarderError.None, proxyError);
Assert.Equal(StatusCodes.Status200OK, httpContext.Response.StatusCode);
Assert.True(httpContext.Response.HasStarted);
Assert.Equal("Hello World", Encoding.UTF8.GetString(resultStream.ToArray()));

AssertProxyStartStop(events, destinationPrefix, httpContext.Response.StatusCode);
events.AssertContainProxyStages(new ForwarderStage[0]);
}

// Tests proxying an upgradeable request.
[Theory]
[InlineData("WebSocket")]
Expand Down Expand Up @@ -1887,11 +2006,10 @@ public async Task ResponseBodyCancelledAfterStart_Aborted()
var httpContext = new DefaultHttpContext();
httpContext.Request.Method = "GET";
httpContext.Request.Host = new HostString("example.com:3456");
var responseBody = new TestResponseBody() { HasStarted = true };
var responseBody = new TestResponseBody();
httpContext.Features.Set<IHttpResponseFeature>(responseBody);
httpContext.Features.Set<IHttpResponseBodyFeature>(responseBody);
httpContext.Features.Set<IHttpRequestLifetimeFeature>(responseBody);
httpContext.RequestAborted = new CancellationToken(canceled: true);

var destinationPrefix = "https://localhost:123/";
var sut = CreateProxy();
Expand All @@ -1900,7 +2018,11 @@ public async Task ResponseBodyCancelledAfterStart_Aborted()
{
var message = new HttpResponseMessage()
{
Content = new StreamContent(new MemoryStream(new byte[1]))
Content = new StreamContent(new CallbackReadStream((_, _) =>
{
responseBody.HasStarted = true;
throw new TaskCanceledException();
}))
};
message.Headers.AcceptRanges.Add("bytes");
return Task.FromResult(message);
Expand Down Expand Up @@ -2828,7 +2950,8 @@ public Task SendFileAsync(string path, long offset, long? count, CancellationTok

public Task StartAsync(CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
OnStart();
return Task.CompletedTask;
}

public override ValueTask WriteAsync(ReadOnlyMemory<byte> buffer, CancellationToken cancellationToken = default)
Expand Down

0 comments on commit c39fcf8

Please sign in to comment.