Skip to content

Instantly share code, notes, and snippets.

@aksh1618
Created December 28, 2021 09:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aksh1618/b26cf732808092bd8a0e926adcb6b79f to your computer and use it in GitHub Desktop.
Save aksh1618/b26cf732808092bd8a0e926adcb6b79f to your computer and use it in GitHub Desktop.
A variation of ContentCachingRequestWrapper for JSON POST data instead of FORM POST data
/*
* Copyright 2021 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
import org.springframework.http.HttpMethod;
import org.springframework.http.MediaType;
import org.springframework.lang.Nullable;
import org.springframework.web.util.WebUtils;
import javax.servlet.ReadListener;
import javax.servlet.ServletInputStream;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import static org.springframework.http.MediaType.APPLICATION_JSON;
/**
* {@link HttpServletRequest} wrapper that caches all content read from
* the {@linkplain #getInputStream() input stream} and {@linkplain #getReader() reader}
* for all further calls to the same methods, provided the request is a POST request
* with json content type.
*
* @author Aakarshit Uppal
* @see org.springframework.web.util.ContentCachingRequestWrapper
*/
public class JsonContentCachingRequestWrapper extends HttpServletRequestWrapper {
private final ByteArrayOutputStream cachedContent;
@Nullable
private final Integer contentCacheLimit;
@Nullable
private ServletInputStream inputStream;
@Nullable
private BufferedReader reader;
/**
* Create a new JsonContentCachingRequestWrapper for the given servlet request.
*
* @param request the original servlet request
*/
public JsonContentCachingRequestWrapper(HttpServletRequest request) {
super(request);
int contentLength = request.getContentLength();
this.cachedContent = new ByteArrayOutputStream(contentLength >= 0 ? contentLength : 1024);
this.contentCacheLimit = null;
}
// /**
// * Create a new JsonContentCachingRequestWrapper for the given servlet request.
// *
// * @param request the original servlet request
// * @param contentCacheLimit the maximum number of bytes to cache per request
// * @see #handleContentOverflow(int)
// * @since 4.3.6
// */
// public JsonContentCachingRequestWrapper(HttpServletRequest request, int contentCacheLimit) {
// super(request);
// this.cachedContent = new ByteArrayOutputStream(contentCacheLimit);
// this.contentCacheLimit = contentCacheLimit;
// }
@Override
public ServletInputStream getInputStream() throws IOException {
if (isJsonPost()) {
// if (this.cachedContent.size() == 0) writeInputToCachedContent();
if (this.inputStream == null) {
this.inputStream = new ContentCachingInputStream(getRequest().getInputStream());
}
if (this.inputStream.isFinished()) {
this.inputStream = getServletInputStreamForBytes(getContentAsByteArray());
}
return this.inputStream;
} else {
return super.getInputStream();
}
}
private ServletInputStream getServletInputStreamForBytes(byte[] contentBytes) {
// Taken from https://stackoverflow.com/a/33836552/6346531
return new ServletInputStream() {
private int lastIndexRetrieved = -1;
private ReadListener readListener = null;
@Override
public int available() {
return (contentBytes.length - lastIndexRetrieved - 1);
}
@Override
public void close() {
lastIndexRetrieved = contentBytes.length - 1;
}
@Override
public boolean isFinished() {
return (lastIndexRetrieved == contentBytes.length - 1);
}
@Override
public boolean isReady() {
// This implementation will never block
// We also never need to call the readListener from this method, as this method will never return false
return isFinished();
}
@Override
public void setReadListener(ReadListener readListener) {
this.readListener = readListener;
if (!isFinished()) {
try {
readListener.onDataAvailable();
} catch (IOException e) {
readListener.onError(e);
}
} else {
try {
readListener.onAllDataRead();
} catch (IOException e) {
readListener.onError(e);
}
}
}
@Override
public int read() throws IOException {
int i;
if (!isFinished()) {
i = contentBytes[lastIndexRetrieved + 1];
lastIndexRetrieved++;
if (isFinished() && (readListener != null)) {
try {
readListener.onAllDataRead();
} catch (IOException ex) {
readListener.onError(ex);
throw ex;
}
}
return i;
} else {
return -1;
}
}
};
}
@Override
public String getCharacterEncoding() {
String enc = super.getCharacterEncoding();
return (enc != null ? enc : WebUtils.DEFAULT_CHARACTER_ENCODING);
}
@Override
public BufferedReader getReader() throws IOException {
if (this.reader == null) {
this.reader = new BufferedReader(new InputStreamReader(getInputStream(), getCharacterEncoding()));
}
return this.reader;
}
private boolean isJsonPost() {
String contentType = getContentType();
return (contentType != null
&& APPLICATION_JSON.isCompatibleWith(MediaType.parseMediaType(contentType))
&& HttpMethod.POST.matches(getMethod()));
}
// private void writeInputToCachedContent() {
// try {
// if (this.cachedContent.size() == 0) {
// String requestEncoding = getCharacterEncoding();
// Map<String, String[]> form = super.getParameterMap();
// for (Iterator<String> nameIterator = form.keySet().iterator(); nameIterator.hasNext(); ) {
// String name = nameIterator.next();
// List<String> values = Arrays.asList(form.get(name));
// for (Iterator<String> valueIterator = values.iterator(); valueIterator.hasNext(); ) {
// String value = valueIterator.next();
// this.cachedContent.write(URLEncoder.encode(name, requestEncoding).getBytes());
// if (value != null) {
// this.cachedContent.write('=');
// this.cachedContent.write(URLEncoder.encode(value, requestEncoding).getBytes());
// if (valueIterator.hasNext()) {
// this.cachedContent.write('&');
// }
// }
// }
// if (nameIterator.hasNext()) {
// this.cachedContent.write('&');
// }
// }
// }
// } catch (IOException ex) {
// throw new IllegalStateException("Failed to write request parameters to cached content", ex);
// }
// }
/**
* Return the cached request content as a byte array. (The content must have been read once for it to be cached)
*/
// * <p>The returned array will never be larger than the content cache limit.
// *
// * @see #JsonContentCachingRequestWrapper(HttpServletRequest, int)
public byte[] getContentAsByteArray() {
// TODO: Write to cache if not present already?
// if (this.cachedContent.size() == 0) {
// }
return this.cachedContent.toByteArray();
}
// /**
// * Template method for handling a content overflow: specifically, a request
// * body being read that exceeds the specified content cache limit.
// * <p>The default implementation is empty. Subclasses may override this to
// * throw a payload-too-large exception or the like.
// *
// * @param contentCacheLimit the maximum number of bytes to cache per request
// * which has just been exceeded
// * @see #JsonContentCachingRequestWrapper(HttpServletRequest, int)
// * @since 4.3.6
// */
// protected void handleContentOverflow(int contentCacheLimit) {
// }
private class ContentCachingInputStream extends ServletInputStream {
private final ServletInputStream is;
private boolean overflow = false;
public ContentCachingInputStream(ServletInputStream is) {
this.is = is;
}
@Override
public int read() throws IOException {
int ch = this.is.read();
if (ch != -1 && !this.overflow) {
if (contentCacheLimit != null && cachedContent.size() == contentCacheLimit) {
// this.overflow = true;
// handleContentOverflow(contentCacheLimit);
} else {
cachedContent.write(ch);
}
}
return ch;
}
@Override
public boolean isFinished() {
return this.is.isFinished();
}
@Override
public boolean isReady() {
return this.is.isReady();
}
@Override
public void setReadListener(ReadListener readListener) {
this.is.setReadListener(readListener);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment