Skip to content

Instantly share code, notes, and snippets.

@peerpalo
Created November 30, 2021 08:40
Show Gist options
  • Save peerpalo/b83598e2c98413a6f5af92fe4800985b to your computer and use it in GitHub Desktop.
Save peerpalo/b83598e2c98413a6f5af92fe4800985b to your computer and use it in GitHub Desktop.
public class AggregationMiddleware
{
private static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(100);
public static readonly EventId Aggregating = new EventId(1, "Aggregating");
private readonly RequestDelegate _next;
private readonly ILogger _logger;
private readonly IHttpForwarder _forwarder;
public AggregationMiddleware(RequestDelegate next, ILogger<AggregationMiddleware> logger, IHttpForwarder forwarder)
{
_next = next ?? throw new ArgumentNullException(nameof(next));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_forwarder = forwarder ?? throw new ArgumentNullException(nameof(forwarder));
}
public async Task Invoke(HttpContext context)
{
_ = context ?? throw new ArgumentNullException(nameof(context));
var reverseProxyFeature = context.GetReverseProxyFeature();
var destinations = reverseProxyFeature.AvailableDestinations
?? throw new InvalidOperationException($"The {nameof(IReverseProxyFeature)} Destinations collection was not set.");
var route = context.GetRouteModel();
var cluster = route.Cluster!;
// If it's an aggregation do the aggregation calls
if (cluster.ClusterId == CustomProxyConfigProvider.AGGREGATE_ID)
{
if (!HttpMethods.IsGet(context.Request.Method))
{
throw new InvalidOperationException($"Aggregation is valid only for GET methods");
}
var destinationResponses = new ConcurrentBag<(DestinationState, HttpResponseMessage)>();
await Parallel.ForEachAsync(destinations, context.RequestAborted, async (destination, cancellationToken) =>
{
var destinationModel = destination.Model;
if (destinationModel == null)
{
throw new InvalidOperationException($"Chosen destination has no model set: '{destination.DestinationId}'");
}
var transformer = route.Transformer;
var clusterConfig = reverseProxyFeature.Cluster;
var httpClient = clusterConfig.HttpClient;
var destinationPrefix = destinationModel.Config.Address;
HttpResponseMessage destinationResponse = null;
var destinationRequest = await CreateRequestMessageAsync(context, destinationPrefix, transformer,
clusterConfig.Config.HttpRequest ?? ForwarderRequestConfig.Empty);
try
{
destinationResponse = await httpClient.SendAsync(destinationRequest, cancellationToken);
}
catch (Exception requestException)
{
await HandleRequestFailureAsync(context, destinationRequest.Content, requestException, transformer, cancellationToken);
}
await transformer.TransformResponseAsync(context, destinationResponse);
destinationResponses.Add((destination, destinationResponse));
});
if (destinationResponses.Any(r => r.Item2?.StatusCode == HttpStatusCode.OK))
{
var body = new Dictionary<string, dynamic>();
foreach (var (destination, response) in destinationResponses)
{
var result = response?.StatusCode == HttpStatusCode.OK ?
await response.Content.ReadFromJsonAsync<dynamic>() :
null;
body.Add(destination.DestinationId, result);
}
context.Response.StatusCode = (int)HttpStatusCode.OK;
await context.Response.WriteAsJsonAsync(body);
}
else
{
context.Response.StatusCode = (int)HttpStatusCode.BadGateway;
}
}
else
{
await _next(context);
}
}
private async Task<HttpRequestMessage> CreateRequestMessageAsync(HttpContext context, string destinationPrefix, HttpTransformer transformer, ForwarderRequestConfig requestConfig)
{
var destinationRequest = new HttpRequestMessage();
destinationRequest.Method = HttpMethod.Get;
var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>();
var upgradeHeader = context.Request.Headers[HeaderNames.Upgrade].ToString();
var isUpgradeRequest = (upgradeFeature?.IsUpgradableRequest ?? false)
// Mitigate https://github.com/microsoft/reverse-proxy/issues/255, IIS considers all requests upgradeable.
&& (string.Equals("WebSocket", upgradeHeader, StringComparison.OrdinalIgnoreCase)
// https://github.com/microsoft/reverse-proxy/issues/467 for kubernetes APIs
|| upgradeHeader.StartsWith("SPDY/", StringComparison.OrdinalIgnoreCase));
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);
// Allow someone to custom build the request uri, otherwise provide a default for them.
var request = context.Request;
destinationRequest.RequestUri ??= RequestUtilities.MakeDestinationAddress(destinationPrefix, request.Path, request.QueryString);
destinationRequest.Version = isUpgradeRequest ? HttpVersion.Version11 : (requestConfig?.Version ?? HttpVersion.Version20);
#if NET
destinationRequest.VersionPolicy = isUpgradeRequest ? HttpVersionPolicy.RequestVersionOrLower : (requestConfig?.VersionPolicy ?? HttpVersionPolicy.RequestVersionOrLower);
#endif
_logger.LogInformation(Aggregating, "Proxying to {targetUrl} {version} {versionPolicy}", destinationRequest.RequestUri!.AbsoluteUri, HttpProtocol.GetHttpProtocol(destinationRequest.Version), destinationRequest.VersionPolicy);
return destinationRequest;
}
private async ValueTask<ForwarderError> HandleRequestFailureAsync(HttpContext context, HttpContent requestContent, Exception requestException, HttpTransformer transformer, CancellationToken requestCancellationToken)
{
if (requestException is OperationCanceledException)
{
if (!context.RequestAborted.IsCancellationRequested && requestCancellationToken.IsCancellationRequested)
{
return await ReportErrorAsync(ForwarderError.RequestTimedOut, StatusCodes.Status504GatewayTimeout);
}
else
{
return await ReportErrorAsync(ForwarderError.RequestCanceled, StatusCodes.Status502BadGateway);
}
}
// We couldn't communicate with the destination.
return await ReportErrorAsync(ForwarderError.Request, StatusCodes.Status502BadGateway);
async ValueTask<ForwarderError> ReportErrorAsync(ForwarderError error, int statusCode)
{
ReportProxyError(context, error, requestException);
context.Response.StatusCode = statusCode;
await transformer.TransformResponseAsync(context, null);
return error;
}
}
private void ReportProxyError(HttpContext context, ForwarderError error, Exception ex)
{
//context.Features.Set<IForwarderErrorFeature>(new ForwarderErrorFeature(error, ex));
_logger.LogError(Aggregating, ex, $"{error}");
}
}
@peerpalo
Copy link
Author

Hi @maxiptah, I think that this solution is still used by my old company. As you wrote, my concerns are exactly the same as stated in GitHub issue and, as far as I can see, there isn't yet a "good way" to do it with YARP directly inside a middleware.

Please, feel free to use this snippet if it's useful to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment