Created
November 25, 2019 08:46
-
-
Save lqc/4914f0f5033894c60ec28d08215edfc2 to your computer and use it in GitHub Desktop.
WIP version of XRayWebFilter
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.syncron.spark.idi.gateway.awsxray; | |
import com.amazonaws.xray.AWSXRay; | |
import com.amazonaws.xray.AWSXRayRecorder; | |
import com.amazonaws.xray.entities.Segment; | |
import com.amazonaws.xray.entities.TraceHeader; | |
import com.amazonaws.xray.entities.TraceID; | |
import com.amazonaws.xray.strategy.sampling.SamplingRequest; | |
import com.amazonaws.xray.strategy.sampling.SamplingResponse; | |
import com.amazonaws.xray.strategy.sampling.SamplingStrategy; | |
import org.springframework.boot.web.reactive.filter.OrderedWebFilter; | |
import org.springframework.http.HttpHeaders; | |
import org.springframework.http.HttpStatus; | |
import org.springframework.http.server.reactive.ServerHttpRequest; | |
import org.springframework.http.server.reactive.ServerHttpResponse; | |
import org.springframework.stereotype.Component; | |
import org.springframework.web.server.ServerWebExchange; | |
import org.springframework.web.server.WebFilterChain; | |
import reactor.core.publisher.Mono; | |
import javax.annotation.Nullable; | |
import java.util.HashMap; | |
import java.util.Map; | |
import java.util.Optional; | |
import static java.util.Objects.requireNonNull; | |
@Component | |
public class XRayWebFilter implements OrderedWebFilter { | |
private final AWSXRayRecorder recorder; | |
public XRayWebFilter() { | |
this(AWSXRay.getGlobalRecorder()); | |
} | |
public XRayWebFilter(AWSXRayRecorder recorder) { | |
this.recorder = requireNonNull(recorder, "recorder"); | |
} | |
@Override | |
public Mono<Void> filter(ServerWebExchange exchange, WebFilterChain chain) { | |
final Segment segment = createSegment(exchange.getRequest(), getRecorder()); | |
return Mono.just(exchange) | |
.flatMap(chain::filter) | |
.doOnError(segment::addException) | |
.doOnSuccess(ignore -> { | |
final ServerHttpResponse response = exchange.getResponse(); | |
final HttpStatus status = response.getStatusCode(); | |
if (status == null || status.is5xxServerError()) { | |
segment.setFault(true); | |
} else if (status.is4xxClientError()) { | |
segment.setError(true); | |
if (status == HttpStatus.TOO_MANY_REQUESTS) { | |
segment.setThrottle(true); | |
} | |
} | |
Map<String, Object> responseAttributes = new HashMap<>(); | |
if (status != null) { | |
responseAttributes.put("status", status.value()); | |
} | |
getContentLength(response.getHeaders()) | |
.ifPresent(contentLength -> responseAttributes.put("content_length", contentLength)); | |
segment.putHttp("response", responseAttributes); | |
}) | |
// NOTE: This is not correct as recorder is global and non-async! | |
.doFinally(ignore -> recorder.endSegment()); | |
} | |
private Segment createSegment(final ServerHttpRequest request, AWSXRayRecorder recorder) { | |
final HttpHeaders requestHeaders = request.getHeaders(); | |
final String segmentName = getHeader(requestHeaders, HttpHeaders.HOST).orElse("<unknown_host>"); | |
final Optional<TraceHeader> incomingHeader = getTraceHeader(requestHeaders); | |
final SamplingStrategy samplingStrategy = recorder.getSamplingStrategy(); | |
final SamplingResponse samplingResponse = fromSamplingStrategy(recorder, segmentName, request); | |
final boolean shouldSample = decideOnSampling(recorder, samplingResponse, incomingHeader.orElse(null)); | |
final TraceID traceId = incomingHeader.map(TraceHeader::getRootTraceId).orElseGet(TraceID::new); | |
final String parentId = incomingHeader.map(TraceHeader::getParentId).orElse(null); | |
final Segment createdSegment; | |
if (shouldSample) { | |
createdSegment = recorder.beginSegment(segmentName, traceId, parentId); | |
samplingResponse.getRuleName().ifPresent(rule -> { | |
// TODO: log.debug() | |
createdSegment.setRuleName(rule); | |
}); | |
} else { | |
if (samplingStrategy.isForcedSamplingSupported()) { | |
createdSegment = recorder.beginSegment(segmentName, traceId, parentId); | |
createdSegment.setSampled(false); | |
} else { | |
createdSegment = recorder.beginDummySegment(segmentName, traceId); | |
} | |
} | |
recordRequestParams(createdSegment, requestHeaders); | |
return createdSegment; | |
} | |
private void recordRequestParams(Segment createdSegment, HttpHeaders requestHeaders) { | |
// TODO | |
} | |
private boolean decideOnSampling( | |
AWSXRayRecorder recorder, | |
SamplingResponse samplingResponse, | |
@Nullable TraceHeader incomingHeader | |
) { | |
final TraceHeader.SampleDecision sampleDecision; | |
if (incomingHeader != null) { | |
sampleDecision = incomingHeader.getSampled(); | |
if (sampleDecision != TraceHeader.SampleDecision.REQUESTED && sampleDecision != TraceHeader.SampleDecision.UNKNOWN) { | |
return samplingResponse.isSampled(); | |
} | |
} | |
return samplingResponse.isSampled(); | |
} | |
private SamplingResponse fromSamplingStrategy(AWSXRayRecorder recorder, String segmentName, ServerHttpRequest request) { | |
SamplingRequest samplingRequest = new SamplingRequest( | |
segmentName, | |
getHeader(request.getHeaders(), HttpHeaders.HOST).orElse(null), | |
request.getURI().toASCIIString(), | |
request.getMethodValue(), | |
recorder.getOrigin() | |
); | |
return recorder.getSamplingStrategy().shouldTrace(samplingRequest); | |
} | |
private static Optional<String> getHeader(HttpHeaders headers, String key) { | |
return Optional.ofNullable(headers.getFirst(key)); | |
} | |
private static Optional<TraceHeader> getTraceHeader(HttpHeaders headers) { | |
return getHeader(headers, TraceHeader.HEADER_KEY) | |
.map(TraceHeader::fromString); | |
} | |
private static Optional<String> getXForwardedFor(HttpHeaders headers) { | |
return getHeader(headers, "X-Forwarded-For") | |
.map(s -> s.split(",")[0].trim()); | |
} | |
private Optional<Integer> getContentLength(HttpHeaders headers) { | |
return getHeader(headers, HttpHeaders.CONTENT_LENGTH) | |
.filter(String::isEmpty) | |
.map(Integer::parseInt); | |
} | |
@Override | |
public int getOrder() { | |
return HIGHEST_PRECEDENCE; | |
} | |
private AWSXRayRecorder getRecorder() { | |
return recorder; | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment