Skip to content

Instantly share code, notes, and snippets.

@tvoinarovskyi
Created July 14, 2017 12:50
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 1 You must be signed in to fork a gist
  • Save tvoinarovskyi/05a5d083a0f96cae3e9b4c2af580be74 to your computer and use it in GitHub Desktop.
Save tvoinarovskyi/05a5d083a0f96cae3e9b4c2af580be74 to your computer and use it in GitHub Desktop.
Kafka enhanced consumer using Thread Workers and consumer.pause().
from kafka import (
KafkaConsumer, TopicPartition, OffsetAndMetadata, ConsumerRebalanceListener
)
import queue
import threading
import time
import logging
log = logging.getLogger(__name__)
NUM_WORKERS = 5
class Shutdown(Exception):
pass
class WorkerQueue(object):
def __init__(self):
self._queue = queue.Queue()
self._processing_tps = set([])
self._finished = {}
self._lock = threading.Lock()
self._start_shutdown = False
def put(self, tp, messages):
assert isinstance(tp, TopicPartition)
self.check_shutdown()
with self._lock:
assert tp not in self._finished
self._queue.put_nowait((tp, messages))
def get(self):
while True:
self.check_shutdown()
try:
return self._queue.get(timeout=1)
except queue.Empty:
continue
def join(self):
return self._queue.join()
def finished_processing(self, tp, last_message):
with self._lock:
self._queue.task_done()
self._finished[tp] = last_message
def get_finished(self):
with self._lock:
finished = self._finished
self._finished = {}
return finished
def drop_pending(self):
""" Remove any records that were not started processing by workers
"""
while True:
try:
self._queue.get_nowait()
except queue.Empty:
break
self._queue.task_done()
def start_shutdown(self):
self._start_shutdown = True
# Clear all items in queue, that has not started processing
self.drop_pending()
def check_shutdown(self):
if self._start_shutdown:
raise Shutdown()
class RebalanceListener(ConsumerRebalanceListener):
def __init__(self, worker_queue, consumer):
self._worker_queue = worker_queue
self._consumer = consumer
def _commit_finished(self):
# Unpause any finished ones and commit offsets
finished = self._worker_queue.get_finished()
paused = self._consumer.paused()
commit_offsets = {}
if finished:
log.warn("Committing %d partitions on revoke", len(finished))
for tp, last_message in finished.items():
assert tp in paused
self._consumer.resume(tp)
commit_offsets[tp] = OffsetAndMetadata(
last_message.offset + 1, "")
self._consumer.commit(commit_offsets)
def on_partitions_revoked(self, revoked):
""" Commit all processed items before rebalancing partitions """
log.info("Revoking %d partitions", len(revoked))
self._worker_queue.drop_pending()
# We commit before and after, as we may not succed later. `join()` can
# take a while.
self._commit_finished()
self._worker_queue.join()
self._commit_finished()
def on_partitions_assigned(self, assigned):
log.info("Assigned %d partitions", len(assigned))
def worker_thread(worker_queue):
try:
log.info("Starting worker thread %s", threading.get_ident())
while True:
tp, messages = worker_queue.get()
# Process messages
log.info("Processing %d messages from tp %s on tid=%s",
len(messages), tp, threading.get_ident())
time.sleep(5)
worker_queue.finished_processing(tp, messages[-1])
except Shutdown:
print("Worker thread {} shutdown".format(threading.get_ident()))
except Exception:
log.exception("Unexpected error in worker", exc_info=True)
def consumer_thread(worker_queue):
try:
log.info("Starting consumer thread")
# To consume latest messages and auto-commit offsets
consumer = KafkaConsumer(
group_id='my-group',
bootstrap_servers=['localhost:9092'],
enable_auto_commit=False,
auto_offset_reset="earliest")
rebalance_listener = RebalanceListener(worker_queue, consumer)
consumer.subscribe('my-new-topic', listener=rebalance_listener)
while True:
worker_queue.check_shutdown()
# You can use `max_records` to limit the number of results
msg_pack = consumer.poll(timeout_ms=1000)
log.info("poll() returned %s partitions, %s paused",
len(msg_pack), len(consumer.paused()))
for tp, messages in msg_pack.items():
if messages:
worker_queue.put(tp, messages)
consumer.pause(tp)
# Unpause any finished ones and commit offsets
finished = worker_queue.get_finished()
paused = consumer.paused()
commit_offsets = {}
for tp, last_message in finished.items():
assert tp in paused
consumer.resume(tp)
commit_offsets[tp] = OffsetAndMetadata(
last_message.offset + 1, "")
consumer.commit(commit_offsets)
except Shutdown:
log.info("Starting consumer thread shutdown")
# We have to commit all processed data on normal shutdowns
worker_queue.join() # Wait for all `get` data to finish processing
last_finished = worker_queue.get_finished()
if last_finished:
log.warn("Committing for last %s partitions", len(last_finished))
commit_offsets = {}
for tp, last_message in last_finished.items():
commit_offsets[tp] = OffsetAndMetadata(
last_message.offset + 1, "")
consumer.commit(commit_offsets)
print("Consumer thread shutdown")
except Exception:
log.exception("Unexpected error in consumer thread", exc_info=True)
worker_queue.start_shutdown()
consumer.close()
def main():
logging.basicConfig(level=logging.INFO)
worker_queue = WorkerQueue()
c_thread = threading.Thread(target=consumer_thread, args=(worker_queue, ))
c_thread.start()
w_threads = []
for i in range(NUM_WORKERS):
t = threading.Thread(target=worker_thread, args=(worker_queue, ))
t.start()
w_threads.append(t)
try:
while True:
time.sleep(1)
assert c_thread.is_alive()
except KeyboardInterrupt:
print("Received ctrl+C, shutting down...")
# All shutdown will be done through WorkerQueue instance, as the
# central bus.
worker_queue.start_shutdown()
worker_queue.join()
c_thread.join()
for w in w_threads:
w.join()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment