Skip to content

Instantly share code, notes, and snippets.

@peerpalo
Created November 30, 2021 08:40
Show Gist options
  • Save peerpalo/b83598e2c98413a6f5af92fe4800985b to your computer and use it in GitHub Desktop.
Save peerpalo/b83598e2c98413a6f5af92fe4800985b to your computer and use it in GitHub Desktop.
public class AggregationMiddleware
{
private static readonly TimeSpan DefaultTimeout = TimeSpan.FromSeconds(100);
public static readonly EventId Aggregating = new EventId(1, "Aggregating");
private readonly RequestDelegate _next;
private readonly ILogger _logger;
private readonly IHttpForwarder _forwarder;
public AggregationMiddleware(RequestDelegate next, ILogger<AggregationMiddleware> logger, IHttpForwarder forwarder)
{
_next = next ?? throw new ArgumentNullException(nameof(next));
_logger = logger ?? throw new ArgumentNullException(nameof(logger));
_forwarder = forwarder ?? throw new ArgumentNullException(nameof(forwarder));
}
public async Task Invoke(HttpContext context)
{
_ = context ?? throw new ArgumentNullException(nameof(context));
var reverseProxyFeature = context.GetReverseProxyFeature();
var destinations = reverseProxyFeature.AvailableDestinations
?? throw new InvalidOperationException($"The {nameof(IReverseProxyFeature)} Destinations collection was not set.");
var route = context.GetRouteModel();
var cluster = route.Cluster!;
// If it's an aggregation do the aggregation calls
if (cluster.ClusterId == CustomProxyConfigProvider.AGGREGATE_ID)
{
if (!HttpMethods.IsGet(context.Request.Method))
{
throw new InvalidOperationException($"Aggregation is valid only for GET methods");
}
var destinationResponses = new ConcurrentBag<(DestinationState, HttpResponseMessage)>();
await Parallel.ForEachAsync(destinations, context.RequestAborted, async (destination, cancellationToken) =>
{
var destinationModel = destination.Model;
if (destinationModel == null)
{
throw new InvalidOperationException($"Chosen destination has no model set: '{destination.DestinationId}'");
}
var transformer = route.Transformer;
var clusterConfig = reverseProxyFeature.Cluster;
var httpClient = clusterConfig.HttpClient;
var destinationPrefix = destinationModel.Config.Address;
HttpResponseMessage destinationResponse = null;
var destinationRequest = await CreateRequestMessageAsync(context, destinationPrefix, transformer,
clusterConfig.Config.HttpRequest ?? ForwarderRequestConfig.Empty);
try
{
destinationResponse = await httpClient.SendAsync(destinationRequest, cancellationToken);
}
catch (Exception requestException)
{
await HandleRequestFailureAsync(context, destinationRequest.Content, requestException, transformer, cancellationToken);
}
await transformer.TransformResponseAsync(context, destinationResponse);
destinationResponses.Add((destination, destinationResponse));
});
if (destinationResponses.Any(r => r.Item2?.StatusCode == HttpStatusCode.OK))
{
var body = new Dictionary<string, dynamic>();
foreach (var (destination, response) in destinationResponses)
{
var result = response?.StatusCode == HttpStatusCode.OK ?
await response.Content.ReadFromJsonAsync<dynamic>() :
null;
body.Add(destination.DestinationId, result);
}
context.Response.StatusCode = (int)HttpStatusCode.OK;
await context.Response.WriteAsJsonAsync(body);
}
else
{
context.Response.StatusCode = (int)HttpStatusCode.BadGateway;
}
}
else
{
await _next(context);
}
}
private async Task<HttpRequestMessage> CreateRequestMessageAsync(HttpContext context, string destinationPrefix, HttpTransformer transformer, ForwarderRequestConfig requestConfig)
{
var destinationRequest = new HttpRequestMessage();
destinationRequest.Method = HttpMethod.Get;
var upgradeFeature = context.Features.Get<IHttpUpgradeFeature>();
var upgradeHeader = context.Request.Headers[HeaderNames.Upgrade].ToString();
var isUpgradeRequest = (upgradeFeature?.IsUpgradableRequest ?? false)
// Mitigate https://github.com/microsoft/reverse-proxy/issues/255, IIS considers all requests upgradeable.
&& (string.Equals("WebSocket", upgradeHeader, StringComparison.OrdinalIgnoreCase)
// https://github.com/microsoft/reverse-proxy/issues/467 for kubernetes APIs
|| upgradeHeader.StartsWith("SPDY/", StringComparison.OrdinalIgnoreCase));
await transformer.TransformRequestAsync(context, destinationRequest, destinationPrefix);
// Allow someone to custom build the request uri, otherwise provide a default for them.
var request = context.Request;
destinationRequest.RequestUri ??= RequestUtilities.MakeDestinationAddress(destinationPrefix, request.Path, request.QueryString);
destinationRequest.Version = isUpgradeRequest ? HttpVersion.Version11 : (requestConfig?.Version ?? HttpVersion.Version20);
#if NET
destinationRequest.VersionPolicy = isUpgradeRequest ? HttpVersionPolicy.RequestVersionOrLower : (requestConfig?.VersionPolicy ?? HttpVersionPolicy.RequestVersionOrLower);
#endif
_logger.LogInformation(Aggregating, "Proxying to {targetUrl} {version} {versionPolicy}", destinationRequest.RequestUri!.AbsoluteUri, HttpProtocol.GetHttpProtocol(destinationRequest.Version), destinationRequest.VersionPolicy);
return destinationRequest;
}
private async ValueTask<ForwarderError> HandleRequestFailureAsync(HttpContext context, HttpContent requestContent, Exception requestException, HttpTransformer transformer, CancellationToken requestCancellationToken)
{
if (requestException is OperationCanceledException)
{
if (!context.RequestAborted.IsCancellationRequested && requestCancellationToken.IsCancellationRequested)
{
return await ReportErrorAsync(ForwarderError.RequestTimedOut, StatusCodes.Status504GatewayTimeout);
}
else
{
return await ReportErrorAsync(ForwarderError.RequestCanceled, StatusCodes.Status502BadGateway);
}
}
// We couldn't communicate with the destination.
return await ReportErrorAsync(ForwarderError.Request, StatusCodes.Status502BadGateway);
async ValueTask<ForwarderError> ReportErrorAsync(ForwarderError error, int statusCode)
{
ReportProxyError(context, error, requestException);
context.Response.StatusCode = statusCode;
await transformer.TransformResponseAsync(context, null);
return error;
}
}
private void ReportProxyError(HttpContext context, ForwarderError error, Exception ex)
{
//context.Features.Set<IForwarderErrorFeature>(new ForwarderErrorFeature(error, ex));
_logger.LogError(Aggregating, ex, $"{error}");
}
}
@maxiptah
Copy link

maxiptah commented Feb 21, 2023

Hi @peerpalo ! Are you still using this middleware or did you switch to another solution? I'm facing the same problem as you and I'm thinking about using your middleware with small adaptations.
It looks like an easy way to bring aggregation into YARP as opposed to creating an aggregation microservice between YARP and downstream webservices. Do you see any hidden problems with the approach, apart from those that you have already stated here in the GitHub issue?

@peerpalo
Copy link
Author

Hi @maxiptah, I think that this solution is still used by my old company. As you wrote, my concerns are exactly the same as stated in GitHub issue and, as far as I can see, there isn't yet a "good way" to do it with YARP directly inside a middleware.

Please, feel free to use this snippet if it's useful to you.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment