Skip to content

Instantly share code, notes, and snippets.

@titu1994
Created August 2, 2023 18:34
Show Gist options
  • Save titu1994/e786fbd1efccd81f412bf76df5ff41c7 to your computer and use it in GitHub Desktop.
Save titu1994/e786fbd1efccd81f412bf76df5ff41c7 to your computer and use it in GitHub Desktop.
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Benchmark script to measure Numba fp16 vs fp32 memory cost for RNNT loss.
Usage:
# Compute and evaluate the default benchmark configuration
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32"
# Modifying benchmark parameters
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" \
-B "1,2,4,8,16,32" \
-T "200,400" \
-U "100,200" \
-V "28,1024 \
-H "640"
# Only evaluate previously computed benchmark results (without re-computation)
# Simplified results
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" --no-compute
# Breakdown of results into [data - loss - total]
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" --no-compute --full-results
# Calculate benchmark without allocating memory of the gradient tensor
python numba_memory_benchmark.py --results_dir "./results_fp16_vs_fp32" --no-grads
"""
import argparse
import os
os.environ['NUMBA_CUDA_USE_NVIDIA_BINDING'] = "1"
import pickle
import subprocess
from typing import List, Union, Tuple
import torch
from pytorch_lightning import seed_everything
from nemo.collections.asr.modules import rnnt
from nemo.collections.asr.losses.rnnt import RNNTLoss
import numba
from nemo.core.utils import numba_utils
###################################################################################
# UTILITY FUNCTIONS
def log_system():
"""
Log system information and whether Numba supports cuda and fp16 or not.
"""
print("Torch :", torch.__version__)
print("Numba :", numba.__version__)
# Print Numba FP16 supported
cuda_supported = numba_utils.numba_cuda_is_supported(numba.__version__)
fp16_supported, reason = numba_utils.is_numba_cuda_fp16_supported(return_reason=True)
print("Numba supports CUDA:", cuda_supported)
print("Numba supports CUDA FP16:", fp16_supported)
if not cuda_supported:
print("CUDA support not available. Exiting program...")
exit(1)
if not fp16_supported:
print("FP16 support not available. Reason:", reason)
print("Exiting program...")
exit(1)
print()
# Print CUDA environment
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, encoding='utf-8')
print(result.stdout)
print()
result = subprocess.run(['nvidia-smi', '-L'], capture_output=True, text=True, encoding='utf-8')
print(result.stdout)
print()
torch.cuda.empty_cache()
print("GPU Memory :", torch.cuda.memory_summary())
print()
print()
def load_results(path):
"""Load results from a pickle file."""
with open(path, 'rb') as f:
results = pickle.load(f)
return results
def save_results(results, path):
"""Save results to a pickle file."""
with open(path, 'wb') as f:
pickle.dump(results, f)
def print_results(results_path, simplify_results=True):
"""Display the results from the preserved run.
Args:
results_path: Path to the pickle file containing the results.
simplify_results: If True, only display the total memory cost. Otherwise, display the breakdown of memory cost
into [data - loss - total].
"""
results = load_results(results_path)
data_memory = [res['data_mem'] for res in results] # type: List[monitor_cuda_mem]
data_mem_fp16 = [data for data in data_memory if "float16" in data.scope] # type: List[monitor_cuda_mem]
data_mem_fp32 = [data for data in data_memory if "float32" in data.scope] # type: List[monitor_cuda_mem]
rnnt_memory = [res['rnnt_mem'] for res in results] # type: List[monitor_cuda_mem]
rnnt_mem_fp16 = [rnnt for rnnt in rnnt_memory if "float16" in rnnt.scope] # type: List[monitor_cuda_mem]
rnnt_mem_fp32 = [rnnt for rnnt in rnnt_memory if "float32" in rnnt.scope] # type: List[monitor_cuda_mem]
loss_memory = [res['loss_mem'] for res in results] # type: List[monitor_cuda_mem]
loss_mem_fp16 = [loss for loss in loss_memory if "float16" in loss.scope] # type: List[monitor_cuda_mem]
loss_mem_fp32 = [loss for loss in loss_memory if "float32" in loss.scope] # type: List[monitor_cuda_mem]
for data16, rnnt16, loss16, data32, rnnt32, loss32 in zip(
data_mem_fp16, rnnt_mem_fp16, loss_mem_fp16, data_mem_fp32, rnnt_mem_fp32, loss_mem_fp32
):
config = loss16.scope.replace("FP torch.float16", "").replace("FP torch.float32", "").strip()
if simplify_results:
fmt_str = (
f"{config.upper():36} | "
f"FP32 = {HumanBytes.format(max(data32.final_memory, rnnt32.memory_diff, loss32.final_memory)):10} | "
f"FP16 = {HumanBytes.format(max(data16.final_memory, rnnt16.final_memory, loss16.final_memory)):10}"
)
else:
fmt_str = (
f"{config.upper():36} | "
f"Data + RNNT Dec+Joint 32 = {rnnt32.memory_diff_human:10} | Data + RNNT Dec+Joint 16 = {rnnt16.memory_diff_human:10} |||| "
f"Loss Memory 32 = {loss32.memory_diff_human:10} | Loss Memory 16 = {loss16.memory_diff_human:10} |||| "
f"FP32 Total Memory = {HumanBytes.format(max(data32.final_memory, rnnt32.memory_diff, loss32.final_memory)):10} | "
f"FP16 Total Memory = {HumanBytes.format(max(data16.final_memory, rnnt16.final_memory, loss16.final_memory)):10}"
)
print(fmt_str)
"""
Utility context manager to monitor CUDA memory usage using PyTorch CUDA API.
"""
class monitor_cuda_mem:
"""
Context manager to monitor CUDA memory usage using PyTorch CUDA API.
"""
_CONTEXT_DEPTH = 0
ENABLED: bool = True # Globally enables or disabls the context manager
EMPTY: bool = False # If True, will perform torch.cuda.empty_cache() before and after the context
DEVICE: int = 0 # CUDA device to monitor
VERBOSE: bool = True # If true will print the scope of memory that was allocated and freed
PRECISION: int = 4 # Number of decimal places to print for memory usage
def __init__(
self, scope, empty=None, enabled: bool = None, device: int = None, verbose: bool = None, precision: int = None
):
self.scope = scope
self.empty = empty if empty is not None else monitor_cuda_mem.EMPTY
self.enabled = enabled if enabled is not None else monitor_cuda_mem.ENABLED
self.device = device if device is not None else monitor_cuda_mem.DEVICE
self.verbose = verbose if verbose is not None else monitor_cuda_mem.VERBOSE
self.precision = precision if precision is not None else monitor_cuda_mem.PRECISION
self.reset()
def reset(self):
self.memory_diff = None
self.memory_diff_human = None
self.initial_memory = None
self.final_memory = None
def __enter__(self):
monitor_cuda_mem._CONTEXT_DEPTH += 1
if self.enabled:
if self.verbose:
self.print_pad()
print(f"|> {self.scope}")
self.initial_memory = torch.cuda.max_memory_allocated(self.device)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.enabled:
if self.empty:
torch.cuda.empty_cache()
self.final_memory = torch.cuda.max_memory_allocated(self.device)
self.memory_diff = self.final_memory - self.initial_memory
self.memory_diff_human = HumanBytes.format(self.memory_diff, precision=self.precision)
if self.verbose:
self.print_pad()
print(f"{self.scope} |> {self.memory_diff_human}")
monitor_cuda_mem._CONTEXT_DEPTH -= 1
@classmethod
def print_pad(cls):
print('\t' * (cls._CONTEXT_DEPTH - 1), end='')
# Shortened form of the answer from https://stackoverflow.com/a/63839503
# Used to format bytes into human-readable format.
class HumanBytes:
# fmt: off
METRIC_LABELS: List[str] = ["B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"]
BINARY_LABELS: List[str] = ["B", "KiB", "MiB", "GiB", "TiB", "PiB", "EiB", "ZiB", "YiB"]
PRECISION_OFFSETS: List[float] = [5 * (0.1 ** x) for x in range(1, 22)] # PREDEFINED FOR SPEED.
PRECISION_FORMATS: List[str] = [("{}{:." + str(ratio) + "f} {}") for ratio in range(len(PRECISION_OFFSETS))] # PREDEFINED FOR SPEED.
# fmt: on
@staticmethod
def format(num: Union[int, float], metric: bool = False, precision: int = 1) -> str:
assert isinstance(num, (int, float)), "num must be an int or float"
assert isinstance(metric, bool), "metric must be a bool"
assert (
isinstance(precision, int) and precision >= 0 and precision <= len(HumanBytes.PRECISION_OFFSETS)
), "precision must be an int (range 0-20)"
unit_labels = HumanBytes.METRIC_LABELS if metric else HumanBytes.BINARY_LABELS
last_label = unit_labels[-1]
unit_step = 1000 if metric else 1024
unit_step_thresh = unit_step - HumanBytes.PRECISION_OFFSETS[precision]
is_negative = num < 0
if is_negative: # Faster than ternary assignment or always running abs().
num = abs(num)
for unit in unit_labels:
if num < unit_step_thresh:
break
if unit != last_label:
num /= unit_step
return HumanBytes.PRECISION_FORMATS[precision].format("-" if is_negative else "", num, unit)
###################################################################################
# DATA UTILS
# Global input variables. Used to store data for the benchmark.
global x, x_len, y, y_len
DEVICE = "cuda"
def data_gen(bs, t=200, u=100, v=1024, h=640, dtype=torch.float32):
"""
Generate seeded data for the benchmark. Every call to this function will generate the same data for a given
set of input parameters.
Args:
bs: Batch Size
t: Audio Timesteps
u: Text Tokens
v: Vocabulary Size
h: RNNT Hidden size
dtype: torch.dtype
Returns:
x: Audio data of shape [bs, h, t]
x_len: Audio length of shape [bs]
y: Text data of shape [bs, u - 1]
y_len: Text length of shape [bs]
"""
# utilize global variables for input to loss
torch.cuda.empty_cache()
torch.manual_seed(0)
x = torch.randn(bs, h, t, dtype=dtype, device=DEVICE, requires_grad=False)
x_len = torch.randint(t, size=[bs], device=DEVICE, dtype=torch.int64)
y = torch.randint(v, size=[bs, u - 1], device=DEVICE, dtype=torch.int64)
y_len = torch.randint(u, size=[bs], device=DEVICE, dtype=torch.int64)
# enforce some RNNT input constraints
rand_idx = torch.randint(bs, size=[1])
x_len[rand_idx] = t
y_len[rand_idx] = u - 1
return x, x_len, y, y_len
def str_to_int_list(string: str) -> List[int]:
return [int(x) for x in string.split(',')] if string else []
###################################################################################
# MODEL UTILS
def rnnt_decoder_joint(v=1024, h=640, dtype=torch.float32, requires_grad=False) -> Tuple[rnnt.RNNTDecoder, rnnt.RNNTJoint]:
"""Build a RNNTDecoder and RNNTJoint with the given parameters."""
seed_everything(0)
prednet = {'pred_hidden': h, 'pred_rnn_layers': 1}
rnnt_decoder = rnnt.RNNTDecoder(prednet, vocab_size=v)
jointnet = {'joint_hidden': h, 'encoder_hidden': h, 'pred_hidden': h, 'activation': 'relu'}
rnnt_joint = rnnt.RNNTJoint(jointnet, num_classes=v, fuse_loss_wer=False)
rnnt_decoder.to(dtype=dtype, device=DEVICE)
rnnt_joint.to(dtype=dtype, device=DEVICE)
# Setup zero grad of params if needed
if requires_grad:
with torch.no_grad():
for p in rnnt_decoder.parameters():
p.requires_grad = True
p.grad = torch.zeros_like(p)
for p in rnnt_joint.parameters():
p.requires_grad = True
p.grad = torch.zeros_like(p)
return rnnt_decoder, rnnt_joint
def rnnt_forward(x, x_len, y, y_len, rnnt_decoder, rnnt_joint):
"""Run the forward pass of the RNNTDecoder and RNNTJoint."""
g, target_length, states = rnnt_decoder(targets=y, target_length=y_len)
acts = rnnt_joint(
encoder_outputs=x, decoder_outputs=g, encoder_lengths=x_len, transcripts=y, transcript_lengths=y_len
)
return acts
def check_memory_numba(rnnt_loss, x, x_len, y, y_len, requires_grad=False):
"""Compute the RNNT Loss on the activations check the memory consumed by the Numba kernel."""
loss = rnnt_loss(log_probs=x, targets=y, input_lengths=x_len, target_lengths=y_len)
if requires_grad:
loss.sum().backward() # compute gradients
return loss
def exec_closure(args):
"""
Closure function for the benchmark. This function is called by the benchmarking script and is responsible for
running the benchmark and returning the results.
Returns:
results: Path to a pickle file, containing List of measurements from the benchmark.
"""
# Compare takes a list of measurements which we'll save in results.
results = []
torch.cuda.empty_cache()
basedir = args.results_dir
if not os.path.exists(basedir):
os.makedirs(basedir, exist_ok=True)
results_path = os.path.join(basedir, f'rnnt_results_requires_grad_{str(args.requires_grad)}.pkl')
# Parse benchmark arguments
batchsizes = str_to_int_list(args.B)
audio_lens = str_to_int_list(args.T)
target_lens = str_to_int_list(args.U)
vocab_sizes = str_to_int_list(args.V)
hidden_sizes = str_to_int_list(args.H)
print()
print("*" * 100)
print()
print("Gradients computed :", args.requires_grad)
print()
# If we're not computing the benchmark, just return the results path
if not args.compute:
return results_path
# Save the empty result list, thereby resetting results
save_results(results, results_path)
del results
for b in batchsizes: # 1, 4, 8, 16, 32, 64 (on 48 GB GPUs)
for t in audio_lens: # 200, 400, 600 (LibriSpeech with 4x and 8x stride, on 32 GB GPUs)
for u in target_lens: # 100, 200 # (char enc, subword enc)
for v in vocab_sizes: # 28, 1024 # (char encoding, Conformer RNNT Vocab Size)
for h in hidden_sizes: # 640, 1024 # (common hidden size of encoder, decoder, joint)
for dtype in [
torch.float32,
torch.float16,
]:
# Access global dataset and reset
global x, x_len, y, y_len
x = None
x_len = None
y = None
y_len = None
# Reset CUDA memory stats
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
torch.cuda.reset_accumulated_memory_stats()
torch.cuda.reset_max_memory_cached()
torch.cuda.reset_max_memory_allocated()
# Setup CUDA monitor flags
monitor_cuda_mem.DEVICE = 0
monitor_cuda_mem.EMPTY = False
monitor_cuda_mem.PRECISION = 2
monitor_cuda_mem.VERBOSE = False
# sub_label are the rows
# description is the column
sub_label = f'[b={b}, t={t}, u={u}, v={v}, h={h}]'
print("Computing :", sub_label)
# Numba FP 32 or 16 depending on `dtype`
env = f'FP {dtype}'
# Build batch of samples with seed set
with monitor_cuda_mem(f'Data {dtype}', empty=False) as datagen_mem:
x, x_len, y, y_len = data_gen(b, t, u, v, h, dtype=dtype)
if args.requires_grad:
x.requires_grad_(True)
with torch.no_grad():
x.grad = torch.zeros_like(x, dtype=dtype)
print("Batch data memory", datagen_mem.memory_diff_human)
# Build new RNNT decoder and joint
with monitor_cuda_mem(f'{env} {sub_label}', empty=False) as rnnt_mem:
dec, joint = rnnt_decoder_joint(v, h=h, dtype=dtype, requires_grad=args.requires_grad)
# Compute the joint activations
# (we don't need to perform log-softmax due to fused kernel)
acts = rnnt_forward(x, x_len, y, y_len, dec, joint)
print("RNNT Decoder+Joint memory", rnnt_mem.memory_diff_human)
# Compute the loss and memory cost
with monitor_cuda_mem(f'{env} {sub_label}', empty=False) as loss_mem:
blank = x.shape[-1] - 1
rnnt_loss = RNNTLoss(
num_classes=blank,
reduction='sum',
loss_name='warprnnt_numba',
loss_kwargs=dict(fastemit_lambda=0.0, clamp=-1.0),
)
# Compute the loss and check memory
# Note: We are not measuring speed, and therefore Numba JIT compile time is not measured
# Therefore we skip performing a warmup run of the loss function
unused_value_ = check_memory_numba(
rnnt_loss, acts, x_len, y, y_len, requires_grad=args.requires_grad
)
result = {
'data_mem': datagen_mem,
'rnnt_mem': rnnt_mem,
'loss_mem': loss_mem,
}
print(f"Loss memory ({dtype})", loss_mem.memory_diff_human)
print(f"Peak memory allocated : {HumanBytes.format(torch.cuda.max_memory_allocated())}")
print()
# Save results to disk
results = load_results(results_path)
results.append(result)
save_results(results, results_path)
# Clean up memory for next benchmark
del results, unused_value_
del dec, joint, acts
del rnnt_loss, blank
del x, x_len, y, y_len
torch.cuda.empty_cache()
return results_path
def parse_args():
parser = argparse.ArgumentParser(description="RNNT Loss Benchmark")
parser.add_argument(
'-B', '--batch-size', dest='B', type=str, default='1,4,8,16,32', help="Batch sizes to benchmark"
)
parser.add_argument(
'-T', '--audio-len', dest='T', type=str, default='200,400', help="Max audio lengths to benchmark"
)
parser.add_argument(
'-U', '--text-len', dest='U', type=str, default='100,200', help="Max text sequence lengths to benchmark"
)
parser.add_argument('-V', '--vocab-size', dest='V', type=str, default='28,1024', help="Vocab sizes to benchmark")
parser.add_argument('-H', '--hidden-size', dest='H', type=str, default='640', help="Hidden size to RNNT Joint")
parser.add_argument("--results_dir", type=str, default='./numba_fp32_vs_fp16', help="Name of results directory")
parser.add_argument(
"--no-compute",
dest='compute',
action='store_false',
help="Whether to avoid computing results or not. "
"Used when only printing the results without recomputation.",
)
parser.add_argument(
'--no-grads',
dest='requires_grad',
action='store_false',
help="Whether to avoid calculating gradients or not to compare memory usage",
)
parser.add_argument(
'--full-results',
dest='simplify_results',
action='store_false',
help="Print full results - including breakdown of data storage size and activations",
)
parser.set_defaults(compute=True, require_grad=True)
args = parser.parse_args()
return args
def main(args):
log_system()
results_path = exec_closure(args)
print("\n\n")
print("Results::")
print_results(results_path, simplify_results=args.simplify_results)
if __name__ == '__main__':
args = parse_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment