Skip to content

Instantly share code, notes, and snippets.

@erikcs
Last active January 12, 2017 19:23
Show Gist options
  • Save erikcs/3afcc7d6682b05d71e59a99afcf0cbfe to your computer and use it in GitHub Desktop.
Save erikcs/3afcc7d6682b05d71e59a99afcf0cbfe to your computer and use it in GitHub Desktop.
scikit-learn random_state check
# rcheck.py: check sklearn test files for class/function invocations that
# lack the optional `random_state` argument. For each file a missing invocation
# is found, prints out the class/function name followed by line number(s)
#
# Usage (from scikit-learn base directory)
# $ find * -type f | grep "/test_.*py$" | python rcheck.py
import fileinput
import inspect
import importlib
import re
ARGNAME = 'random_state'
def get_estimators(mod_name):
# Get name of imported callables that takes a `random_state` arg
module = importlib.import_module(mod_name)
random_estimators = []
for el in dir(module):
obj = getattr(module, el)
try:
if ARGNAME in inspect.signature(obj).parameters:
random_estimators.append(el)
except:
pass
return random_estimators
def process_file(f_name, random_estimators):
# For each of the estimators in `random_estimators` search
# the source file `f_name` for invocations and see if these
# lack a `random_state` argument. This is done in a simple regex
# that matches up to the first closing parenthesis, so will miss
# invocations that contain function calls (i.e. nested parenthes)
# or invoke the estimator in other ways (through a class factory
# or maybe with **kwrds)
nmissing = 0
msg = "\t{0} {1}"
with open(f_name, 'r') as f:
text = f.read()
for estimator in random_estimators:
linenums = []
for m in re.finditer(estimator + '\([^\)]*', text):
match = text[m.start():m.end()]
if match.find(ARGNAME) == -1:
lineno = text.count('\n', 0, m.start()) + 1
linenums.append(lineno)
nmissing += 1
if nmissing != 0 and estimator == random_estimators[0]:
print(f_name)
if linenums:
print(msg.format(estimator, linenums))
return nmissing
nmissing = 0
nfiles = 0
for f in fileinput.input():
f_name = f.rstrip()
mod_name = f_name.replace('/', '.').replace('.py', '')
random_estimators = get_estimators(mod_name)
nmissing += process_file(f_name, random_estimators)
nfiles += 1
print ("")
print ("Found {0} missing in {1} files".format(nmissing, nfiles))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment