Skip to content

Instantly share code, notes, and snippets.

@micw
Created September 17, 2015 10:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save micw/a8b20720db900762a29d to your computer and use it in GitHub Desktop.
Save micw/a8b20720db900762a29d to your computer and use it in GitHub Desktop.
WebMvcRequestDispatcherForwardFix
package testutils.mvc;
import static org.mockito.Matchers.anyString;
import javax.servlet.Filter;
import javax.servlet.RequestDispatcher;
import javax.servlet.Servlet;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequestWrapper;
import org.mockito.AdditionalAnswers;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockRequestDispatcher;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.test.web.servlet.MockMvc;
import org.springframework.test.web.servlet.request.MockHttpServletRequestBuilder;
import org.springframework.test.web.servlet.request.MockMvcRequestBuilders;
import org.springframework.test.web.servlet.request.RequestPostProcessor;
public class WebMvcRequestDispatcherForwardFix
{
/**
* adds a RequestPostProcessor to mockMvc that modifies the MockServletRequest to use a
* new version of MockRequestDispatcher that does correct forwarding
*
* This fix reads/modifies some internal fields of MockMvc, so changes in MockMvc here will probably break
* this fix.
*/
public static void apply(MockMvc mockMvc)
{
MockHttpServletRequestBuilder defaultRequestBuilder=(MockHttpServletRequestBuilder) ReflectionTestUtils.getField(mockMvc, "defaultRequestBuilder");
if (defaultRequestBuilder==null)
{
defaultRequestBuilder=MockMvcRequestBuilders.get("/");
ReflectionTestUtils.setField(mockMvc, "defaultRequestBuilder", defaultRequestBuilder);
}
defaultRequestBuilder.with(new RequestDispatcherReplacingRequestPostProcessor(mockMvc));
}
/**
* Uses mockito to replace the MockHttpServletRequest.getRequestDispatcher to return MockRequestDispatcherWithForwardFix
*/
protected static class RequestDispatcherReplacingRequestPostProcessor implements RequestPostProcessor
{
protected final MockMvc mockMvc;
public RequestDispatcherReplacingRequestPostProcessor(MockMvc mockMvc)
{
this.mockMvc=mockMvc;
}
@Override
public MockHttpServletRequest postProcessRequest(MockHttpServletRequest request)
{
MockHttpServletRequest mock=Mockito.mock(MockHttpServletRequest.class,AdditionalAnswers.delegatesTo(request));
Mockito.when(mock.getRequestDispatcher(anyString())).thenAnswer(new Answer<RequestDispatcher>() {
@Override
public RequestDispatcher answer(InvocationOnMock invocation) throws Throwable
{
return new MockRequestDispatcherWithForwardFix(mockMvc,invocation.getArgumentAt(0, String.class));
}
});
return mock;
}
}
/**
* This version of MockRequestDispatcher handles forward() calls by:
* - reading servlet and filters from MockMvc private variables
* - create the same MockFilterChain that is used for normal requests in MockMvc
* - sets requestURI of the original request to the forward destination
* - passes request/response to the chain
* This should result in a similar behaviour as if forward() would be called within a regular web container.
*/
protected static class MockRequestDispatcherWithForwardFix extends MockRequestDispatcher
{
protected final MockMvc mockMvc;
protected final String resource;
public MockRequestDispatcherWithForwardFix(MockMvc mockMvc, String resource)
{
super(resource);
this.mockMvc=mockMvc;
this.resource=resource;
}
protected MockHttpServletRequest getMockHttpServletRequest(ServletRequest response) {
if (response instanceof MockHttpServletRequest) {
return (MockHttpServletRequest) response;
}
if (response instanceof HttpServletRequestWrapper) {
return getMockHttpServletRequest(((HttpServletRequestWrapper) response).getRequest());
}
throw new IllegalArgumentException("MockRequestDispatcher requires MockHttpServletRequest");
}
@Override
public void forward(ServletRequest request, ServletResponse response)
{
try
{
Servlet servlet=(Servlet)ReflectionTestUtils.getField(mockMvc, "servlet");
Filter[] filters=(Filter[])ReflectionTestUtils.getField(mockMvc, "filters");
MockFilterChain chain = new MockFilterChain(servlet,filters);
getMockHttpServletRequest(request).setRequestURI(resource);
chain.doFilter(request, response);
}
catch (Exception ex)
{
throw new RuntimeException(ex);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment