Skip to content

Instantly share code, notes, and snippets.

@dtodt
Last active February 23, 2023 18:18
Show Gist options
  • Star 19 You must be signed in to star a gist
  • Fork 6 You must be signed in to fork a gist
  • Save dtodt/2b62a18e87375682167027bb7feb6752 to your computer and use it in GitHub Desktop.
Save dtodt/2b62a18e87375682167027bb7feb6752 to your computer and use it in GitHub Desktop.
Retrofit2@2.5.0 - Retry Adapter Factory - Call & CompletableFuture
package com.company.retrofit2.annotation;
import java.lang.annotation.Documented;
import java.lang.annotation.Retention;
import java.lang.annotation.Target;
import static java.lang.annotation.ElementType.METHOD;
import static java.lang.annotation.RetentionPolicy.RUNTIME;
/**
* @author dtodt
*/
@Documented
@Target(METHOD)
@Retention(RUNTIME)
public @interface Retry {
/**
* The max retry attempt (default is 3).
*/
int max() default 3;
}
package com.company.retrofit2;
import com.company.retrofit2.annotation.Retry;
import lombok.AllArgsConstructor;
import lombok.NoArgsConstructor;
import lombok.RequiredArgsConstructor;
import lombok.extern.slf4j.Slf4j;
import okhttp3.Request;
import retrofit2.*;
import java.io.IOException;
import java.lang.annotation.Annotation;
import java.lang.reflect.Type;
import java.util.Arrays;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicInteger;
/**
* When you need something to be tried once more, with this adapter and {@see Retry} annotation, you can retry.
* <p>
* NOTE THAT: You can only retry asynchronous {@see Call} or {@see java.util.concurrent.CompletableFuture} implementations.
*
* @author dtodt
*/
@Slf4j
@NoArgsConstructor
@AllArgsConstructor
public class RetryCallAdapterFactory extends CallAdapter.Factory {
/**
* You can setup a default max retry count for all connections.
*/
private int maxRetries = 0;
@Override
public CallAdapter<?, ?> get(final Type returnType, final Annotation[] annotations, final Retrofit retrofit) {
int itShouldRetry = maxRetries;
final Retry retry = getRetry(annotations);
if (retry != null) {
itShouldRetry = retry.max();
}
log.debug("Starting a CallAdapter with {} retries.", itShouldRetry);
return new RetryCallAdapter<>(
retrofit.nextCallAdapter(this, returnType, annotations),
itShouldRetry
);
}
private Retry getRetry(final Annotation[] annotations) {
return Arrays.stream(annotations).parallel()
.filter(annotation -> annotation instanceof Retry)
.map(annotation -> ((Retry) annotation))
.findFirst()
.orElse(null);
}
@RequiredArgsConstructor
private class RetryCallAdapter<R, T> implements CallAdapter<R, T> {
private final CallAdapter<R, T> delegated;
private final int maxRetries;
@Override
public Type responseType() {
return delegated.responseType();
}
@Override
public T adapt(final Call<R> call) {
return delegated.adapt(maxRetries > 0 ? new RetryingCall<>(call, maxRetries) : call);
}
}
@RequiredArgsConstructor
private class RetryingCall<R> implements Call<R> {
private final Call<R> delegated;
private final int maxRetries;
@Override
public Response<R> execute() throws IOException {
return delegated.execute();
}
@Override
public void enqueue(final Callback<R> callback) {
delegated.enqueue(new RetryCallback<>(delegated, callback, maxRetries));
}
@Override
public boolean isExecuted() {
return delegated.isExecuted();
}
@Override
public void cancel() {
delegated.cancel();
}
@Override
public boolean isCanceled() {
return delegated.isCanceled();
}
@Override
public Call<R> clone() {
return new RetryingCall<>(delegated.clone(), maxRetries);
}
@Override
public Request request() {
return delegated.request();
}
}
@RequiredArgsConstructor
private class RetryCallback<T> implements Callback<T> {
private final Call<T> call;
private final Callback<T> callback;
private final int maxRetries;
private final AtomicInteger retryCount = new AtomicInteger(0);
@Override
public void onResponse(final Call<T> call, final Response<T> response) {
if (!response.isSuccessful() && retryCount.incrementAndGet() <= maxRetries) {
log.debug("Call with no success result code: {}", response.code());
retryCall();
} else {
callback.onResponse(call, response);
}
}
@Override
public void onFailure(final Call<T> call, final Throwable t) {
log.debug("Call failed with message: {}", t.getLocalizedMessage(), t);
if (retryCount.incrementAndGet() <= maxRetries) {
retryCall();
} else if (maxRetries > 0) {
log.debug("No retries left sending timeout up.");
callback.onFailure(call, new TimeoutException(String.format("No retries left after %s attempts.", maxRetries)));
} else {
callback.onFailure(call, t);
}
}
private void retryCall() {
log.warn("{}/{} Retrying...", retryCount.get(), maxRetries);
call.clone().enqueue(this);
}
}
}
package com.company.retrofit2;
import com.company.retrofit2.annotation.Retry;
import lombok.extern.slf4j.Slf4j;
import okhttp3.OkHttpClient;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import retrofit2.*;
import retrofit2.converter.scalars.ScalarsConverterFactory;
import retrofit2.http.GET;
import java.io.IOException;
import java.net.ConnectException;
import java.net.SocketTimeoutException;
import java.time.Duration;
import java.util.concurrent.*;
import static org.junit.Assert.fail;
/**
* {@see RetryCallAdapterFactory} validations and examples, definition: {@see FakeRestService}.
*/
@Slf4j
public class RetryCallAdapterFactoryTest {
private static final String FAIL = "Must throw exception";
private static final String OK = "OK";
private static final String PATH = "/fake";
private static final int E1O2 = 2;
private static final int E2O3 = 3;
private static final int E3O4 = 4;
private static final int A200 = 200;
private static final int A401 = 401;
private static final int A403 = 403;
private static final int A500 = 500;
private final MockResponse a200Response = new MockResponse().setResponseCode(A200).setBody(OK);
private final MockResponse a401Response = new MockResponse().setResponseCode(A401);
private final MockResponse a403Response = new MockResponse().setResponseCode(A403);
private final MockResponse a500Response = new MockResponse().setResponseCode(A500);
private MockWebServer webServer;
private FakeRestService fakeService;
private CountDownLatch latch;
@Before
public void setUp() throws Exception {
latch = new CountDownLatch(1);
webServer = new MockWebServer();
webServer.start();
fakeService = getFakeService();
}
@After
public void tearDown() throws Exception {
fakeService = null;
webServer.shutdown();
}
@Test
public void successOnFirstRequest() throws ExecutionException, InterruptedException {
webServer.enqueue(a200Response);
final CompletableFuture<String> future = fakeService.futureGet();
Assert.assertEquals(OK, future.get());
Assert.assertEquals(1, webServer.getRequestCount());
}
@Test
public void successOnLastRequest() throws ExecutionException, InterruptedException {
webServer.enqueue(a403Response);
webServer.enqueue(a401Response);
webServer.enqueue(a200Response);
final CompletableFuture<String> future = fakeService.futureGet();
Assert.assertEquals(OK, future.get());
Assert.assertEquals(E2O3, webServer.getRequestCount());
}
@Test
public void failureWith500() {
webServer.enqueue(a500Response);
webServer.enqueue(a500Response);
webServer.enqueue(a500Response);
final CompletableFuture<String> future = fakeService.futureGet();
try {
future.get();
fail(FAIL);
} catch (final InterruptedException | ExecutionException e) {
Assert.assertEquals(HttpException.class.getName(), e.getCause().getClass().getName());
}
Assert.assertEquals(E2O3, webServer.getRequestCount());
}
@Test
public void failureWithSocketConnection() {
final CompletableFuture<String> future = fakeService.futureGet();
try {
future.get();
fail(FAIL);
} catch (final InterruptedException | ExecutionException e) {
Assert.assertEquals(TimeoutException.class.getName(), e.getCause().getClass().getName());
}
Assert.assertEquals(E2O3, webServer.getRequestCount());
}
@Test
public void failureWithNoConnection() throws IOException {
webServer.close();
final CompletableFuture<String> future = fakeService.futureGet();
try {
future.get();
fail(FAIL);
} catch (final InterruptedException | ExecutionException e) {
Assert.assertEquals(TimeoutException.class.getName(), e.getCause().getClass().getName());
}
Assert.assertEquals(0, webServer.getRequestCount());
}
@Test
public void callSyncFailureWith500() throws IOException {
webServer.enqueue(a500Response);
final Call<String> call = fakeService.callGet();
final Response<String> response = call.execute();
Assert.assertNotNull(response);
Assert.assertNull(response.body());
Assert.assertEquals(1, webServer.getRequestCount());
}
@Test(expected = SocketTimeoutException.class)
public void callSyncFailureWithSocketConnection() throws IOException {
final Call<String> call = fakeService.callGet();
try {
call.execute();
fail(FAIL);
} catch (IOException e) {
Assert.assertEquals(1, webServer.getRequestCount());
throw e;
}
}
@Test(expected = ConnectException.class)
public void callSyncFailureWithNoConnection() throws IOException {
webServer.close();
final Call<String> call = fakeService.callGet();
try {
call.execute();
fail(FAIL);
} catch (IOException e) {
Assert.assertEquals(0, webServer.getRequestCount());
throw e;
}
}
@Test
public void callAsyncFailureWith500() throws InterruptedException {
webServer.enqueue(a500Response);
webServer.enqueue(a500Response);
webServer.enqueue(a500Response);
webServer.enqueue(a500Response);
final StringBuilder result = new StringBuilder();
final Call<String> call = fakeService.callGet();
call.enqueue(getFakeCallback(result));
final boolean ended = latch.await(1, TimeUnit.MINUTES);
Assert.assertTrue(ended);
Assert.assertEquals("", result.toString());
Assert.assertEquals(E3O4, webServer.getRequestCount());
}
@Test
public void callAsyncFailureWithSocketConnection() throws InterruptedException {
final StringBuilder result = new StringBuilder();
final Call<String> call = fakeService.callGet();
call.enqueue(getFakeCallback(result));
final boolean ended = latch.await(1, TimeUnit.MINUTES);
Assert.assertTrue(ended);
Assert.assertEquals("", result.toString());
Assert.assertEquals(E3O4, webServer.getRequestCount());
}
@Test
public void callAsyncFailureWithNoConnection() throws IOException, InterruptedException {
webServer.close();
final StringBuilder result = new StringBuilder();
final Call<String> call = fakeService.callGet();
call.enqueue(getFakeCallback(result));
final boolean ended = latch.await(1, TimeUnit.MINUTES);
Assert.assertTrue(ended);
Assert.assertEquals("", result.toString());
Assert.assertEquals(0, webServer.getRequestCount());
}
@Test
public void futureResponseFailureWith500() throws ExecutionException, InterruptedException {
webServer.enqueue(a500Response);
webServer.enqueue(a500Response);
final CompletableFuture<Response<String>> future = fakeService.futureResponseGet();
Assert.assertNotNull(future.get());
Assert.assertEquals(A500, future.get().code());
Assert.assertNull(future.get().body());
Assert.assertEquals(E1O2, webServer.getRequestCount());
}
@Test
public void futureResponseFailureWithSocketConnection() {
final CompletableFuture<Response<String>> future = fakeService.futureResponseGet();
try {
future.get();
fail(FAIL);
} catch (final InterruptedException | ExecutionException e) {
Assert.assertEquals(TimeoutException.class.getName(), e.getCause().getClass().getName());
}
Assert.assertEquals(E1O2, webServer.getRequestCount());
}
@Test
public void futureResponseFailureWithNoConnection() throws IOException {
webServer.close();
final CompletableFuture<Response<String>> future = fakeService.futureResponseGet();
try {
future.get();
fail(FAIL);
} catch (final InterruptedException | ExecutionException e) {
Assert.assertEquals(TimeoutException.class.getName(), e.getCause().getClass().getName());
}
Assert.assertEquals(0, webServer.getRequestCount());
}
private FakeRestService getFakeService() {
return getRetrofitInstance().create(FakeRestService.class);
}
private Retrofit getRetrofitInstance() {
final Duration seconds = Duration.ofSeconds(2);
final OkHttpClient client = new OkHttpClient()
.newBuilder()
.readTimeout(seconds)
.connectTimeout(seconds)
.writeTimeout(seconds)
.build();
return new Retrofit.Builder()
.baseUrl(webServer.url("/"))
.client(client)
.addCallAdapterFactory(new RetryCallAdapterFactory(2))
.addConverterFactory(ScalarsConverterFactory.create())
.build();
}
private Callback<String> getFakeCallback(final StringBuilder result) {
return new Callback<String>() {
@Override
public void onResponse(Call<String> call, Response<String> response) {
if (response.isSuccessful()) {
result.append(response.body());
}
latch.countDown();
}
@Override
public void onFailure(Call<String> call, Throwable t) {
latch.countDown();
}
};
}
/**
* Retrofit rest service representation.
*/
private interface FakeRestService {
/**
* More likely used on sync methods when you want to validate the response code.
*/
@Retry
@GET(PATH)
Call<String> callGet();
/**
* More likely used on async or sync methods when you just need the response value.
*/
@GET(PATH)
CompletableFuture<String> futureGet();
/**
* (RECOMMENDED)
* More likely used on async or sync methods when you want to validate the response code.
*/
@Retry(max = 1)
@GET(PATH)
CompletableFuture<Response<String>> futureResponseGet();
}
}
@emtiazahmed
Copy link

Hi,

Thankyou for writing this up. Did you by any change implemented support for exponential backoff in this?

@dtodt
Copy link
Author

dtodt commented Sep 13, 2019 via email

@jbaezml
Copy link

jbaezml commented Oct 23, 2019

Hi!. How I can use this but using sync api call only. In yours tests, only user async. The execute method can not retry!

@cah-parvendan-somasundar

is kotlin version is available?

@emouawad
Copy link

About time somebody puts this into a library

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment