Created
December 2, 2014 15:26
-
-
Save marram/3e69fba12d82bc1501aa to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class PullQueueWorker(Command): | |
""" @num_tasks: The number of tasks to execute at a time. | |
""" | |
# The queue that this work is invoked from. This is a PUSH queue. | |
# It is used by the base_class to determine how to schedule a command. | |
# It is different from the queue that the worker reads the tasks from | |
queue_name = None | |
pull_queue_name = None | |
# The target machine | |
target = None | |
# Lease tasks by tag or not | |
lease_by_tag = False | |
channel = None | |
__metaclass__ = abc.ABCMeta | |
num_tasks = 40 | |
lease_time = 600 | |
def __init__(self, num_tasks=None, lease_time=None, force_pull_queue_name=None): | |
self.num_tasks = int(num_tasks or self.num_tasks) | |
self.lease_time = int(lease_time or self.lease_time) | |
self.force_pull_queue_name = force_pull_queue_name | |
def validate(self): | |
if self.force_pull_queue_name and self.force_pull_queue_name != "None": | |
queue_name = self.force_pull_queue_name | |
elif not isinstance(self.pull_queue_name, basestring): | |
# Then we have shard queues | |
assert hasattr(self.pull_queue_name, "__iter__") | |
queue_name = random.choice(self.pull_queue_name) | |
else: | |
queue_name = self.pull_queue_name | |
logging.info("Latching onto queue %s" % queue_name) | |
self.pull_queue = Queue(queue_name) | |
def _make_taskqueue_params(self): | |
params = super(PullQueueWorker, self)._make_taskqueue_params() | |
# Run the workers on this backend. | |
if self.target: | |
params.update(dict(target=self.target)) | |
return params | |
def process_tasks(self, tasks): | |
payloads = [] | |
for task in tasks: | |
try: | |
payloads.extend(json.loads(task.payload)) | |
except Exception, e: | |
logging.error(e) | |
logging.error("Failed to deserialize task body. Perhaps a bad one?") | |
pass | |
self.channel().communicate_message(payloads) | |
def lease_tasks(self): | |
leasing_function = self.pull_queue.lease_tasks_async | |
if self.lease_by_tag: | |
leasing_function = self.pull_queue.lease_tasks_by_tag_async | |
max_tasks = taskqueue.MAX_TASKS_PER_LEASE | |
rpcs = [] | |
for lease_count in miscutils.split_count(self.num_tasks, max_tasks): | |
rpcs.append(leasing_function(self.lease_time, lease_count)) | |
tasks = [] | |
for rpc in rpcs: | |
tasks.extend(rpc.get_result() or []) | |
return tasks | |
def execute(self): | |
while True: | |
try: | |
logging.info("Leasing tasks from %s" % self.pull_queue.name) | |
tasks = self.lease_tasks() | |
if not tasks: | |
# Backoff using a non-logging response. The cron will make sure that there are enough workers spawned | |
logging.info("Found no tasks. Bailing with a 302 to retry ") | |
if hasattr(self, "response"): | |
return self._redirect() | |
else: | |
return | |
logging.info("Leased %d tasks" % len(tasks)) | |
# Ensure to delete the tasks. Pull queue tasks are not purged after execution. | |
# Delete the tasks first, since it seems the queue | |
# implementation is more flaky than the external services | |
rpc = taskqueue.create_rpc(deadline=60) | |
self.pull_queue.delete_tasks_async(tasks, rpc=rpc) | |
rpc.get_result() | |
logging.info("deleted %d tasks" % len(tasks)) | |
self.process_tasks(tasks) | |
logging.info("processed %d tasks" % len(tasks)) | |
except (TransientError, DeadlineExceededError) as e: | |
logging.error(type(e)) | |
logging.error(e) | |
# This is a transient error, so we need to backoff. | |
# This is an error that does not get logged in the error counters in the admin page | |
return self._redirect() | |
class PullQueueBoss(Command): | |
""" This is executed in periodic cron. | |
""" | |
worker_class = PullQueueWorker | |
max_workers = 10 | |
max_tasks_per_worker = 300 | |
def __init__(self, max_workers=None, max_tasks_per_worker=None): | |
self.max_workers = int(max_workers or self.max_workers) | |
self.max_tasks_per_worker = int(max_tasks_per_worker or self.max_tasks_per_worker) | |
def validate(self): | |
# We create pointers to the queues here so that we can test self.calculate_needed_workers effectively. | |
self.worker_queue = Queue(self.worker_class.queue_name) | |
# HACK for now ... what this needs to do is run this for each shard. | |
try: | |
assert isinstance(self.worker_class.pull_queue_name, basestring) | |
self.pull_queue = Queue(self.worker_class.pull_queue_name) | |
except: | |
self.pull_queue = Queue(self.worker_class.pull_queue_name[0]) | |
def calculate_needed_workers(self): | |
# Do we need more workers? | |
worker_queue_stats, pull_queue_stats = tuple(QueueStatistics.fetch([self.worker_queue, self.pull_queue])) | |
message_count = pull_queue_stats.tasks | |
logging.info("Pull queue has %d messages" % message_count) | |
if not message_count: | |
return 0 | |
current_workers = worker_queue_stats.in_flight or 0 | |
logging.info("Current workers: %d" % current_workers) | |
if current_workers >= self.max_workers: | |
return 0 | |
needed_workers = min(self.max_workers, max(1, int(math.ceil(message_count / self.max_tasks_per_worker)))) | |
needed_workers -= current_workers | |
return max(1, needed_workers) | |
def execute(self): | |
needed_workers = self.calculate_needed_workers() | |
self.worker_queue.purge() | |
logging.info("Starting %d workers" % needed_workers) | |
commands = [self.worker_class() for i in range(needed_workers)] | |
command_utils.batch_command_instances(commands) | |
return True | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment