Skip to content

Instantly share code, notes, and snippets.

@lqc
Created November 25, 2019 08:46
Show Gist options
  • Save lqc/4914f0f5033894c60ec28d08215edfc2 to your computer and use it in GitHub Desktop.
Save lqc/4914f0f5033894c60ec28d08215edfc2 to your computer and use it in GitHub Desktop.
WIP version of XRayWebFilter
package com.syncron.spark.idi.gateway.awsxray;
import com.amazonaws.xray.AWSXRay;
import com.amazonaws.xray.AWSXRayRecorder;
import com.amazonaws.xray.entities.Segment;
import com.amazonaws.xray.entities.TraceHeader;
import com.amazonaws.xray.entities.TraceID;
import com.amazonaws.xray.strategy.sampling.SamplingRequest;
import com.amazonaws.xray.strategy.sampling.SamplingResponse;
import com.amazonaws.xray.strategy.sampling.SamplingStrategy;
import org.springframework.boot.web.reactive.filter.OrderedWebFilter;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.http.server.reactive.ServerHttpResponse;
import org.springframework.stereotype.Component;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.WebFilterChain;
import reactor.core.publisher.Mono;
import javax.annotation.Nullable;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import static java.util.Objects.requireNonNull;
@Component
public class XRayWebFilter implements OrderedWebFilter {
private final AWSXRayRecorder recorder;
public XRayWebFilter() {
this(AWSXRay.getGlobalRecorder());
}
public XRayWebFilter(AWSXRayRecorder recorder) {
this.recorder = requireNonNull(recorder, "recorder");
}
@Override
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) {
final Segment segment = createSegment(exchange.getRequest(), getRecorder());
return Mono.just(exchange)
.flatMap(chain::filter)
.doOnError(segment::addException)
.doOnSuccess(ignore -> {
final ServerHttpResponse response = exchange.getResponse();
final HttpStatus status = response.getStatusCode();
if (status == null || status.is5xxServerError()) {
segment.setFault(true);
} else if (status.is4xxClientError()) {
segment.setError(true);
if (status == HttpStatus.TOO_MANY_REQUESTS) {
segment.setThrottle(true);
}
}
Map<String, Object> responseAttributes = new HashMap<>();
if (status != null) {
responseAttributes.put("status", status.value());
}
getContentLength(response.getHeaders())
.ifPresent(contentLength -> responseAttributes.put("content_length", contentLength));
segment.putHttp("response", responseAttributes);
})
// NOTE: This is not correct as recorder is global and non-async!
.doFinally(ignore -> recorder.endSegment());
}
private Segment createSegment(final ServerHttpRequest request, AWSXRayRecorder recorder) {
final HttpHeaders requestHeaders = request.getHeaders();
final String segmentName = getHeader(requestHeaders, HttpHeaders.HOST).orElse("<unknown_host>");
final Optional<TraceHeader> incomingHeader = getTraceHeader(requestHeaders);
final SamplingStrategy samplingStrategy = recorder.getSamplingStrategy();
final SamplingResponse samplingResponse = fromSamplingStrategy(recorder, segmentName, request);
final boolean shouldSample = decideOnSampling(recorder, samplingResponse, incomingHeader.orElse(null));
final TraceID traceId = incomingHeader.map(TraceHeader::getRootTraceId).orElseGet(TraceID::new);
final String parentId = incomingHeader.map(TraceHeader::getParentId).orElse(null);
final Segment createdSegment;
if (shouldSample) {
createdSegment = recorder.beginSegment(segmentName, traceId, parentId);
samplingResponse.getRuleName().ifPresent(rule -> {
// TODO: log.debug()
createdSegment.setRuleName(rule);
});
} else {
if (samplingStrategy.isForcedSamplingSupported()) {
createdSegment = recorder.beginSegment(segmentName, traceId, parentId);
createdSegment.setSampled(false);
} else {
createdSegment = recorder.beginDummySegment(segmentName, traceId);
}
}
recordRequestParams(createdSegment, requestHeaders);
return createdSegment;
}
private void recordRequestParams(Segment createdSegment, HttpHeaders requestHeaders) {
// TODO
}
private boolean decideOnSampling(
AWSXRayRecorder recorder,
SamplingResponse samplingResponse,
@Nullable TraceHeader incomingHeader
) {
final TraceHeader.SampleDecision sampleDecision;
if (incomingHeader != null) {
sampleDecision = incomingHeader.getSampled();
if (sampleDecision != TraceHeader.SampleDecision.REQUESTED && sampleDecision != TraceHeader.SampleDecision.UNKNOWN) {
return samplingResponse.isSampled();
}
}
return samplingResponse.isSampled();
}
private SamplingResponse fromSamplingStrategy(AWSXRayRecorder recorder, String segmentName, ServerHttpRequest request) {
SamplingRequest samplingRequest = new SamplingRequest(
segmentName,
getHeader(request.getHeaders(), HttpHeaders.HOST).orElse(null),
request.getURI().toASCIIString(),
request.getMethodValue(),
recorder.getOrigin()
);
return recorder.getSamplingStrategy().shouldTrace(samplingRequest);
}
private static Optional<String> getHeader(HttpHeaders headers, String key) {
return Optional.ofNullable(headers.getFirst(key));
}
private static Optional<TraceHeader> getTraceHeader(HttpHeaders headers) {
return getHeader(headers, TraceHeader.HEADER_KEY)
.map(TraceHeader::fromString);
}
private static Optional<String> getXForwardedFor(HttpHeaders headers) {
return getHeader(headers, "X-Forwarded-For")
.map(s -> s.split(",")[0].trim());
}
private Optional<Integer> getContentLength(HttpHeaders headers) {
return getHeader(headers, HttpHeaders.CONTENT_LENGTH)
.filter(String::isEmpty)
.map(Integer::parseInt);
}
@Override
public int getOrder() {
return HIGHEST_PRECEDENCE;
}
private AWSXRayRecorder getRecorder() {
return recorder;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment