Created
September 9, 2021 09:33
-
-
Save danieldk/d54cdbdeb7f2c1863e7cf6e1394b9b02 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import Optional, Callable, Any, Tuple | |
from ..model import Model | |
from ..util import use_nvtx_range | |
def with_nvtx_range( | |
layer: Model, | |
name: Optional[str] = None, | |
*, | |
forward_color: int = -1, | |
backprop_color: int = -1, | |
): | |
"""Layer that wraps any layer and marks the forward and backprop | |
phases as NVTX ranges for CUDA profiling. | |
By default, the name of the layer is used as the name of the range, | |
followed by the name of the pass. | |
""" | |
name = layer.name if name is None else name | |
orig_forward = layer._func | |
def forward(model: Model, X: Any, is_train: bool) -> Tuple[Any, Callable]: | |
with use_nvtx_range(f"{name} forward", forward_color): | |
layer_Y, layer_callback = orig_forward(model, X, is_train=is_train) | |
def backprop(dY: Any) -> Any: | |
with use_nvtx_range(f"{name} backprop", backprop_color): | |
return layer_callback(dY) | |
return layer_Y, backprop | |
layer._func = forward | |
return layer |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment