Skip to content

Instantly share code, notes, and snippets.

@asford
Last active May 9, 2018 04:57
Show Gist options
  • Save asford/912ba39d55c135a510f3b982b160f434 to your computer and use it in GitHub Desktop.
Save asford/912ba39d55c135a510f3b982b160f434 to your computer and use it in GitHub Desktop.
Enabling pytest-cov in pytorch backward pass.
import threading
import pytest
@pytest.fixture
def pytorch_backward_coverage(cov):
"""Torch hook to enable coverage in backward pass.
The `cov` fixture is provided by pytest-cov.
Returns a hook function used to enable coverage tracing during
pytorch backward passes. Torch runs all backward passes in a
non-main thread, not spawned by the standard 'threading'
interface, so coverage does not trace the thread.
Example:
result = custom_func(input)
# enable the hook
result.register_hook(pytorch_backward_coverage)
# call backward via sum so hook fires before custom_op backward
result.sum().backward()
"""
if cov:
cov.collector.added_tracers = {threading.get_ident()}
def add_tracer(_):
tid = threading.get_ident()
if tid not in cov.collector.added_tracers:
print(f"pytorch backward trace: {tid}")
cov.collector.added_tracers.add(tid)
cov.collector._start_tracer()
else:
def add_tracer(_):
pass
return add_tracer
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment