Skip to content

Instantly share code, notes, and snippets.

@mzakyalvan
Created February 23, 2020 08:54
Show Gist options
  • Save mzakyalvan/4fe1734d989404027ccc1bfcde0f51dc to your computer and use it in GitHub Desktop.
Save mzakyalvan/4fe1734d989404027ccc1bfcde0f51dc to your computer and use it in GitHub Desktop.
package com.tiket.poc.redis.track;
import java.time.Duration;
import javax.validation.constraints.NotBlank;
import javax.validation.constraints.NotNull;
import reactor.core.publisher.Mono;
/**
* @author zakyalvan
*/
public interface RequestTracker<I, O> {
/**
* Start tracking of request.
*
* @param request
* @return
*/
Mono<TrackingData<I, O>> startTracking(@NotBlank String correlation, @NotNull I request);
Mono<TrackingData<I, O>> completeTracking(@NotBlank String key);
Mono<TrackingData<I, O>> completeTracking(@NotBlank String key, @NotNull Duration retention);
Mono<TrackingData<I, O>> completeTracking(@NotBlank String key, @NotNull Throwable error);
Mono<TrackingData<I, O>> completeTracking(@NotBlank String key, @NotNull Throwable error, @NotNull Duration retention);
Mono<TrackingData<I, O>> removeTracking(@NotBlank String key);
}
// ======================================
package com.tiket.poc.redis.track;
import static com.tiket.poc.redis.track.RequestState.COMPLETED_STATE_VALUE;
import static com.tiket.poc.redis.track.RequestState.FAILED_STATE_VALUE;
import static com.tiket.poc.redis.track.RequestState.PROGRESSED_STATE_VALUE;
import static com.tiket.poc.redis.track.TrackingData.COMPLETED_TIME_FIELD;
import static com.tiket.poc.redis.track.TrackingData.REQUEST_OBJECT_FIELD;
import static com.tiket.poc.redis.track.TrackingData.REQUEST_STATE_FIELD;
import static com.tiket.poc.redis.track.TrackingData.RESPONSE_OBJECT_FIELD;
import static com.tiket.poc.redis.track.TrackingData.CLIENT_CORRELATION_FIELD;
import static com.tiket.poc.redis.track.TrackingData.TOTAL_ATTEMPTS_FIELD;
import static com.tiket.poc.redis.track.TrackingData.TRACKED_TIME_FIELD;
import static com.tiket.poc.redis.track.TrackingData.TRACKING_KEY_FIELD;
import static java.nio.charset.StandardCharsets.UTF_8;
import com.fasterxml.jackson.databind.ObjectMapper;
import java.nio.ByteBuffer;
import java.time.Duration;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import javax.validation.constraints.NotNull;
import lombok.extern.slf4j.Slf4j;
import org.springframework.dao.DataAccessException;
import org.springframework.data.redis.connection.RedisConnection;
import org.springframework.data.redis.core.RedisCallback;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.hash.HashMapper;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.validation.annotation.Validated;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;
import reactor.util.function.Tuple2;
/**
* @author zakyalvan
*/
@Slf4j
@Validated
public class DefaultRequestTracker<I, O> implements RequestTracker<I, O> {
private static final String DEFAULT_KEY_PREFIX = "tiket.request.track:";
private final String keyPrefix;
private final KeyFactory<I> keyFactory;
private final RedisTemplate<Object, Object> redisTemplate;
private final HashMapper<TrackingData<I, O>, byte[], byte[]> hashMapper;
/**
* How long {@link TrackingData} in {@link RequestState#PROGRESSED} will be retained.
*/
@NotNull
private Duration progressedRetention = Duration.ofSeconds(30);
/**
* How long {@link TrackingData} in {@link RequestState#COMPLETED} will be retained.
*/
@NotNull
private Duration completedRetention = Duration.ofSeconds(60);
/**
* How long {@link TrackingData} in {@link RequestState#FAILED} will be retained.
*/
@NotNull
private Duration failedRetention = Duration.ZERO;
public DefaultRequestTracker(KeyFactory<I> keyFactory, RedisTemplate<Object, Object> redisTemplate,
ObjectMapper objectMapper, Class<I> requestType, Class<O> responseType) {
Assert.notNull(keyFactory, "Key factory must be provided");
Assert.notNull(redisTemplate, "Redis template must be provided");
Assert.notNull(objectMapper, "Object mapper must be provided");
Assert.notNull(requestType, "Request type must be provided");
Assert.notNull(responseType, "Response type must be provided");
this.keyPrefix = DEFAULT_KEY_PREFIX.concat(requestType.getSimpleName().toLowerCase()).concat(":");
this.keyFactory = keyFactory;
this.redisTemplate = redisTemplate;
this.hashMapper = new RequestTrackingHashMapper<>(objectMapper, requestType, responseType);
}
@Override
public Mono<TrackingData<I, O>> startTracking(String clientCorrelation, I requestObject) {
return keyFactory.createKey(requestObject)
.map(trackingKey -> new StartTrackingCallback<>(keyPrefix, trackingKey, clientCorrelation, hashMapper, requestObject))
.flatMap(trackCallback -> Mono.fromCallable(() -> redisTemplate.execute(trackCallback)))
.publishOn(Schedulers.parallel())
.elapsed()
//.doOnNext(objects -> log.debug("Tracking request with key {} requiring {} ms", objects.getT2().getKey(), objects.getT1()))
.map(Tuple2::getT2);
}
@Override
public Mono<TrackingData<I, O>> completeTracking(String trackingKey) {
return completeTracking(trackingKey, completedRetention);
}
@Override
public Mono<TrackingData<I, O>> completeTracking(String trackingKey, Duration retentionDuration) {
return completeTracking(trackingKey, null, retentionDuration);
}
@Override
public Mono<TrackingData<I, O>> completeTracking(String trackingKey, Throwable thrownError) {
return completeTracking(trackingKey, thrownError, failedRetention);
}
@Override
public Mono<TrackingData<I, O>> completeTracking(String trackingKey, Throwable thrownError, Duration retentionDuration) {
return Mono.fromCallable(() -> new CompleteTrackingCallback<>(keyPrefix, trackingKey, hashMapper, thrownError, retentionDuration))
.flatMap(completeCallback -> Mono.fromCallable(() -> redisTemplate.execute(completeCallback)));
}
@Override
public Mono<TrackingData<I, O>> removeTracking(String trackingKey) {
return Mono.fromCallable(() -> new RemoveTrackingCallback<>(keyPrefix, trackingKey, hashMapper))
.flatMap(removeCallback -> Mono.fromCallable(() -> redisTemplate.execute(removeCallback)));
}
/**
* Create tracking {@link RedisCallback}.
*
* @param <I>
* @param <O>
*/
@Slf4j
static class StartTrackingCallback<I, O> implements RedisCallback<TrackingData<I, O>> {
private final String keyPrefix;
private final String trackingKey;
private final String startCorrelation;
private final HashMapper<TrackingData<I, O>, byte[], byte[]> hashMapper;
private final I requestObject;
StartTrackingCallback(String keyPrefix, String trackingKey, String startCorrelation, HashMapper<TrackingData<I, O>, byte[], byte[]> hashMapper, I requestObject) {
Assert.hasText(keyPrefix, "Key prefix must be provided");
Assert.hasText(trackingKey, "Tracking key must be provided");
Assert.hasText(startCorrelation, "Tracking start correlation must be provided");
Assert.notNull(hashMapper, "Tracking hash mapper must be provided");
Assert.notNull(requestObject, "Request object must be provided");
this.keyPrefix = keyPrefix;
this.trackingKey = trackingKey;
this.startCorrelation = startCorrelation;
this.hashMapper = hashMapper;
this.requestObject = requestObject;
}
@Override
public TrackingData<I, O> doInRedis(RedisConnection connection) throws DataAccessException {
byte[] redisKey = keyPrefix.concat(trackingKey).getBytes(UTF_8);
connection.hSetNX(redisKey, CLIENT_CORRELATION_FIELD, startCorrelation.getBytes(UTF_8));
connection.hSetNX(redisKey, REQUEST_STATE_FIELD, PROGRESSED_STATE_VALUE);
if(connection.hIncrBy(redisKey, TOTAL_ATTEMPTS_FIELD, 1) == 1) {
connection.pExpire(redisKey, Duration.ofSeconds(50).toMillis());
connection.hSetNX(redisKey, TRACKING_KEY_FIELD, trackingKey.getBytes(UTF_8));
connection.hSetNX(redisKey, TRACKED_TIME_FIELD, LocalDateTime.now().toString().getBytes(UTF_8));
}
return hashMapper.fromHash(connection.hGetAll(redisKey));
}
}
static class CompleteTrackingCallback<I, O> implements RedisCallback<TrackingData<I, O>> {
private final String keyPrefix;
private final String trackingKey;
private final HashMapper<TrackingData<I, O>, byte[], byte[]> hashMapper;
private final Throwable thrownError;
private final Duration trackingRetention;
CompleteTrackingCallback(String keyPrefix, String trackingKey,
HashMapper<TrackingData<I, O>, byte[], byte[]> hashMapper,
Duration trackingRetention) {
this(keyPrefix, trackingKey, hashMapper, null, trackingRetention);
}
CompleteTrackingCallback(String keyPrefix, String trackingKey,
HashMapper<TrackingData<I, O>, byte[], byte[]> hashMapper, Throwable thrownError,
Duration trackingRetention) {
Assert.hasText(keyPrefix, "Tracking key prefix must be provided");
Assert.hasText(trackingKey, "Request tracking key must be provided");
this.keyPrefix = keyPrefix;
this.trackingKey = trackingKey;
this.hashMapper = hashMapper;
this.thrownError = thrownError;
this.trackingRetention = trackingRetention;
}
@Override
public TrackingData<I, O> doInRedis(RedisConnection connection) throws DataAccessException {
byte[] redisKey = keyPrefix.concat(trackingKey).getBytes(UTF_8);
if (Arrays.equals(connection.hGet(redisKey, REQUEST_STATE_FIELD), PROGRESSED_STATE_VALUE)) {
if (thrownError != null) {
connection.hSet(redisKey, REQUEST_STATE_FIELD, FAILED_STATE_VALUE);
} else {
connection.hSet(redisKey, REQUEST_STATE_FIELD, COMPLETED_STATE_VALUE);
}
connection.hSetNX(redisKey, COMPLETED_TIME_FIELD, LocalDateTime.now().toString().getBytes(UTF_8));
if (trackingRetention != null) {
connection.pExpire(redisKey, trackingRetention.toMillis());
}
}
Map<byte[], byte[]> hashValues = connection.hGetAll(redisKey);
if (!CollectionUtils.isEmpty(hashValues)) {
return hashMapper.fromHash(hashValues);
}
else {
return null;
}
}
}
static class RemoveTrackingCallback<I, O> implements RedisCallback<TrackingData<I, O>> {
private final String keyPrefix;
private final String trackingKey;
private final HashMapper<TrackingData<I, O>, byte[], byte[]> hashMapper;
public RemoveTrackingCallback(String keyPrefix, String trackingKey,
HashMapper<TrackingData<I, O>, byte[], byte[]> hashMapper) {
this.keyPrefix = keyPrefix;
this.trackingKey = trackingKey;
this.hashMapper = hashMapper;
}
@Override
public TrackingData<I, O> doInRedis(RedisConnection connection) throws DataAccessException {
byte[] redisKey = keyPrefix.concat(trackingKey).getBytes(UTF_8);
TrackingData<I, O> requestTracking = null;
Map<byte[], byte[]> hashValues = connection.hGetAll(redisKey);
if(!CollectionUtils.isEmpty(hashValues)) {
requestTracking = hashMapper.fromHash(hashValues);
}
connection.del(redisKey);
return requestTracking;
}
}
/**
* {@link TrackingData}'s {@link HashMapper}.
*
* @param <I>
* @param <O>
*/
static class RequestTrackingHashMapper<I, O> implements HashMapper<TrackingData<I, O>, byte[], byte[]> {
private final DateTimeFormatter DEFAULT_TIME_FORMAT = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS");
private final ObjectMapper objectMapper;
private final Class<I> requestType;
private final Class<O> responseType;
public RequestTrackingHashMapper(ObjectMapper objectMapper, Class<I> requestType,
Class<O> responseType) {
Assert.notNull(objectMapper, "Object mapper must be provided");
Assert.notNull(requestType, "Request type must be provided");
Assert.notNull(responseType, "Response type must be provided");
this.objectMapper = objectMapper;
this.requestType = requestType;
this.responseType = responseType;
}
@Override
public Map<byte[], byte[]> toHash(TrackingData<I, O> object) {
try {
Map<byte[], byte[]> hashValues = new HashMap<>();
hashValues.put(TRACKING_KEY_FIELD, object.getKey().getBytes(UTF_8));
hashValues.put(REQUEST_STATE_FIELD, object.getRequestState().name().getBytes(UTF_8));
hashValues.put(REQUEST_OBJECT_FIELD, objectMapper.writeValueAsBytes(object.getRequestObject()));
hashValues.put(TOTAL_ATTEMPTS_FIELD, ByteBuffer.allocate(4).putInt(1).array());
hashValues.put(TRACKED_TIME_FIELD, object.getTrackedTime().format(DEFAULT_TIME_FORMAT).getBytes(UTF_8));
return hashValues;
}
catch (Exception e) {
throw new IllegalStateException("Can not create hash of given object, see stack traces", e);
}
}
@Override
public TrackingData<I, O> fromHash(Map<byte[], byte[]> hash) {
if(CollectionUtils.isEmpty(hash)) {
return null;
}
TrackingData.Builder<I, O> trackingBuilder = TrackingData.builder();
hash.forEach((hashField, hashValue) -> {
try {
if (hashValue == null || hashValue.length == 0) {
return;
}
if (Arrays.equals(hashField, TRACKING_KEY_FIELD)) {
trackingBuilder.key(new String(hashValue, UTF_8));
}
if (Arrays.equals(hashField, REQUEST_STATE_FIELD)) {
trackingBuilder.requestState(RequestState.valueOf(new String(hashValue, UTF_8)));
}
if (Arrays.equals(hashField, CLIENT_CORRELATION_FIELD)) {
trackingBuilder.clientCorrelation(new String(hashValue, UTF_8));
}
if (Arrays.equals(hashField, REQUEST_OBJECT_FIELD)) {
trackingBuilder.requestObject(objectMapper.readValue(hashValue, requestType));
}
if (Arrays.equals(hashField, RESPONSE_OBJECT_FIELD)) {
trackingBuilder.responseObject(objectMapper.readValue(hashValue, responseType));
}
if (Arrays.equals(hashField, TOTAL_ATTEMPTS_FIELD)) {
trackingBuilder.totalAttempts(Integer.valueOf(new String(hashValue)));
}
if (Arrays.equals(hashField, TRACKED_TIME_FIELD)) {
trackingBuilder.trackedTime(LocalDateTime.parse(new String(hashValue, UTF_8), DateTimeFormatter.ISO_LOCAL_DATE_TIME));
}
if (Arrays.equals(hashField, COMPLETED_TIME_FIELD)) {
trackingBuilder.completedTime(LocalDateTime.parse(new String(hashValue, UTF_8), DateTimeFormatter.ISO_LOCAL_DATE_TIME));
}
}
catch (Exception e) {
throw new IllegalStateException("Can not create object from hash, see stack traces", e);
}
});
return trackingBuilder.build();
}
}
}
// ======================================
package com.tiket.poc.redis.track;
import java.nio.charset.StandardCharsets;
/**
* @author zakyalvan
*/
public enum RequestState {
/**
* Request still in progress.
*/
PROGRESSED,
/**
* Request already completed successfully.
*/
COMPLETED,
/**
* Request already completed by with fail.
*/
FAILED;
public static final byte[] PROGRESSED_STATE_VALUE = RequestState.PROGRESSED.name().getBytes(StandardCharsets.UTF_8);
public static final byte[] COMPLETED_STATE_VALUE = RequestState.PROGRESSED.name().getBytes(StandardCharsets.UTF_8);
public static final byte[] FAILED_STATE_VALUE = RequestState.FAILED.name().getBytes(StandardCharsets.UTF_8);
}
// ======================================
package com.tiket.poc.redis.track;
import reactor.core.publisher.Mono;
/**
* @author zakyalvan
*/
public interface KeyFactory<R> {
Mono<String> createKey(R request);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment