Skip to content

Instantly share code, notes, and snippets.

@cchacin
Created December 20, 2018 16:48
Show Gist options
  • Save cchacin/af90d76a7e8e2f5db5a9564be60b02d5 to your computer and use it in GitHub Desktop.
Save cchacin/af90d76a7e8e2f5db5a9564be60b02d5 to your computer and use it in GitHub Desktop.
Http2Client => OkHttp Call.Factory
package cronus.core;
import java.io.IOException;
import java.text.SimpleDateFormat;
import java.util.Arrays;
import java.util.Date;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Handshake;
import okhttp3.Headers;
import okhttp3.HttpUrl;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.WebSocket;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
import org.assertj.core.api.Assertions;
import org.junit.Rule;
import org.junit.Test;
public final class CallTest extends Assertions {
@Rule
public final MockWebServer server = new MockWebServer();
@Rule
public final MockWebServer server2 = new MockWebServer();
private final Call.Factory client = new Http2Client();
@Test
public void get() throws Exception {
this.server.enqueue(new MockResponse()
.setBody("abc")
.clearHeaders()
.addHeader("content-type: text/plain")
.addHeader("content-length", "3"));
final long sentAt = System.currentTimeMillis();
final RecordedResponse recordedResponse = this.executeSynchronously("/", "User-Agent", "SyncApiTest");
final long receivedAt = System.currentTimeMillis();
recordedResponse.assertCode(200)
.assertSuccessful()
.assertHeaders(new Headers.Builder()
.add("content-type", "text/plain")
.add("content-length", "3")
.build())
.assertBody("abc");
// .assertSentRequestAtMillis(sentAt, receivedAt)
// .assertReceivedResponseAtMillis(sentAt, receivedAt);
final RecordedRequest recordedRequest = this.server.takeRequest();
assertThat(recordedRequest.getMethod()).isEqualTo("GET");
assertThat(recordedRequest.getHeader("User-Agent")).isEqualTo("SyncApiTest");
assertThat(recordedRequest.getBody().size()).isEqualTo(0);
assertThat(recordedRequest.getHeader("Content-Length")).isEqualTo("0");
}
private RecordedResponse executeSynchronously(final String path, final String... headers) throws IOException {
final Request.Builder builder = new Request.Builder();
builder.url(this.server.url(path));
for (int i = 0, size = headers.length; i < size; i += 2) {
builder.addHeader(headers[i], headers[i + 1]);
}
return this.executeSynchronously(builder.build());
}
private RecordedResponse executeSynchronously(final Request request) throws IOException {
final Call call = this.client.newCall(request);
try {
final Response[] r = new Response[1];
call.enqueue(new Callback() {
@Override
public void onFailure(final Call call, final IOException e) {
assertThat(true).isFalse();
}
@Override
public void onResponse(final Call call, final Response response) throws IOException {
r[0] = response;
}
});
try {
Thread.sleep(200);
}
catch (final InterruptedException e) {
e.printStackTrace();
}
final String bodyString = r[0].body().string();
return new RecordedResponse(request, r[0], null, bodyString, null);
}
catch (final IOException e) {
return new RecordedResponse(request, null, null, null, e);
}
}
public static final class RecordedResponse extends Assertions {
public final Request request;
public final Response response;
public final WebSocket webSocket;
public final String body;
public final IOException failure;
public RecordedResponse(final Request request,
final Response response,
final WebSocket webSocket,
final String body,
final IOException failure) {
this.request = request;
this.response = response;
this.webSocket = webSocket;
this.body = body;
this.failure = failure;
}
public RecordedResponse assertRequestUrl(final HttpUrl url) {
assertThat(this.request.url()).isEqualTo(url);
return this;
}
public RecordedResponse assertRequestMethod(final String method) {
assertThat(this.request.method()).isEqualTo(method);
return this;
}
public RecordedResponse assertRequestHeader(final String name, final String... values) {
assertThat(this.request.headers(name)).containsExactly(values);
return this;
}
public RecordedResponse assertCode(final int expectedCode) {
assertThat(this.response.code()).isEqualTo(expectedCode);
return this;
}
public RecordedResponse assertSuccessful() {
assertThat(this.response.isSuccessful()).isTrue();
return this;
}
public RecordedResponse assertNotSuccessful() {
assertThat(this.response.isSuccessful()).isFalse();
return this;
}
public RecordedResponse assertHeader(final String name, final String... values) {
assertThat(this.response.headers(name)).containsExactly(values);
return this;
}
public RecordedResponse assertHeaders(final Headers headers) {
assertThat(this.response.headers().toMultimap()).isEqualTo(headers.toMultimap());
return this;
}
public RecordedResponse assertBody(final String expectedBody) {
assertThat(this.body).isEqualTo(expectedBody);
return this;
}
public RecordedResponse assertHandshake() {
final Handshake handshake = this.response.handshake();
assertThat(handshake.tlsVersion()).isNotNull();
assertThat(handshake.cipherSuite()).isNotNull();
assertThat(handshake.peerPrincipal()).isNotNull();
assertThat(handshake.peerCertificates()).hasSize(1);
assertThat(handshake.localPrincipal()).isNull();
assertThat(handshake.localCertificates()).hasSize(0);
return this;
}
/**
* Asserts that the current response was redirected and returns the prior response.
*/
public RecordedResponse priorResponse() {
final Response priorResponse = this.response.priorResponse();
assertThat(priorResponse).isNotNull();
assertThat(priorResponse.body()).isNull();
return new RecordedResponse(priorResponse.request(), priorResponse, null, null, null);
}
/**
* Asserts that the current response used the network and returns the network response.
*/
public RecordedResponse networkResponse() {
final Response networkResponse = this.response.networkResponse();
assertThat(networkResponse).isNotNull();
assertThat(networkResponse.body()).isNull();
return new RecordedResponse(networkResponse.request(), networkResponse, null, null, null);
}
/**
* Asserts that the current response didn't use the network.
*/
public RecordedResponse assertNoNetworkResponse() {
assertThat(this.response.networkResponse()).isNull();
return this;
}
/**
* Asserts that the current response didn't use the cache.
*/
public RecordedResponse assertNoCacheResponse() {
assertThat(this.response.cacheResponse()).isNull();
return this;
}
/**
* Asserts that the current response used the cache and returns the cache response.
*/
public RecordedResponse cacheResponse() {
final Response cacheResponse = this.response.cacheResponse();
assertThat(cacheResponse).isNotNull();
assertThat(cacheResponse.body()).isNull();
return new RecordedResponse(cacheResponse.request(), cacheResponse, null, null, null);
}
public RecordedResponse assertFailure(final Class<?>... allowedExceptionTypes) {
boolean found = false;
for (final Class expectedClass : allowedExceptionTypes) {
if (expectedClass.isInstance(this.failure)) {
found = true;
break;
}
}
assertThat(found).as("Expected exception type among " + Arrays.toString(allowedExceptionTypes)).isTrue();
return this;
}
public RecordedResponse assertFailure(final String... messages) {
assertThat(this.failure).isNotNull();
assertThat(Arrays.asList(messages)).contains(this.failure.getMessage());
return this;
}
public RecordedResponse assertFailureMatches(final String... patterns) {
assertThat(this.failure).isNotNull();
for (final String pattern : patterns) {
if (this.failure.getMessage().matches(pattern)) {
return this;
}
}
throw new AssertionError(this.failure.getMessage());
}
public RecordedResponse assertSentRequestAtMillis(final long minimum, final long maximum) {
this.assertDateInRange(minimum, this.response.sentRequestAtMillis(), maximum);
return this;
}
public RecordedResponse assertReceivedResponseAtMillis(final long minimum, final long maximum) {
this.assertDateInRange(minimum, this.response.receivedResponseAtMillis(), maximum);
return this;
}
private void assertDateInRange(final long minimum, final long actual, final long maximum) {
assertThat(actual).isGreaterThanOrEqualTo(minimum);
assertThat(actual).isLessThanOrEqualTo(maximum);
}
private String format(final long time) {
return new SimpleDateFormat("HH:mm:ss.SSS").format(new Date(time));
}
public String getBody() {
return this.body;
}
}
}
package cronus.core;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.TreeSet;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.TimeUnit;
import java.util.function.Function;
import java.util.stream.Collectors;
import okhttp3.Call;
import okhttp3.Callback;
import okhttp3.Headers;
import okhttp3.Protocol;
import okhttp3.Request;
import okhttp3.Response;
import okhttp3.ResponseBody;
import okio.AsyncTimeout;
import okio.Timeout;
import static java.net.http.HttpClient.Redirect.ALWAYS;
import static java.net.http.HttpClient.Version.HTTP_2;
public class Http2Client implements Call.Factory {
private final HttpClient client;
private final ExecutorService executorService;
public Http2Client(final HttpClient client, final ExecutorService executorService) {
this.client = Objects.requireNonNull(client, "HttpClient can not be null");
this.executorService = Objects.requireNonNull(executorService, "ExecutorService can not be null");
}
public Http2Client() {
this(HttpClient.newBuilder()
.followRedirects(ALWAYS)
.version(HTTP_2)
.build(),
ForkJoinPool.commonPool());
}
@Override
public Call newCall(final Request request) {
return new Http2Call(this.client, request, this.executorService);
}
static class Http2Call implements Call {
private final HttpClient client;
private final Request originalRequest;
private final Executor executor;
private final CancellableFuture cancellableFuture;
private final AsyncTimeout timeout;
private boolean executed;
public Http2Call(final HttpClient client,
final Request originalRequest,
final Executor executor) {
this.client = client;
this.originalRequest = originalRequest;
this.executor = executor;
this.cancellableFuture = new CancellableFuture(this);
this.timeout = new AsyncTimeout() {
@Override
protected void timedOut() {
Http2Call.this.cancel();
}
};
client.connectTimeout()
.ifPresent(duration -> this.timeout.timeout(duration.toMillis(), TimeUnit.MILLISECONDS));
}
@Override
public Request request() {
return this.originalRequest;
}
@Override
public Response execute() throws IOException {
synchronized (this) {
if (this.executed) {
throw new IllegalStateException("Already Executed");
}
this.executed = true;
}
this.timeout.enter();
final HttpResponse<byte[]> httpResponse;
try {
httpResponse = this.client.send(this.toRequest(this.originalRequest),
HttpResponse.BodyHandlers.ofByteArray());
}
catch (final InterruptedException e) {
throw new RuntimeException("Interrupted", e);
}
catch (final IOException e) {
throw this.timeoutExit(e);
}
return this.fromResponse(httpResponse);
}
IOException timeoutExit(final IOException cause) {
if (!this.timeout.exit()) {
return cause;
}
final InterruptedIOException e = new InterruptedIOException("timeout");
if (cause != null) {
e.initCause(cause);
}
return e;
}
@Override
public void enqueue(final Callback responseCallback) {
synchronized (this) {
if (this.executed) {
throw new IllegalStateException("Already Executed");
}
this.executed = true;
}
final CompletableFuture<HttpResponse<byte[]>> httpResponse =
this.client.sendAsync(this.toRequest(this.originalRequest),
HttpResponse.BodyHandlers.ofByteArray());
httpResponse.handleAsync((response, throwable) -> {
if (throwable != null) {
responseCallback.onFailure(this, new IOException(throwable));
}
return this.fromResponse(response);
}).thenAcceptAsync(r -> {
try {
responseCallback.onResponse(this, r);
}
catch (final IOException e) {
responseCallback.onFailure(this, e);
}
}, this.executor);
}
@Override
public void cancel() {
this.cancellableFuture.cancel(true);
}
@Override
public boolean isExecuted() {
return this.executed;
}
@Override
public boolean isCanceled() {
return this.cancellableFuture.isCancelled();
}
@Override
public Timeout timeout() {
return null;
}
@Override
public Call clone() {
return new Http2Call(this.client, this.originalRequest, this.executor);
}
private HttpRequest toRequest(final Request request) {
final HttpRequest.BodyPublisher body;
if (request.body() == null) {
body = HttpRequest.BodyPublishers.noBody();
}
else {
body = HttpRequest.BodyPublishers.ofByteArray(request.body().toString().getBytes());
}
final HttpRequest.Builder requestBuilder = HttpRequest.newBuilder()
.uri(request.url().uri())
.version(HTTP_2);
final Map<String, Collection<String>> headers = this.filterRestrictedHeaders(request.headers().toMultimap());
if (!headers.isEmpty()) {
requestBuilder.headers(this.asString(headers));
}
switch (request.method()) {
case "GET":
return requestBuilder.GET().build();
case "POST":
return requestBuilder.POST(body).build();
case "PUT":
return requestBuilder.PUT(body).build();
case "DELETE":
return requestBuilder.DELETE().build();
default:
// fall back scenario, http implementations may restrict some methods
return requestBuilder.method(request.method(), body).build();
}
}
private Response fromResponse(final HttpResponse<byte[]> httpResponse) {
final Response.Builder builder = new Response.Builder();
final Map<String, String> h = new HashMap<>();
httpResponse.headers()
.map()
.forEach((key, value) -> value.forEach(v -> h.put(key, v)));
return builder
.body(ResponseBody.create(null, httpResponse.body()))
.request(this.originalRequest)
.code(httpResponse.statusCode())
.protocol(Protocol.HTTP_2)
.headers(Headers.of(h))
// .sentRequestAtMillis(httpResponse.headers())
.message(httpResponse.headers().firstValue("Reason-Phrase").orElse("OK"))
.build();
}
private static final Set<String> DISALLOWED_HEADERS_SET;
static {
// A case insensitive TreeSet of strings.
final TreeSet<String> treeSet = new TreeSet<>(String.CASE_INSENSITIVE_ORDER);
treeSet.addAll(Set.of("connection", "content-length", "date", "expect", "from", "host",
"origin", "referer", "upgrade", "via", "warning"));
DISALLOWED_HEADERS_SET = Collections.unmodifiableSet(treeSet);
}
private Map<String, Collection<String>> filterRestrictedHeaders(final Map<String, List<String>> headers) {
final Map<String, Collection<String>> filteredHeaders =
headers.keySet()
.stream()
.filter(headerName -> !DISALLOWED_HEADERS_SET.contains(
headerName))
.collect(Collectors.toMap(
Function.identity(),
headers::get));
filteredHeaders.computeIfAbsent("Accept", key -> List.of("*/*"));
return filteredHeaders;
}
private String[] asString(final Map<String, Collection<String>> headers) {
return headers.entrySet()
.stream()
.flatMap(entry -> entry.getValue()
.stream()
.map(value -> Arrays.asList(entry.getKey(), value))
.flatMap(List::stream)).toArray(String[]::new);
}
}
static class CancellableFuture extends CompletableFuture<Response> {
private final Call call;
CancellableFuture(final Call call) {
this.call = call;
}
@Override
public boolean cancel(final boolean mayInterruptIfRunning) {
if (mayInterruptIfRunning && !this.isDone()) {
this.call.cancel();
}
return super.cancel(mayInterruptIfRunning);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment