Skip to content

Instantly share code, notes, and snippets.



Created Jan 24, 2019
What would you like to do?
MKL_NUM_THREADS python context manager
import ctypes
class MKLThreads(object):
_mkl_rt = None
def _mkl(cls):
if cls._mkl_rt is None:
cls._mkl_rt = ctypes.CDLL('')
except OSError:
cls._mkl_rt = ctypes.CDLL('mkl_rt.dll')
return cls._mkl_rt
def get_max_threads(cls):
return cls._mkl().mkl_get_max_threads()
def set_num_threads(cls, n):
assert type(n) == int
def __init__(self, num_threads):
self._n = num_threads
self._saved_n = self.get_max_threads()
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
from unittest import TestCase
from mkl import MKLThreads
class TestMKLThreads(TestCase):
def test_context(self):
n = MKLThreads.get_max_threads()
self.assertTrue(n > 1, "must run on multi core to test")
with MKLThreads(1):
self.assertEqual(MKLThreads.get_max_threads(), 1)
self.assertEqual(MKLThreads.get_max_threads(), n)

This comment has been minimized.

Copy link
Owner Author

@technic technic commented Jan 25, 2019

When using together with multiprocessing module you should invoke set_num_threads within child process for example:

import multiprocessing as mp
def init():
pool = mp.Pool(mp.cpu_count() // 2, initializer=init)

This would create a pool with each sub-process consuming two cores.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment