Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
RedirectingHandler
using System;
using System.Collections.Generic;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
namespace Knapcode.Http.Handlers
{
/// <summary>
/// A delegating handler that handles HTTP redirects (301, 302, 303, 307, and 308).
/// </summary>
public class RedirectingHandler : DelegatingHandler
{
/// <summary>
/// The property key used to access the list of responses.
/// </summary>
public const string HistoryPropertyKey = "Knapcode.Http.Handlers.RedirectingHandler.ResponseHistory";
private static readonly ISet<HttpStatusCode> RedirectStatusCodes = new HashSet<HttpStatusCode>(new[]
{
HttpStatusCode.MovedPermanently,
HttpStatusCode.Found,
HttpStatusCode.SeeOther,
HttpStatusCode.TemporaryRedirect,
(HttpStatusCode) 308
});
private static readonly ISet<HttpStatusCode> KeepRequestBodyRedirectStatusCodes = new HashSet<HttpStatusCode>(new[]
{
HttpStatusCode.TemporaryRedirect,
(HttpStatusCode) 308
});
/// <summary>
/// Initializes a new instance of the <see cref="RedirectingHandler"/> class.
/// </summary>
public RedirectingHandler()
{
AllowAutoRedirect = true;
MaxAutomaticRedirections = 50;
DisableInnerAutoRedirect = true;
DownloadContentOnRedirect = true;
KeepResponseHistory = true;
}
/// <summary>
/// Gets or sets a value that indicates whether the handler should follow redirection responses.
/// </summary>
public bool AllowAutoRedirect { get; set; }
/// <summary>
/// Gets or sets the maximum number of redirects that the handler follows.
/// </summary>
public int MaxAutomaticRedirections { get; set; }
/// <summary>
/// Gets or sets a value indicating whether the response body should be downloaded before each redirection.
/// </summary>
public bool DownloadContentOnRedirect { get; set; }
/// <summary>
/// Gets or sets a value indicating inner redirections on <see cref="HttpClientHandler"/> and <see cref="RedirectingHandler"/> should be disabled.
/// </summary>
public bool DisableInnerAutoRedirect { get; set; }
/// <summary>
/// Gets or sets a value indicating whether the response history should be saved to the <see cref="HttpResponseMessage.RequestMessage"/> properties with the key of <see cref="HistoryPropertyKey"/>.
/// </summary>
public bool KeepResponseHistory { get; set; }
protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
if (DisableInnerAutoRedirect)
{
// find the inner-most handler
HttpMessageHandler innerHandler = InnerHandler;
while (innerHandler is DelegatingHandler)
{
var redirectingHandler = innerHandler as RedirectingHandler;
if (redirectingHandler != null)
{
redirectingHandler.AllowAutoRedirect = false;
}
innerHandler = ((DelegatingHandler) innerHandler).InnerHandler;
}
var httpClientHandler = innerHandler as HttpClientHandler;
if (httpClientHandler != null)
{
httpClientHandler.AllowAutoRedirect = false;
}
}
// buffer the request body, to allow re-use in redirects
HttpContent requestBody = null;
if (AllowAutoRedirect && request.Content != null)
{
byte[] buffer = await request.Content.ReadAsByteArrayAsync();
requestBody = new ByteArrayContent(buffer);
foreach (var header in request.Content.Headers)
{
requestBody.Headers.Add(header.Key, header.Value);
}
}
// make a copy of the request headers
KeyValuePair<string, string[]>[] requestHeaders = request
.Headers
.Select(p => new KeyValuePair<string, string[]>(p.Key, p.Value.ToArray()))
.ToArray();
// send the initial request
HttpResponseMessage response = await base.SendAsync(request, cancellationToken);
var responses = new List<HttpResponseMessage>();
int redirectCount = 0;
string locationString;
while (AllowAutoRedirect && redirectCount < MaxAutomaticRedirections && TryGetRedirectLocation(response, out locationString))
{
if (DownloadContentOnRedirect && response.Content != null)
{
await response.Content.ReadAsByteArrayAsync();
}
Uri previousRequestUri = response.RequestMessage.RequestUri;
// Credit where credit is due: https://github.com/kennethreitz/requests/blob/master/requests/sessions.py
// allow redirection without a scheme
if (locationString.StartsWith("//"))
{
locationString = previousRequestUri.Scheme + ":" + locationString;
}
var nextRequestUri = new Uri(locationString, UriKind.RelativeOrAbsolute);
// allow relative redirects
if (!nextRequestUri.IsAbsoluteUri)
{
nextRequestUri = new Uri(previousRequestUri, nextRequestUri);
}
// override previous method
HttpMethod nextMethod = response.RequestMessage.Method;
if ((response.StatusCode == HttpStatusCode.Moved && nextMethod == HttpMethod.Post) ||
(response.StatusCode == HttpStatusCode.Found && nextMethod != HttpMethod.Head) ||
(response.StatusCode == HttpStatusCode.SeeOther && nextMethod != HttpMethod.Head))
{
nextMethod = HttpMethod.Get;
requestBody = null;
}
if (!KeepRequestBodyRedirectStatusCodes.Contains(response.StatusCode))
{
requestBody = null;
}
// build the next request
var nextRequest = new HttpRequestMessage(nextMethod, nextRequestUri)
{
Content = requestBody,
Version = request.Version
};
foreach (var header in requestHeaders)
{
nextRequest.Headers.Add(header.Key, header.Value);
}
foreach (var pair in request.Properties)
{
nextRequest.Properties.Add(pair.Key, pair.Value);
}
// keep a history all responses
if (KeepResponseHistory)
{
responses.Add(response);
}
// send the next request
response = await base.SendAsync(nextRequest, cancellationToken);
request = response.RequestMessage;
redirectCount++;
}
// save the history to the request message properties
if (KeepResponseHistory && response.RequestMessage != null)
{
responses.Add(response);
response.RequestMessage.Properties.Add(HistoryPropertyKey, responses);
}
return response;
}
private static bool TryGetRedirectLocation(HttpResponseMessage response, out string location)
{
IEnumerable<string> locations;
if (RedirectStatusCodes.Contains(response.StatusCode) &&
response.Headers.TryGetValues("Location", out locations) &&
(locations = locations.ToArray()).Count() == 1 &&
!string.IsNullOrWhiteSpace(locations.First()))
{
location = locations.First().Trim();
return true;
}
location = null;
return false;
}
}
}
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;
using FluentAssertions;
using Knapcode.Http.Handlers;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Moq;
using Moq.Protected;
namespace Knapcode.Http.Tests.Handlers
{
[TestClass]
public class RedirectingHandlerTests
{
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithTooManyRedirects_StopsRedirecting()
{
// ARRANGE
const HttpStatusCode statusCode = HttpStatusCode.TemporaryRedirect;
var client = GetHttpClient(
configure: handler => handler.MaxAutomaticRedirections = 5,
redirectCount: 6,
statusCode: statusCode);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.StatusCode.Should().Be(statusCode);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithNoLocationHeader_DoesNotRedirect()
{
// ARRANGE
const HttpStatusCode statusCode = HttpStatusCode.TemporaryRedirect;
var client = GetHttpClient(
redirectUri: new Uri(string.Empty, UriKind.Relative),
statusCode: statusCode);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.StatusCode.Should().Be(statusCode);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDisabledRedirect_DoesNotRedirect()
{
// ARRANGE
const HttpStatusCode statusCode = HttpStatusCode.TemporaryRedirect;
var client = GetHttpClient(
configure: handler => handler.AllowAutoRedirect = false,
statusCode: statusCode);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.StatusCode.Should().Be(statusCode);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDisabledHistory_DoesKeepHistory()
{
// ARRANGE
var client = GetHttpClient(configure: handler => handler.KeepResponseHistory = false);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.RequestMessage.Properties.ContainsKey(RedirectingHandler.HistoryPropertyKey).Should().BeFalse();
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithContentHeaders_CopiesContentHeaders()
{
// ARRANGE
var client = GetHttpClient(HttpStatusCode.TemporaryRedirect);
var request = GetRequest();
request.Method = HttpMethod.Post;
request.Content = new StringContent("foo");
const string headerKey = "X-Foo";
const string headerValue = "bar";
request.Content.Headers.Add(headerKey, headerValue);
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
string[] values = httpResponseMessage.RequestMessage.Content.Headers.GetValues(headerKey).ToArray();
values.Should().HaveCount(1);
values.Should().BeEquivalentTo(headerValue);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithRedirects_KeepsHistory()
{
// ARRANGE
const int redirectCount = 5;
const HttpStatusCode statusCode = HttpStatusCode.TemporaryRedirect;
var client = GetHttpClient(redirectCount: redirectCount, statusCode: statusCode);
var request = GetRequest();
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.RequestMessage.Properties.ContainsKey(RedirectingHandler.HistoryPropertyKey).Should().BeTrue();
object value = httpResponseMessage.RequestMessage.Properties[RedirectingHandler.HistoryPropertyKey];
value.Should().BeAssignableTo<IEnumerable<HttpResponseMessage>>();
HttpResponseMessage[] responses = ((IEnumerable<HttpResponseMessage>) value).ToArray();
responses.Should().HaveCount(redirectCount + 1);
responses.Take(redirectCount).Should().OnlyContain(r => r.StatusCode == statusCode);
responses.Skip(redirectCount).Take(1).Should().OnlyContain(r => r.StatusCode == HttpStatusCode.OK);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeaders_CopiesHeaders()
{
// ARRANGE
var client = GetHttpClient();
var request = GetRequest();
const string headerKey = "X-Foo";
const string headerValue = "bar";
request.Headers.Add(headerKey, headerValue);
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
string[] values = httpResponseMessage.RequestMessage.Headers.GetValues(headerKey).ToArray();
values.Should().HaveCount(1);
values.Should().BeEquivalentTo(headerValue);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithNoSchemeInRedirect_UsesRequestUriScheme()
{
// ARRANGE
var client = GetHttpClient(redirectUri: new Uri("//www.example.com/2", UriKind.Relative));
var request = GetRequest();
request.RequestUri = new Uri("https://www.example.com/1");
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.RequestMessage.RequestUri.Should().Be(new Uri("https://www.example.com/2"));
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithRelativeRedirect_ResolvedAgainstRequestUri()
{
// ARRANGE
var client = GetHttpClient(redirectUri: new Uri("../c/e/../d.txt", UriKind.Relative));
var request = GetRequest();
request.RequestUri = new Uri("https://www.example.com/1/2/3/4.txt");
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.RequestMessage.RequestUri.Should().Be(new Uri("https://www.example.com/1/2/c/d.txt"));
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd301_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Post,
"foo",
HttpStatusCode.Moved,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd301_DuplicatesRequestWithoutContent()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Put,
"foo",
HttpStatusCode.Moved,
HttpMethod.Put,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd301_DuplicatesRequestWithoutContent()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Delete,
"foo",
HttpStatusCode.Moved,
HttpMethod.Delete,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd301_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, HttpStatusCode.Moved);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd301_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, HttpStatusCode.Moved);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd302_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Post,
"foo",
HttpStatusCode.Found,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd302_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Put,
"foo",
HttpStatusCode.Found,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd302_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Delete,
"foo",
HttpStatusCode.Found,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd302_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, HttpStatusCode.Found);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd302_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, HttpStatusCode.Found);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd303_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Post,
"foo",
HttpStatusCode.SeeOther,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd303_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Put,
"foo",
HttpStatusCode.SeeOther,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd303_MakesGetRequest()
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
HttpMethod.Delete,
"foo",
HttpStatusCode.SeeOther,
HttpMethod.Get,
null);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd303_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, HttpStatusCode.SeeOther);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd303_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, HttpStatusCode.SeeOther);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Post, "foo", HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Put, "foo", HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Delete, "foo", HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd307_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, HttpStatusCode.TemporaryRedirect);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPostAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Post, "foo", (HttpStatusCode)308);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithPutAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Put, "foo", (HttpStatusCode)308);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithDeleteAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Delete, "foo", (HttpStatusCode)308);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithGetAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Get, null, (HttpStatusCode)308);
}
[TestMethod, TestCategory("Unit")]
public async Task SendAsync_WithHeadAnd308_DuplicatesRequest()
{
await SendAsync_WithRedirect_DuplicatesRequest(HttpMethod.Head, null, (HttpStatusCode)308);
}
private static async Task SendAsync_WithRedirect_DuplicatesRequest(HttpMethod httpMethod, string content, HttpStatusCode httpStatusCode)
{
await SendAsync_WithRedirect_MakesSubsequentRequest(
httpMethod,
content,
httpStatusCode,
httpMethod,
content);
}
private static async Task SendAsync_WithRedirect_MakesSubsequentRequest(HttpMethod initialMethod, string initialContent, HttpStatusCode httpStatusCode, HttpMethod expectedMethod, string expectedContent)
{
var content = new StubbedHttpContent(string.Empty);
if (initialContent != null)
{
content = new StubbedHttpContent(initialContent);
}
await SendAsync_WithRedirect_MakesNewRequest(
initialMethod,
content,
httpStatusCode,
async request =>
{
request.Method.Should().Be(expectedMethod);
if (expectedContent == null)
{
request.Content.Should().BeNull();
}
else
{
content.SerializeToStreamAsyncCalls.Should().Be(1);
request.Content.Should().NotBeNull();
string actualContent = await request.Content.ReadAsStringAsync();
actualContent.Should().Be(expectedContent);
}
});
}
private static async Task SendAsync_WithRedirect_MakesNewRequest(HttpMethod initialMethod, HttpContent content, HttpStatusCode statusCode, Action<HttpRequestMessage> validateRequest)
{
// ARRANGE
var client = GetHttpClient(statusCode);
var request = GetRequest();
request.Method = initialMethod;
request.Content = content;
// ACT
HttpResponseMessage httpResponseMessage = await client.SendAsync(request);
// ASSERT
httpResponseMessage.StatusCode.Should().Be(HttpStatusCode.OK);
validateRequest(httpResponseMessage.RequestMessage);
}
private static HttpResponseMessage GetOkHttpResponseMessage()
{
return new HttpResponseMessage(HttpStatusCode.OK);
}
private static HttpResponseMessage GetRedirectHttpResponseMessage(HttpStatusCode statusCode, Uri redirectUri)
{
var response = new HttpResponseMessage(statusCode);
response.Headers.Location = redirectUri;
return response;
}
private static Mock<HttpMessageHandler> GetHttpMessageHandlerMock(Queue<HttpResponseMessage> responseQueue)
{
var mock = new Mock<HttpMessageHandler>();
mock
.Protected()
.Setup<Task<HttpResponseMessage>>("SendAsync", ItExpr.IsAny<HttpRequestMessage>(), ItExpr.IsAny<CancellationToken>())
.Returns<HttpRequestMessage, CancellationToken>((request, token) =>
{
HttpResponseMessage response = responseQueue.Dequeue();
response.RequestMessage = request;
return Task.FromResult(response);
});
return mock;
}
private static HttpRequestMessage GetRequest()
{
return new HttpRequestMessage(HttpMethod.Get, "http://www.example.com/1");
}
private static HttpClient GetHttpClient(HttpStatusCode statusCode = HttpStatusCode.Moved, Uri redirectUri = null, int redirectCount = 1, Action<RedirectingHandler> configure = null)
{
if (redirectUri == null)
{
redirectUri = new Uri("http://www.example.com/2", UriKind.Absolute);
}
var responses = Enumerable
.Range(0, redirectCount)
.Select(i => GetRedirectHttpResponseMessage(statusCode, redirectUri))
.Concat(new[]
{
GetOkHttpResponseMessage()
});
var responseQueue = new Queue<HttpResponseMessage>(responses);
Mock<HttpMessageHandler> httpMessageHandlerMock = GetHttpMessageHandlerMock(responseQueue);
var handler = new RedirectingHandler { InnerHandler = httpMessageHandlerMock.Object };
if (configure != null)
{
configure(handler);
}
var client = new HttpClient(handler);
return client;
}
private class StubbedHttpContent : StringContent
{
public StubbedHttpContent(string content)
: base(content)
{
}
public int SerializeToStreamAsyncCalls { get; private set; }
protected override Task SerializeToStreamAsync(Stream stream, TransportContext context)
{
SerializeToStreamAsyncCalls++;
return base.SerializeToStreamAsync(stream, context);
}
}
}
}
@SteveL-MSFT

This comment has been minimized.

Copy link

@SteveL-MSFT SteveL-MSFT commented Mar 7, 2017

What is the license for this code? I'd like to use it in https://github.com/powershell/powershell. Thanks!

@joelverhagen

This comment has been minimized.

Copy link
Owner Author

@joelverhagen joelverhagen commented Jan 10, 2021

For future readers, the license is MIT.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment