Skip to content

Instantly share code, notes, and snippets.

@mumbleskates
Created November 28, 2018 22:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mumbleskates/2c9cfd76c8ea747b35e2eb16ed2d00fc to your computer and use it in GitHub Desktop.
Save mumbleskates/2c9cfd76c8ea747b35e2eb16ed2d00fc to your computer and use it in GitHub Desktop.
closeable channel and threadsafe iterator wrapper
# coding=utf-8
from collections import deque
from queue import Empty, Full
from threading import Condition, RLock, Thread
from time import monotonic as now
from weakref import finalize
class ChannelClosed(Exception):
"""Exception raised when a channel has been closed."""
pass
class Channel(object):
"""
Like queue.Queue, but can be closed. Iteration goes until the channel is closed.
Every item put() into a channel may be provided exactly once to a caller of its get() method,
unless drain() is called. The structure is suitable for any number of producers and any number
of consumers. The channel will not accept new items if the size of its queue is currently at
or over `maxsize`, and will reject or block until space is available.
For race condition reasons, channels are not reusable once closed.
"""
def __init__(self, maxsize=float('inf')):
if maxsize < 1:
raise ValueError('maxsize must be 1 or more')
self.maxsize = maxsize
self._closed = False # set to True when the channel is flagged for closure
self.mutex = RLock() # lock held whenever the channel's queue is mutated
self.is_closed = Condition(self.mutex) # notified when the channel is closed and drained
self.not_empty = Condition(self.mutex) # notified when item(s) exist in the queue
self.not_full = Condition(self.mutex) # notified when space exists in the queue
self._init()
# Override these methods to implement other queueing models, as with standard queue.Queue.
def _init(self):
"""Initialize the queue representation."""
self.queue = deque()
def _qsize(self):
"""Return the current size of the queue, in whatever unit. MUST be falsy when empty."""
return len(self.queue)
def _put(self, item):
"""Put an item into the queue."""
self.queue.append(item)
def _get(self):
"""Get an item from the queue."""
return self.queue.popleft()
def get(self, timeout=None):
"""
Take and return a single item from the channel.
With the default timeout of None, blocks until an item is available in the queue.
Positive values will block for up to that many seconds waiting, and zero or negative
values of timeout do not block. Raises queue.Empty if a timeout is reached.
If the channel is closed and drained, raises ChannelClosed.
"""
with self.mutex:
# ensure there are items to get
if timeout is None: # blocking indefinitely
while not self._qsize():
if self._closed:
self.is_closed.notify_all()
raise ChannelClosed
self.not_empty.wait()
else: # block for up to timeout seconds
endtime = now() + timeout
while not self._qsize():
if self._closed:
self.is_closed.notify_all()
raise ChannelClosed
remaining = endtime - now()
if remaining <= 0:
raise Empty # timed out
self.not_empty.wait(remaining)
self.not_full.notify()
return self._get()
def put(self, item, timeout=None):
"""
Send an item to the channel.
With the default timeout of None, blocks until there is space in the queue to accept it.
Positive values will block for up to that many seconds waiting, and zero or negative
values of timeout do not block. Raises queue.Full if a timeout is reached.
If the channel is closed before success, raises ChannelClosed.
"""
with self.mutex:
# ensure the channel can accept items
if timeout is None: # blocking indefinitely
while not self._closed:
if self._qsize() < self.maxsize:
break
self.not_full.wait()
else:
raise ChannelClosed
else: # block for up to timeout seconds
endtime = now() + timeout
while not self._closed:
if self._qsize() < self.maxsize:
break
remaining = endtime - now()
if remaining <= 0:
raise Full
self.not_full.wait(remaining)
else:
raise ChannelClosed
self._put(item)
self.not_empty.notify()
def put_all(self, items):
"""
Sends all the items in the provided iterator to the channel, blocking until done.
If the channel is closed before success, raises ChannelClosed.
This is slightly faster than looping put(), as it reduces mutex thrashing.
"""
it = iter(items)
with self.mutex:
# ensure the channel can accept items
while True:
if self._closed:
raise ChannelClosed
if self._qsize() < self.maxsize:
break
self.not_full.wait()
while True:
# add items to the queue in bulk while there is space
for item in it:
self._put(item)
self.not_empty.notify()
if self._qsize() >= self.maxsize:
break
else:
return # run out of items to insert, we are done!
# wait for space to be available again
while True:
self.not_full.wait()
if self._closed:
raise ChannelClosed
if self._qsize() < self.maxsize:
break
def close(self):
"""
Flags the channel for closure. This is not reversible.
Once this method is called, no new items can be sent to the channel
"""
with self.mutex:
self._closed = True
# awaken all threads that need to finish consuming or fail to insert
if self._qsize(): # channel is now draining, let threads blocked on put() seppuku
self.not_full.notify_all()
else: # channel is now closed and dead, let threads blocked on wait_closed() and get() seppuku
self.not_empty.notify_all()
self.is_closed.notify_all()
def drain(self):
"""
!Danger! Discard all items currently in the channel's queue.
Obviously, this can lead to data loss. This method is intended for use when a channel's consumers
have died irreplaceably, but other threads are still awaiting the closure of this channel for a
clean shutdown.
"""
with self.mutex:
if self._closed:
self.is_closed.notify_all()
else:
self.not_full.notify_all()
self._init()
def wait_closed(self, timeout=None):
"""
Wait for the channel to be completely closed and drained.
With the default timeout None, blocks indefinitely. Positive values for timeout specify
the maximum amount of time to block in seconds, and values less than or equal to zero
do not block.
Raises TimeoutError if the timeout is reached.
"""
with self.mutex:
endtime = now() + timeout
while not self._closed or self._qsize():
remaining = endtime - now()
if remaining <= 0:
raise TimeoutError
self.is_closed.wait(remaining)
def status(self):
"""
Non-authoritative status for the channel. Returns one of three status strings:
* 'open' -- available for processing items (does not indicate at which end the bottleneck lies)
* 'draining' -- close() has been called, but items still remain to be processed.
* 'closed' -- the channel is fully closed and will never reopen.
The status returnd may be subject to race conditions and is meant to be advisory only; the
channel may have since progressed to a later status.
"""
if self._closed:
if self._qsize():
return 'draining'
else:
return 'closed'
else:
return 'open'
def __iter__(self):
"""Yield values from this channel until it is closed and drained, blocking indefinitely."""
try:
while True:
yield self.get()
except ChannelClosed:
return
class TaskTrackChannel(Channel):
"""A subclass of Channel that provides task completion tracking like queue.Queue."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.tasks_put = 0
self.tasks_done = 0
self.all_tasks_done = Condition(self.mutex)
def _put(self, item):
super()._put(item)
self.tasks_put += 1
def task_done(self):
"""Report one task as done."""
with self.mutex:
self.tasks_done += 1
if self.tasks_done > self.tasks_put:
raise ValueError('task_done() called too many times')
elif self.tasks_done == self.tasks_put:
self.all_tasks_done.notify_all()
def join(self):
"""
Block until the number of items put into the channel is the same as the number of
tasks reported done.
"""
with self.mutex:
while self.tasks_done < self.tasks_put:
self.all_tasks_done.wait()
class IterProvider(object):
"""
Provide a generator to multiple threads.
Every time this object provides a new iterator, it spawns a worker thread that gets a new iterator
from the provided generator; the iterator returned can then be used safely by any number of threads.
Objects provided by this iterator are guaranteed to be passed exactly one time.
"""
def __init__(self, generator, queue_length=16):
self.generator = generator
self.queue_length = queue_length
def __iter__(self):
class Yielder(object):
def __init__(self, generator, queue_length):
channel = self.channel = Channel(maxsize=queue_length)
# must not hold a reference to self to prevent the thread from keeping the iterator alive
def work():
for thing in generator():
try:
channel.put(thing)
except ChannelClosed:
# channel was closed by someone else
# print("thread ending after channel closed externally") # Debug
return
channel.close()
# print("thread ending after generator exhausted") # Debug
Thread(target=work).start()
finalize(self, channel.close)
def __next__(self):
try:
return self.channel.get()
except ChannelClosed:
raise StopIteration
def __iter__(self):
return self
return Yielder(self.generator, self.queue_length)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment