Created
March 31, 2019 19:14
-
-
Save tsangwpx/a3241a05a4289e853b57d4bab1ce9f2e to your computer and use it in GitHub Desktop.
Logging with multiprocessing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Logging with multiprocessing | |
LogRecord are sent back to the master process | |
License: MIT | |
Author: Aaron Tsang | |
""" | |
from __future__ import annotations | |
import copy | |
import logging | |
import multiprocessing as mp | |
import re | |
import threading | |
import time | |
from types import TracebackType | |
from typing import Type | |
import tblib | |
_worker_id = -1 | |
class ExcInfo: | |
""" | |
Persist the exc_info tuple as a picklable object | |
""" | |
def __init__(self, exc_info): | |
if isinstance(exc_info, BaseException): | |
exc_info = (type(exc_info), exc_info, exc_info.__traceback__) | |
assert isinstance(exc_info, tuple) | |
self.cls, self.exc, self.tb = exc_info # type: (Type[BaseException], BaseException, TracebackType) | |
@property | |
def exc_info(self): | |
return self.cls, self.exc, self.tb | |
def __getstate__(self): | |
cause = self.exc.__cause__ | |
context = self.exc.__context__ | |
return ( | |
self.cls, self.exc, tblib.Traceback(self.tb), | |
ExcInfo(cause) if cause else None, | |
ExcInfo(context) if context else None, | |
self.exc.__suppress_context__ | |
) | |
def __setstate__(self, state): | |
cls, exc, tb, cause, context, supress = state | |
if cause: | |
exc.__cause__ = cause.exc | |
if context: | |
exc.__context__ = context.exc | |
exc.__suppress_context__ = supress | |
self.cls = cls | |
self.exc = exc | |
self.tb = tb.as_traceback() | |
def __repr__(self): | |
return f'{type(self).__qualname__}({self.exc!r})' | |
class MasterLogHandler(logging.Handler): | |
def __init__(self, backend: logging.Handler, queue=None): | |
super().__init__() | |
self.looping = True | |
self.backend = backend | |
self._queue = mp.Queue() if queue is None else queue | |
t = threading.Thread(target=self._loop, daemon=True) | |
t.start() | |
self.worker_handler = WorkerLogHandler(self._queue) | |
def _loop(self): | |
while self.looping: | |
record: logging.LogRecord = self._queue.get() | |
if isinstance(record.exc_info, ExcInfo): | |
record.exc_info = record.exc_info.exc_info | |
self.handle(record) | |
def emit(self, record): | |
self.backend.handle(record) | |
def close(self): | |
self._queue.close() | |
def new_worker(self): | |
return WorkerLogHandler(self._queue) | |
class WorkerLogHandler(logging.Handler): | |
def __init__(self, queue): | |
super().__init__() | |
self._queue: mp.Queue = queue | |
def emit(self, record: logging.LogRecord): | |
record = copy.copy(record) | |
record.worker_id = _worker_id | |
if record.exc_info: | |
record.exc_info = ExcInfo(record.exc_info) | |
self._queue.put_nowait(record) | |
def __getstate__(self): | |
return self._queue, | |
def __setstate__(self, state): | |
logging.Handler.__init__(self) | |
(self._queue,) = state | |
class MultipleLineFormatter(logging.Formatter): | |
_SHORT_DOTS = '\n' + '.' * 6 + ' ' | |
_LONG_DOTS = '\n' + '.' * 8 + ' ' | |
def __init__(self, fmt=None, datefmt=None): | |
super().__init__(fmt, datefmt) | |
def _mlstr(self, s): | |
if self.usesTime(): | |
return s.replace('\n', self._LONG_DOTS) | |
else: | |
return s.replace('\n', self._SHORT_DOTS) | |
def formatMessage(self, record): | |
return self._mlstr(super().formatMessage(record)) | |
def formatException(self, ei): | |
return self._mlstr('\n' + super().formatException(ei)).lstrip() | |
def process_init(log_handler): | |
global _worker_id | |
match = re.search(r'-(\d+)$', mp.current_process().name) | |
if match: | |
_worker_id = int(match.group(1)) | |
if logging.root.handlers: | |
logging.root.handlers[:] = () | |
logging.basicConfig( | |
level=logging.DEBUG, | |
handlers=[ | |
log_handler, | |
], | |
) | |
def raise_exception(prefix='raise_exception', cause=None): | |
if cause: | |
raise ValueError(f'{prefix}') from cause | |
else: | |
raise ValueError(f'{prefix}') | |
def raise_exception2(prefix='raise_exception'): | |
try: | |
raise_exception(prefix) | |
except ValueError as ex: | |
raise_exception(prefix + '.2', ex) | |
def raise_exc_in_exc(prefix='raise_exception'): | |
try: | |
raise_exception(prefix) | |
except ValueError: | |
raise_exception(prefix + '.another') | |
def process_task(taskid): | |
time.sleep((taskid % 2) * 0.01) | |
logger = logging.getLogger(f'process_task') | |
if taskid == 4: | |
try: | |
raise_exception('taskid=4') | |
except ValueError as ex: | |
logger.exception(ex) | |
elif taskid == 5: | |
try: | |
raise_exception2('taskid=5') | |
except ValueError as ex: | |
logger.exception(ex) | |
elif taskid == 6: | |
try: | |
raise_exc_in_exc('taskid=6') | |
except ValueError as ex: | |
logger.exception(ex) | |
else: | |
logger.warning('Log for task %d', taskid) | |
def main(): | |
logging.basicConfig( | |
level=logging.DEBUG, | |
) | |
root_handler: logging.Handler = logging.root.handlers[0] | |
root_handler.setFormatter(MultipleLineFormatter( | |
fmt="%(asctime)s %(levelname)s:%(name)s:%(message)s", | |
datefmt='%H:%M:%S', | |
)) | |
logger = logging.getLogger('main') | |
logger.warning('main() started') | |
logger.warning('This is a multiple-lines message\n2nd line') | |
record_queue = mp.Queue() | |
mlh = MasterLogHandler(logging.root.handlers[0], record_queue) | |
with mp.Pool(3, initializer=process_init, initargs=(mlh.new_worker(),)) as pool: | |
results = [] | |
for taskid in range(8): | |
results.append(pool.apply_async(process_task, (taskid,))) | |
for ret in results: | |
ret.get() | |
logger.warning('main() done') | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment