Skip to content

Instantly share code, notes, and snippets.

@jkuipers
Last active November 7, 2023 08:34
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jkuipers/ea27406b3bd2b84eab65f197366cfe8e to your computer and use it in GitHub Desktop.
Save jkuipers/ea27406b3bd2b84eab65f197366cfe8e to your computer and use it in GitHub Desktop.
Configuration and code to add tracing support to Spring Cloud AWS's message listeners
@AutoConfiguration(before = io.awspring.cloud.autoconfigure.sqs.SqsAutoConfiguration.class,
afterName = "org.springframework.boot.actuate.autoconfigure.tracing.BraveAutoConfiguration")
@ConditionalOnBean(Tracing.class)
public class SqsTracingAutoConfiguration {
@Bean(name = SqsBeanNames.SQS_LISTENER_ANNOTATION_BEAN_POST_PROCESSOR_BEAN_NAME)
TracingSqsListenerAnnotationBeanPostProcessor tracingSLABPP(Tracing tracing) {
return new TracingSqsListenerAnnotationBeanPostProcessor(tracing);
}
static class TracingSqsListenerAnnotationBeanPostProcessor extends SqsListenerAnnotationBeanPostProcessor {
private final Tracing tracing;
public TracingSqsListenerAnnotationBeanPostProcessor(Tracing tracing) {
this.tracing = tracing;
}
/**
* Overrides parent method to ensure that our custom endpoint with tracing support is returned.
*/
@Override
protected Endpoint createEndpoint(SqsListener sqsListenerAnnotation) {
return new TracingSqsEndpoint.TracingSqsEndpointBuilder(tracing)
.queueNames(resolveEndpointNames(sqsListenerAnnotation.value()))
.factoryBeanName(resolveAsString(sqsListenerAnnotation.factory(), "factory"))
.id(getEndpointId(sqsListenerAnnotation.id()))
.pollTimeoutSeconds(resolveAsInteger(sqsListenerAnnotation.pollTimeoutSeconds(), "pollTimeoutSeconds"))
.maxMessagesPerPoll(resolveAsInteger(sqsListenerAnnotation.maxMessagesPerPoll(), "maxMessagesPerPoll"))
.maxConcurrentMessages(
resolveAsInteger(sqsListenerAnnotation.maxConcurrentMessages(), "maxConcurrentMessages"))
.messageVisibility(
resolveAsInteger(sqsListenerAnnotation.messageVisibilitySeconds(), "messageVisibility"))
.build();
}
}
static class TracingSqsEndpoint extends SqsEndpoint {
private final Tracing tracing;
protected TracingSqsEndpoint(SqsEndpoint.SqsEndpointBuilder builder, Tracing tracing) {
super(builder);
this.tracing = tracing;
}
@Override
protected <T> MessageListener<T> createMessageListenerInstance(InvocableHandlerMethod handlerMethod) {
return new TracingWrappers.MessageListenerWrapper<>(super.createMessageListenerInstance(handlerMethod), tracing);
}
@Override
protected <T> AsyncMessageListener<T> createAsyncMessageListenerInstance(InvocableHandlerMethod handlerMethod) {
return new TracingWrappers.AsyncMessageListenerWrapper<>(super.createAsyncMessageListenerInstance(handlerMethod), tracing);
}
static class TracingSqsEndpointBuilder extends SqsEndpoint.SqsEndpointBuilder {
private final Tracing tracing;
public TracingSqsEndpointBuilder(Tracing tracing) {
this.tracing = tracing;
}
@Override
public SqsEndpoint build() {
return new TracingSqsEndpoint(this, tracing);
}
}
}
static abstract class TracingWrappers<D> {
private static final Propagation.Getter<MessageHeaders, String> GETTER =
(headers, key) -> (String) headers.get(key);
protected D delegate;
private final TraceContext.Extractor<MessageHeaders> extractor;
private final Tracer tracer;
private final Logger errorLogger = LoggerFactory.getLogger("nl.trifork.sqs.listener");
TracingWrappers(D delegate, Tracing tracing) {
this.delegate = delegate;
this.extractor = tracing.propagation().extractor(GETTER);
this.tracer = tracing.tracer();
}
CompletableFuture<Void> doInSpan(Function<Message, CompletableFuture<Void>> caller, Message<?> message) {
TraceContextOrSamplingFlags extracted = extractor.extract(message.getHeaders());
Span span = tracer.nextSpan(extracted)
.kind(CONSUMER)
.name("on-message")
.remoteServiceName("sqs")
.start();
try (Tracer.SpanInScope ws = tracer.withSpanInScope(span)) {
return caller.apply(message);
} catch (Throwable t) {
span.error(t);
logError(message.getHeaders(), t);
throw t;
} finally {
span.finish();
}
}
private void logError(MessageHeaders headers, Throwable t) {
Integer dlqDequeues = (Integer) headers.get("DlqDequeues");
errorLogger.warn("Error processing messageId={} with receiveCount={} and dlqDequeues={} of type {}",
headers.getId(),
headers.get(SqsHeaders.MessageSystemAttributes.SQS_APPROXIMATE_RECEIVE_COUNT),
dlqDequeues != null ? dlqDequeues : 0,
headers.get(SqsHeaders.SQS_DEFAULT_TYPE_HEADER),
t);
}
static class AsyncMessageListenerWrapper<T> extends TracingWrappers<AsyncMessageListener<T>> implements AsyncMessageListener<T> {
public AsyncMessageListenerWrapper(AsyncMessageListener<T> delegate, Tracing tracing) {
super(delegate, tracing);
}
@Override
public CompletableFuture<Void> onMessage(Message<T> message) {
return doInSpan(delegate::onMessage, message);
}
@Override
public CompletableFuture<Void> onMessage(Collection<Message<T>> messages) {
return delegate.onMessage(messages);
}
}
static class MessageListenerWrapper<T> extends TracingWrappers<MessageListener<T>> implements MessageListener<T> {
MessageListenerWrapper(MessageListener<T> delegate, Tracing tracing) {
super(delegate, tracing);
}
@Override
public void onMessage(Message<T> message) {
doInSpan(msg -> {
delegate.onMessage(msg);
return null;
}, message);
}
@Override
public void onMessage(Collection<Message<T>> messages) {
delegate.onMessage(messages);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment