Last active
May 9, 2018 04:57
-
-
Save asford/912ba39d55c135a510f3b982b160f434 to your computer and use it in GitHub Desktop.
Enabling pytest-cov in pytorch backward pass.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import threading | |
import pytest | |
@pytest.fixture | |
def pytorch_backward_coverage(cov): | |
"""Torch hook to enable coverage in backward pass. | |
The `cov` fixture is provided by pytest-cov. | |
Returns a hook function used to enable coverage tracing during | |
pytorch backward passes. Torch runs all backward passes in a | |
non-main thread, not spawned by the standard 'threading' | |
interface, so coverage does not trace the thread. | |
Example: | |
result = custom_func(input) | |
# enable the hook | |
result.register_hook(pytorch_backward_coverage) | |
# call backward via sum so hook fires before custom_op backward | |
result.sum().backward() | |
""" | |
if cov: | |
cov.collector.added_tracers = {threading.get_ident()} | |
def add_tracer(_): | |
tid = threading.get_ident() | |
if tid not in cov.collector.added_tracers: | |
print(f"pytorch backward trace: {tid}") | |
cov.collector.added_tracers.add(tid) | |
cov.collector._start_tracer() | |
else: | |
def add_tracer(_): | |
pass | |
return add_tracer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment