Last active
January 12, 2017 19:23
-
-
Save erikcs/3afcc7d6682b05d71e59a99afcf0cbfe to your computer and use it in GitHub Desktop.
scikit-learn random_state check
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# rcheck.py: check sklearn test files for class/function invocations that | |
# lack the optional `random_state` argument. For each file a missing invocation | |
# is found, prints out the class/function name followed by line number(s) | |
# | |
# Usage (from scikit-learn base directory) | |
# $ find * -type f | grep "/test_.*py$" | python rcheck.py | |
import fileinput | |
import inspect | |
import importlib | |
import re | |
ARGNAME = 'random_state' | |
def get_estimators(mod_name): | |
# Get name of imported callables that takes a `random_state` arg | |
module = importlib.import_module(mod_name) | |
random_estimators = [] | |
for el in dir(module): | |
obj = getattr(module, el) | |
try: | |
if ARGNAME in inspect.signature(obj).parameters: | |
random_estimators.append(el) | |
except: | |
pass | |
return random_estimators | |
def process_file(f_name, random_estimators): | |
# For each of the estimators in `random_estimators` search | |
# the source file `f_name` for invocations and see if these | |
# lack a `random_state` argument. This is done in a simple regex | |
# that matches up to the first closing parenthesis, so will miss | |
# invocations that contain function calls (i.e. nested parenthes) | |
# or invoke the estimator in other ways (through a class factory | |
# or maybe with **kwrds) | |
nmissing = 0 | |
msg = "\t{0} {1}" | |
with open(f_name, 'r') as f: | |
text = f.read() | |
for estimator in random_estimators: | |
linenums = [] | |
for m in re.finditer(estimator + '\([^\)]*', text): | |
match = text[m.start():m.end()] | |
if match.find(ARGNAME) == -1: | |
lineno = text.count('\n', 0, m.start()) + 1 | |
linenums.append(lineno) | |
nmissing += 1 | |
if nmissing != 0 and estimator == random_estimators[0]: | |
print(f_name) | |
if linenums: | |
print(msg.format(estimator, linenums)) | |
return nmissing | |
nmissing = 0 | |
nfiles = 0 | |
for f in fileinput.input(): | |
f_name = f.rstrip() | |
mod_name = f_name.replace('/', '.').replace('.py', '') | |
random_estimators = get_estimators(mod_name) | |
nmissing += process_file(f_name, random_estimators) | |
nfiles += 1 | |
print ("") | |
print ("Found {0} missing in {1} files".format(nmissing, nfiles)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment