Skip to content

Instantly share code, notes, and snippets.

@tsangwpx
Created March 31, 2019 19:14
Show Gist options
  • Save tsangwpx/a3241a05a4289e853b57d4bab1ce9f2e to your computer and use it in GitHub Desktop.
Save tsangwpx/a3241a05a4289e853b57d4bab1ce9f2e to your computer and use it in GitHub Desktop.
Logging with multiprocessing
"""
Logging with multiprocessing
LogRecord are sent back to the master process
License: MIT
Author: Aaron Tsang
"""
from __future__ import annotations
import copy
import logging
import multiprocessing as mp
import re
import threading
import time
from types import TracebackType
from typing import Type
import tblib
_worker_id = -1
class ExcInfo:
"""
Persist the exc_info tuple as a picklable object
"""
def __init__(self, exc_info):
if isinstance(exc_info, BaseException):
exc_info = (type(exc_info), exc_info, exc_info.__traceback__)
assert isinstance(exc_info, tuple)
self.cls, self.exc, self.tb = exc_info # type: (Type[BaseException], BaseException, TracebackType)
@property
def exc_info(self):
return self.cls, self.exc, self.tb
def __getstate__(self):
cause = self.exc.__cause__
context = self.exc.__context__
return (
self.cls, self.exc, tblib.Traceback(self.tb),
ExcInfo(cause) if cause else None,
ExcInfo(context) if context else None,
self.exc.__suppress_context__
)
def __setstate__(self, state):
cls, exc, tb, cause, context, supress = state
if cause:
exc.__cause__ = cause.exc
if context:
exc.__context__ = context.exc
exc.__suppress_context__ = supress
self.cls = cls
self.exc = exc
self.tb = tb.as_traceback()
def __repr__(self):
return f'{type(self).__qualname__}({self.exc!r})'
class MasterLogHandler(logging.Handler):
def __init__(self, backend: logging.Handler, queue=None):
super().__init__()
self.looping = True
self.backend = backend
self._queue = mp.Queue() if queue is None else queue
t = threading.Thread(target=self._loop, daemon=True)
t.start()
self.worker_handler = WorkerLogHandler(self._queue)
def _loop(self):
while self.looping:
record: logging.LogRecord = self._queue.get()
if isinstance(record.exc_info, ExcInfo):
record.exc_info = record.exc_info.exc_info
self.handle(record)
def emit(self, record):
self.backend.handle(record)
def close(self):
self._queue.close()
def new_worker(self):
return WorkerLogHandler(self._queue)
class WorkerLogHandler(logging.Handler):
def __init__(self, queue):
super().__init__()
self._queue: mp.Queue = queue
def emit(self, record: logging.LogRecord):
record = copy.copy(record)
record.worker_id = _worker_id
if record.exc_info:
record.exc_info = ExcInfo(record.exc_info)
self._queue.put_nowait(record)
def __getstate__(self):
return self._queue,
def __setstate__(self, state):
logging.Handler.__init__(self)
(self._queue,) = state
class MultipleLineFormatter(logging.Formatter):
_SHORT_DOTS = '\n' + '.' * 6 + ' '
_LONG_DOTS = '\n' + '.' * 8 + ' '
def __init__(self, fmt=None, datefmt=None):
super().__init__(fmt, datefmt)
def _mlstr(self, s):
if self.usesTime():
return s.replace('\n', self._LONG_DOTS)
else:
return s.replace('\n', self._SHORT_DOTS)
def formatMessage(self, record):
return self._mlstr(super().formatMessage(record))
def formatException(self, ei):
return self._mlstr('\n' + super().formatException(ei)).lstrip()
def process_init(log_handler):
global _worker_id
match = re.search(r'-(\d+)$', mp.current_process().name)
if match:
_worker_id = int(match.group(1))
if logging.root.handlers:
logging.root.handlers[:] = ()
logging.basicConfig(
level=logging.DEBUG,
handlers=[
log_handler,
],
)
def raise_exception(prefix='raise_exception', cause=None):
if cause:
raise ValueError(f'{prefix}') from cause
else:
raise ValueError(f'{prefix}')
def raise_exception2(prefix='raise_exception'):
try:
raise_exception(prefix)
except ValueError as ex:
raise_exception(prefix + '.2', ex)
def raise_exc_in_exc(prefix='raise_exception'):
try:
raise_exception(prefix)
except ValueError:
raise_exception(prefix + '.another')
def process_task(taskid):
time.sleep((taskid % 2) * 0.01)
logger = logging.getLogger(f'process_task')
if taskid == 4:
try:
raise_exception('taskid=4')
except ValueError as ex:
logger.exception(ex)
elif taskid == 5:
try:
raise_exception2('taskid=5')
except ValueError as ex:
logger.exception(ex)
elif taskid == 6:
try:
raise_exc_in_exc('taskid=6')
except ValueError as ex:
logger.exception(ex)
else:
logger.warning('Log for task %d', taskid)
def main():
logging.basicConfig(
level=logging.DEBUG,
)
root_handler: logging.Handler = logging.root.handlers[0]
root_handler.setFormatter(MultipleLineFormatter(
fmt="%(asctime)s %(levelname)s:%(name)s:%(message)s",
datefmt='%H:%M:%S',
))
logger = logging.getLogger('main')
logger.warning('main() started')
logger.warning('This is a multiple-lines message\n2nd line')
record_queue = mp.Queue()
mlh = MasterLogHandler(logging.root.handlers[0], record_queue)
with mp.Pool(3, initializer=process_init, initargs=(mlh.new_worker(),)) as pool:
results = []
for taskid in range(8):
results.append(pool.apply_async(process_task, (taskid,)))
for ret in results:
ret.get()
logger.warning('main() done')
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment