Last active
November 30, 2018 07:19
-
-
Save tomstitt/6fcf222b903037172a3ef8f0ef6af5e9 to your computer and use it in GitHub Desktop.
context manager for stdout/stderr redirect; non-python stdout/stderr goes through (replaced) sys.stdout/sys.stderr
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sys | |
import select | |
from threading import Thread | |
class redirect_output(): | |
max_read_size = 8192 | |
class __redir_obj(): | |
def __init__(self, outfile, buffervalue=0): | |
self.fd = outfile.fileno() | |
self.originalfile = outfile | |
self.buffervalue = buffervalue | |
def start(self): | |
self.dup_fd = os.dup(self.fd) | |
self.r_pipe, self.w_pipe = os.pipe() | |
os.dup2(self.w_pipe, self.fd) | |
def stop(self): | |
# remap old fd to starting fd | |
os.dup2(self.dup_fd, self.fd) | |
# close our pipes | |
os.close(self.r_pipe) | |
os.close(self.w_pipe) | |
# TODO: improve | |
def write(self, s): | |
# if str is unicode and isinstance(s, bytes): | |
if isinstance(s, bytes): | |
os.write(self.dup_fd, s) | |
else: | |
os.write(self.dup_fd, str(s).encode()) | |
def flush(self): | |
pass | |
def fileno(self): | |
return self.dup_fd | |
def __read_loop(self): | |
if self.debug: | |
self.err_obj.write("thread started\n") | |
while self.thread_running: | |
rfds, _, _ = select.select([self.err_obj.r_pipe, self.out_obj.r_pipe], [], [], self.polling_frequency) | |
if len(rfds) > 0: | |
if self.err_obj.r_pipe in rfds: | |
s = os.read(self.err_obj.r_pipe, self.max_read_size) | |
# EOF | |
if s == "": | |
if self.debug: | |
self.err_obj.write("got EOF on stderr") | |
continue | |
sys.stderr.write(s) | |
if self.out_obj.r_pipe in rfds: | |
s = os.read(self.out_obj.r_pipe, self.max_read_size) | |
# EOF | |
if s == "": | |
if self.debug: | |
self.err_obj.write("got EOF on stdout") | |
continue | |
sys.stdout.write(s) | |
if self.debug: | |
self.err_obj.write("thread terminated\n") | |
def __init__(self, polling_frequency=.05, debug=False): | |
self.out_obj = self.__redir_obj(sys.stdout, 1) | |
self.err_obj = self.__redir_obj(sys.stderr) | |
self.__stdout = sys.__stdout__ | |
self.__stderr = sys.__stderr__ | |
self.polling_frequency = polling_frequency | |
self.debug = debug | |
def __enter__(self): | |
# sys.stdout might be buffered, flush it | |
sys.stdout.flush() | |
self.out_obj.start() | |
self.err_obj.start() | |
sys.stdout = self.out_obj | |
sys.stderr = self.err_obj | |
sys.__stdout__ = self.out_obj | |
sys.__stderr__ = self.err_obj | |
# start reader | |
if self.debug: | |
self.err_obj.write("starting thread... ") | |
self.thread = Thread(target=self.__read_loop) | |
self.thread_running = True | |
self.thread.start() | |
def __exit__(self, *args, **kwargs): | |
# external stdout might be buffered | |
sys.stdout.flush() | |
if self.debug: | |
self.err_obj.write("joining thread... ") | |
self.thread_running = False | |
self.thread.join() | |
self.out_obj.stop() | |
self.err_obj.stop() | |
sys.stdout = self.out_obj.originalfile | |
sys.stderr = self.err_obj.originalfile | |
sys.__stdout__ = self.__stdout | |
sys.__stderr__ = self.__stderr |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment