Skip to content

Instantly share code, notes, and snippets.

@joelverhagen
Last active January 16, 2024 05:04
Show Gist options
  • Save joelverhagen/3be85bc0d5733756befa to your computer and use it in GitHub Desktop.
Save joelverhagen/3be85bc0d5733756befa to your computer and use it in GitHub Desktop.
RedirectingHandler
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
namespace Knapcode.Http.Handlers
{
/// <summary>
/// A delegating handler that handles HTTP redirects (301, 302, 303, 307, and 308).
/// </summary>
public class RedirectingHandler : DelegatingHandler
{
/// <summary>
/// The property key used to access the list of responses.
/// </summary>
public const string HistoryPropertyKey = "Knapcode.Http.Handlers.RedirectingHandler.ResponseHistory";
private static readonly ISet<HttpStatusCode> RedirectStatusCodes = new HashSet<HttpStatusCode>(new[]
{
HttpStatusCode.MovedPermanently,
HttpStatusCode.Found,
HttpStatusCode.SeeOther,
HttpStatusCode.TemporaryRedirect,
(HttpStatusCode) 308
});
private static readonly ISet<HttpStatusCode> KeepRequestBodyRedirectStatusCodes = new HashSet<HttpStatusCode>(new[]
{
HttpStatusCode.TemporaryRedirect,
(HttpStatusCode) 308
});
/// <summary>
/// Initializes a new instance of the <see cref="RedirectingHandler"/> class.
/// </summary>
public RedirectingHandler()
{
AllowAutoRedirect = true;
MaxAutomaticRedirections = 50;
DisableInnerAutoRedirect = true;
DownloadContentOnRedirect = true;
KeepResponseHistory = true;
}
/// <summary>
/// Gets or sets a value that indicates whether the handler should follow redirection responses.
/// </summary>
public bool AllowAutoRedirect { get; set; }
/// <summary>
/// Gets or sets the maximum number of redirects that the handler follows.
/// </summary>
public int MaxAutomaticRedirections { get; set; }
/// <summary>
/// Gets or sets a value indicating whether the response body should be downloaded before each redirection.
/// </summary>
public bool DownloadContentOnRedirect { get; set; }
/// <summary>
/// Gets or sets a value indicating inner redirections on <see cref="HttpClientHandler"/> and <see cref="RedirectingHandler"/> should be disabled.
/// </summary>
public bool DisableInnerAutoRedirect { get; set; }
/// <summary>
/// Gets or sets a value indicating whether the response history should be saved to the <see cref="HttpResponseMessage.RequestMessage"/> properties with the key of <see cref="HistoryPropertyKey"/>.
/// </summary>
public bool KeepResponseHistory { get; set; }
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
if (DisableInnerAutoRedirect)
{
// find the inner-most handler
HttpMessageHandler innerHandler = InnerHandler;
while (innerHandler is DelegatingHandler)
{
var redirectingHandler = innerHandler as RedirectingHandler;
if (redirectingHandler != null)
{
redirectingHandler.AllowAutoRedirect = false;
}
innerHandler = ((DelegatingHandler) innerHandler).InnerHandler;
}
var httpClientHandler = innerHandler as HttpClientHandler;
if (httpClientHandler != null)
{
httpClientHandler.AllowAutoRedirect = false;
}
}
// buffer the request body, to allow re-use in redirects
HttpContent requestBody = null;
if (AllowAutoRedirect && request.Content != null)
{
byte[] buffer = await request.Content.ReadAsByteArrayAsync();
requestBody = new ByteArrayContent(buffer);
foreach (var header in request.Content.Headers)
{
requestBody.Headers.Add(header.Key, header.Value);
}
}
// make a copy of the request headers
KeyValuePair<string, string[]>[] requestHeaders = request
.Headers
.Select(p => new KeyValuePair<string, string[]>(p.Key, p.Value.ToArray()))
.ToArray();
// send the initial request
HttpResponseMessage response = await base.SendAsync(request, cancellationToken);
var responses = new List<HttpResponseMessage>();
int redirectCount = 0;
string locationString;
while (AllowAutoRedirect && redirectCount < MaxAutomaticRedirections && TryGetRedirectLocation(response, out locationString))
{
if (DownloadContentOnRedirect && response.Content != null)
{
await response.Content.ReadAsByteArrayAsync();
}
Uri previousRequestUri = response.RequestMessage.RequestUri;
// Credit where credit is due: https://github.com/kennethreitz/requests/blob/master/requests/sessions.py
// allow redirection without a scheme
if (locationString.StartsWith("//"))
{
locationString = previousRequestUri.Scheme + ":" + locationString;
}
var nextRequestUri = new Uri(locationString, UriKind.RelativeOrAbsolute);
// allow relative redirects
if (!nextRequestUri.IsAbsoluteUri)
{
nextRequestUri = new Uri(previousRequestUri, nextRequestUri);
}
// override previous method
HttpMethod nextMethod = response.RequestMessage.Method;
if ((response.StatusCode == HttpStatusCode.Moved && nextMethod == HttpMethod.Post) ||
(response.StatusCode == HttpStatusCode.Found && nextMethod != HttpMethod.Head) ||
(response.StatusCode == HttpStatusCode.SeeOther && nextMethod != HttpMethod.Head))
{
nextMethod = HttpMethod.Get;
requestBody = null;
}
if (!KeepRequestBodyRedirectStatusCodes.Contains(response.StatusCode))
{
requestBody = null;
}
// build the next request
var nextRequest = new HttpRequestMessage(nextMethod, nextRequestUri)
{
Content = requestBody,
Version = request.Version
};
foreach (var header in requestHeaders)
{
nextRequest.Headers.Add(header.Key, header.Value);
}
foreach (var pair in request.Properties)
{
nextRequest.Properties.Add(pair.Key, pair.Value);
}
// keep a history all responses
if (KeepResponseHistory)
{
responses.Add(response);
}
// send the next request
response = await base.SendAsync(nextRequest, cancellationToken);
request = response.RequestMessage;
redirectCount++;
}
// save the history to the request message properties
if (KeepResponseHistory && response.RequestMessage != null)
{
responses.Add(response);
response.RequestMessage.Properties.Add(HistoryPropertyKey, responses);
}
return response;
}
private static bool TryGetRedirectLocation(HttpResponseMessage response, out string location)
{
IEnumerable<string> locations;
if (RedirectStatusCodes.Contains(response.StatusCode) &&
response.Headers.TryGetValues("Location", out locations) &&
(locations = locations.ToArray()).Count() == 1 &&
!string.IsNullOrWhiteSpace(locations.First()))
{
location = locations.First().Trim();
return true;
}
location = null;
return false;
}
}
}
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Knapcode.Http.Handlers;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
using Moq.Protected;
namespace Knapcode.Http.Tests.Handlers
{
[TestClass]
public class RedirectingHandlerTests
{
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithTooManyRedirects_StopsRedirecting()
{
// ARRANGE
const HttpStatusCode statusCode = HttpStatusCode.TemporaryRedirect;
var client = GetHttpClient(
configure: handler => handler.MaxAutomaticRedirections = 5,
redirectCount: 6,
statusCode: statusCode);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.StatusCode.Should().Be(statusCode);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithNoLocationHeader_DoesNotRedirect()
{
// ARRANGE
const HttpStatusCode statusCode = HttpStatusCode.TemporaryRedirect;
var client = GetHttpClient(
redirectUri: new Uri(string.Empty, UriKind.Relative),
statusCode: statusCode);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.StatusCode.Should().Be(statusCode);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDisabledRedirect_DoesNotRedirect()
{
// ARRANGE
const HttpStatusCode statusCode = HttpStatusCode.TemporaryRedirect;
var client = GetHttpClient(
configure: handler => handler.AllowAutoRedirect = false,
statusCode: statusCode);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.StatusCode.Should().Be(statusCode);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDisabledHistory_DoesKeepHistory()
{
// ARRANGE
var client = GetHttpClient(configure: handler => handler.KeepResponseHistory = false);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.RequestMessage.Properties.ContainsKey(RedirectingHandler.HistoryPropertyKey).Should().BeFalse();
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithContentHeaders_CopiesContentHeaders()
{
// ARRANGE
var client = GetHttpClient(HttpStatusCode.TemporaryRedirect);
var request = GetRequest();
request.Method = HttpMethod.Post;
request.Content = new StringContent("foo");
const string headerKey = "X-Foo";
const string headerValue = "bar";
request.Content.Headers.Add(headerKey, headerValue);
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
string[] values = httpResponseMessage.RequestMessage.Content.Headers.GetValues(headerKey).ToArray();
values.Should().HaveCount(1);
values.Should().BeEquivalentTo(headerValue);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithRedirects_KeepsHistory()
{
// ARRANGE
const int redirectCount = 5;
const HttpStatusCode statusCode = HttpStatusCode.TemporaryRedirect;
var client = GetHttpClient(redirectCount: redirectCount, statusCode: statusCode);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.RequestMessage.Properties.ContainsKey(RedirectingHandler.HistoryPropertyKey).Should().BeTrue();
object value = httpResponseMessage.RequestMessage.Properties[RedirectingHandler.HistoryPropertyKey];
value.Should().BeAssignableTo<IEnumerable<HttpResponseMessage>>();
HttpResponseMessage[] responses = ((IEnumerable<HttpResponseMessage>) value).ToArray();
responses.Should().HaveCount(redirectCount + 1);
responses.Take(redirectCount).Should().OnlyContain(r => r.StatusCode == statusCode);
responses.Skip(redirectCount).Take(1).Should().OnlyContain(r => r.StatusCode == HttpStatusCode.OK);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeaders_CopiesHeaders()
{
// ARRANGE
var client = GetHttpClient();
var request = GetRequest();
const string headerKey = "X-Foo";
const string headerValue = "bar";
request.Headers.Add(headerKey, headerValue);
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
string[] values = httpResponseMessage.RequestMessage.Headers.GetValues(headerKey).ToArray();
values.Should().HaveCount(1);
values.Should().BeEquivalentTo(headerValue);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithNoSchemeInRedirect_UsesRequestUriScheme()
{
// ARRANGE
var client = GetHttpClient(redirectUri: new Uri("//www.example.com/2", UriKind.Relative));
var request = GetRequest();
request.RequestUri = new Uri("https://www.example.com/1");
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.RequestMessage.RequestUri.Should().Be(new Uri("https://www.example.com/2"));
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithRelativeRedirect_ResolvedAgainstRequestUri()
{
// ARRANGE
var client = GetHttpClient(redirectUri: new Uri("../c/e/../d.txt", UriKind.Relative));
var request = GetRequest();
request.RequestUri = new Uri("https://www.example.com/1/2/3/4.txt");
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.RequestMessage.RequestUri.Should().Be(new Uri("https://www.example.com/1/2/c/d.txt"));
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd301_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Post,
"foo",
HttpStatusCode.Moved,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd301_DuplicatesRequestWithoutContent()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Put,
"foo",
HttpStatusCode.Moved,
HttpMethod.Put,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd301_DuplicatesRequestWithoutContent()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Delete,
"foo",
HttpStatusCode.Moved,
HttpMethod.Delete,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd301_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, HttpStatusCode.Moved);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd301_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, HttpStatusCode.Moved);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd302_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Post,
"foo",
HttpStatusCode.Found,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd302_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Put,
"foo",
HttpStatusCode.Found,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd302_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Delete,
"foo",
HttpStatusCode.Found,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd302_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, HttpStatusCode.Found);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd302_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, HttpStatusCode.Found);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd303_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Post,
"foo",
HttpStatusCode.SeeOther,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd303_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Put,
"foo",
HttpStatusCode.SeeOther,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd303_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Delete,
"foo",
HttpStatusCode.SeeOther,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd303_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, HttpStatusCode.SeeOther);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd303_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, HttpStatusCode.SeeOther);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Post, "foo", HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Put, "foo", HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Delete, "foo", HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Post, "foo", (HttpStatusCode)308);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Put, "foo", (HttpStatusCode)308);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Delete, "foo", (HttpStatusCode)308);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, (HttpStatusCode)308);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, (HttpStatusCode)308);
}
private static async Task SendAsync_WithRedirect_DuplicatesRequest(HttpMethod httpMethod, string content, HttpStatusCode httpStatusCode)
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
httpMethod,
content,
httpStatusCode,
httpMethod,
content);
}
private static async Task SendAsync_WithRedirect_MakesSubsequentRequest(HttpMethod initialMethod, string initialContent, HttpStatusCode httpStatusCode, HttpMethod expectedMethod, string expectedContent)
{
var content = new StubbedHttpContent(string.Empty);
if (initialContent != null)
{
content = new StubbedHttpContent(initialContent);
}
await SendAsync_WithRedirect_MakesNewRequest(
initialMethod,
content,
httpStatusCode,
async request =>
{
request.Method.Should().Be(expectedMethod);
if (expectedContent == null)
{
request.Content.Should().BeNull();
}
else
{
content.SerializeToStreamAsyncCalls.Should().Be(1);
request.Content.Should().NotBeNull();
string actualContent = await request.Content.ReadAsStringAsync();
actualContent.Should().Be(expectedContent);
}
});
}
private static async Task SendAsync_WithRedirect_MakesNewRequest(HttpMethod initialMethod, HttpContent content, HttpStatusCode statusCode, Action<HttpRequestMessage> validateRequest)
{
// ARRANGE
var client = GetHttpClient(statusCode);
var request = GetRequest();
request.Method = initialMethod;
request.Content = content;
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.StatusCode.Should().Be(HttpStatusCode.OK);
validateRequest(httpResponseMessage.RequestMessage);
}
private static HttpResponseMessage GetOkHttpResponseMessage()
{
return new HttpResponseMessage(HttpStatusCode.OK);
}
private static HttpResponseMessage GetRedirectHttpResponseMessage(HttpStatusCode statusCode, Uri redirectUri)
{
var response = new HttpResponseMessage(statusCode);
response.Headers.Location = redirectUri;
return response;
}
private static Mock<HttpMessageHandler> GetHttpMessageHandlerMock(Queue<HttpResponseMessage> responseQueue)
{
var mock = new Mock<HttpMessageHandler>();
mock
.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>((request, token) =>
{
HttpResponseMessage response = responseQueue.Dequeue();
response.RequestMessage = request;
return Task.FromResult(response);
});
return mock;
}
private static HttpRequestMessage GetRequest()
{
return new HttpRequestMessage(HttpMethod.Get, "http://www.example.com/1");
}
private static HttpClient GetHttpClient(HttpStatusCode statusCode = HttpStatusCode.Moved, Uri redirectUri = null, int redirectCount = 1, Action<RedirectingHandler> configure = null)
{
if (redirectUri == null)
{
redirectUri = new Uri("http://www.example.com/2", UriKind.Absolute);
}
var responses = Enumerable
.Range(0, redirectCount)
.Select(i => GetRedirectHttpResponseMessage(statusCode, redirectUri))
.Concat(new[]
{
GetOkHttpResponseMessage()
});
var responseQueue = new Queue<HttpResponseMessage>(responses);
Mock<HttpMessageHandler> httpMessageHandlerMock = GetHttpMessageHandlerMock(responseQueue);
var handler = new RedirectingHandler { InnerHandler = httpMessageHandlerMock.Object };
if (configure != null)
{
configure(handler);
}
var client = new HttpClient(handler);
return client;
}
private class StubbedHttpContent : StringContent
{
public StubbedHttpContent(string content)
: base(content)
{
}
public int SerializeToStreamAsyncCalls { get; private set; }
protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
{
SerializeToStreamAsyncCalls++;
return base.SerializeToStreamAsync(stream, context);
}
}
}
}
@SteveL-MSFT
Copy link

What is the license for this code? I'd like to use it in https://github.com/powershell/powershell. Thanks!

@joelverhagen
Copy link
Author

For future readers, the license is MIT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment