Skip to content

Instantly share code, notes, and snippets.

@technic
Created January 24, 2019 16:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save technic/80e8d95858b187cd8ff8677bd5cc0fbb to your computer and use it in GitHub Desktop.
Save technic/80e8d95858b187cd8ff8677bd5cc0fbb to your computer and use it in GitHub Desktop.
MKL_NUM_THREADS python context manager
import ctypes
class MKLThreads(object):
_mkl_rt = None
@classmethod
def _mkl(cls):
if cls._mkl_rt is None:
try:
cls._mkl_rt = ctypes.CDLL('libmkl_rt.so')
except OSError:
cls._mkl_rt = ctypes.CDLL('mkl_rt.dll')
return cls._mkl_rt
@classmethod
def get_max_threads(cls):
return cls._mkl().mkl_get_max_threads()
@classmethod
def set_num_threads(cls, n):
assert type(n) == int
cls._mkl().mkl_set_num_threads(ctypes.byref(ctypes.c_int(n)))
def __init__(self, num_threads):
self._n = num_threads
self._saved_n = self.get_max_threads()
def __enter__(self):
self.set_num_threads(self._n)
return self
def __exit__(self, type, value, traceback):
self.set_num_threads(self._saved_n)
from unittest import TestCase
from mkl import MKLThreads
class TestMKLThreads(TestCase):
def test_context(self):
n = MKLThreads.get_max_threads()
self.assertTrue(n > 1, "must run on multi core to test")
with MKLThreads(1):
self.assertEqual(MKLThreads.get_max_threads(), 1)
self.assertEqual(MKLThreads.get_max_threads(), n)
@technic
Copy link
Author

technic commented Jan 25, 2019

When using together with multiprocessing module you should invoke set_num_threads within child process for example:

import multiprocessing as mp
def init():
    MKLThreads.set_num_threads(2)
pool = mp.Pool(mp.cpu_count() // 2, initializer=init)

This would create a pool with each sub-process consuming two cores.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment