Last active
September 11, 2022 17:16
-
-
Save csarofeen/a5c214637f24d06a8f07c0cee3ca83ed to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from torch._C._nvfuser import Fusion, FusionCache, FusionDefinition, DataType | |
import torch | |
import torch.nn.functional as F | |
import functorch | |
from functorch.compile import memory_efficient_fusion | |
from copy import deepcopy | |
from typing import List | |
import time | |
import functools | |
import random | |
random.seed(42) | |
if torch.__version__ < (1, 12, 0): | |
raise RuntimeError( | |
"PyTorch >= 1.12.0 required, but your environment uses torch=={}".format( | |
torch.__version__ | |
) | |
) | |
major, minor, _ = functorch.__version__.split(".") | |
if int(major) == 0 and int(minor) < 2: | |
raise RuntimeError( | |
"FuncTorch >= 0.2.0 required, but your environment uses functorch=={}".format( | |
functorch.__version__ | |
) | |
) | |
# Slide - FUSION DEFINITION | |
# Start with the composite definition that we would like optimized | |
def composite_definition( | |
input1: torch.Tensor, | |
input2: torch.Tensor, | |
weight: torch.Tensor, | |
bias1: torch.Tensor, | |
bias2: torch.Tensor, | |
normalization_axis: int, | |
dropout_prob: float, | |
) -> torch.Tensor: | |
bias1_out = input1 + bias1 | |
dropout_out = F.dropout(bias1_out, dropout_prob, training=True) | |
norm_input = dropout_out + input2 | |
norm_output = F.layer_norm( | |
norm_input, (input1.size(normalization_axis),), weight, bias2 | |
) | |
return norm_output | |
# Slide - INITIALIZE TENSOR TO USE | |
# Setup and initialize tensors and parameters | |
input_size = [64, 128, 1024] | |
device = "cuda" | |
dtype = torch.float32 | |
# Create sample inputs | |
input1 = torch.randn(*input_size, device=device, | |
dtype=dtype, requires_grad=True) | |
input2 = torch.rand_like(input1).requires_grad_() | |
# Precompute a grad output tensor, for this example it's the same size | |
# as the inputs | |
grad_output = torch.rand_like(input1) | |
# Randomly initialize the model parameters | |
weight = torch.nn.Parameter(torch.randn( | |
input_size[2], dtype=dtype, device=device)) | |
bias1 = torch.nn.Parameter(torch.randn( | |
input_size[2], dtype=dtype, device=device)) | |
bias2 = torch.nn.Parameter(torch.randn( | |
input_size[2], dtype=dtype, device=device)) | |
parameters = [input1, input2, weight, bias1, bias2] | |
# Slide - PROFILING UTILITY | |
# Utility to profile the different workloads | |
def profile_workload(forward_func, backward_func, grad_output, iteration_count=100, label=""): | |
# Perform warm-up iterations | |
for _ in range(3): | |
if grad_output is not None: | |
for p in parameters: | |
if p.grad is not None: | |
p.grad = None | |
# Run model, forward and backward | |
output = forward_func() | |
if grad_output is not None: | |
if backward_func is not None: | |
backward_func(grad_output, *output[1:]) | |
else: | |
output.backward(grad_output) | |
# delete gradiens to avoid profiling the gradient accumulation | |
# for p in parameters: | |
# p.grad = None | |
# Synchronize the GPU before starting the timer | |
torch.cuda.synchronize() | |
start = time.perf_counter() | |
for _ in range(iteration_count): | |
if grad_output is not None: | |
for p in parameters: | |
if p.grad is not None: | |
p.grad = None | |
# Run model, forward and backward | |
output = forward_func() | |
if grad_output is not None: | |
if backward_func is not None: | |
backward_func(grad_output, *output[1:]) | |
else: | |
output.backward(grad_output) | |
# delete gradiens to avoid profiling the gradient accumulation | |
# for p in parameters: | |
# p.grad = None | |
# Synchronize the GPU before stopping the timer | |
torch.cuda.synchronize() | |
stop = time.perf_counter() | |
iters_per_second = iteration_count / (stop - start) | |
if label: | |
print(label) | |
print("Average iterations per second: {:.2f}".format(iters_per_second)) | |
# Slide - RUN WITH EAGER MODE | |
# Run and profile eager mode execution on the composite definition of our | |
# operations. | |
func = functools.partial( | |
composite_definition, | |
input1, | |
input2, | |
weight, | |
bias1, | |
bias2, | |
normalization_axis=2, | |
dropout_prob=0.1, | |
) | |
output_eager = profile_workload( | |
func, None, grad_output, iteration_count=100, label="Eager Mode - Composite definition" | |
) | |
# Slide - TORCHSCRIPT | |
# Script the function for fusion with nvFuser | |
scripted_composite_definition = torch.jit.script(composite_definition) | |
func = functools.partial( | |
scripted_composite_definition, | |
input1, | |
input2, | |
weight, | |
bias1, | |
bias2, | |
normalization_axis=2, | |
dropout_prob=0.1, | |
) | |
profile_workload( | |
func, None, grad_output, iteration_count=100, label="TorchScript - Composite definition" | |
) | |
# Slide - DYNAMIC SHAPES | |
# Create some dynamic shapes to run through the torch scripted model | |
SHAPE_COUNT = 20 | |
dynamic_sizes = deepcopy(input_size) | |
inputs1: List[torch.Tensor] = [] | |
inputs2: List[torch.Tensor] = [] | |
grad_outputs: List[torch.Tensor] = [] | |
# Slide - DYNAMIC SHAPES | |
# Create the shapes | |
for _ in range(SHAPE_COUNT): | |
dynamic_sizes[0] = input_size[0] + random.randrange(-2, 3) | |
dynamic_sizes[1] = input_size[1] + random.randrange(-2, 3) | |
input = torch.randn(*dynamic_sizes, device=device, | |
dtype=dtype, requires_grad=True) | |
inputs1.append(input) | |
inputs2.append(torch.rand_like(input)) | |
grad_outputs.append(torch.rand_like(input)) | |
# Manually perform profiling since the profiling utility expects static shapes | |
# Perform warm-up iterations | |
for _ in range(3): | |
dynamic_input1 = inputs1[0] | |
dynamic_input2 = inputs2[0] | |
dynamic_grad_output = grad_outputs[0] | |
# Run model, forward and backward | |
output = scripted_composite_definition( | |
dynamic_input1, | |
dynamic_input2, | |
weight, | |
bias1, | |
bias2, | |
normalization_axis=2, | |
dropout_prob=0.1, | |
) | |
output.backward(dynamic_grad_output) | |
iteration_count = 100 | |
# Synchronize the GPU before starting the timer | |
torch.cuda.synchronize() | |
start = time.perf_counter() | |
for i in range(iteration_count): | |
dynamic_input1 = inputs1[i % SHAPE_COUNT] | |
dynamic_input2 = inputs2[i % SHAPE_COUNT] | |
dynamic_grad_output = grad_outputs[i % SHAPE_COUNT] | |
dynamic_parameters = [dynamic_input1, dynamic_input2, weight, bias1, bias2] | |
# Run model, forward and backward | |
output = scripted_composite_definition( | |
dynamic_input1, | |
dynamic_input2, | |
weight, | |
bias1, | |
bias2, | |
normalization_axis=2, | |
dropout_prob=0.1, | |
) | |
output.backward(dynamic_grad_output) | |
# Delete the gradients to avoid profiling the gradient accumulation | |
for p in dynamic_parameters: | |
p.grad = None | |
# Synchronize the GPU before stopping the timer | |
torch.cuda.synchronize() | |
stop = time.perf_counter() | |
iters_per_second = iteration_count / (stop - start) | |
print("TorchScript - Random Sizes") | |
print("Average iterations per second: {:.2f}".format(iters_per_second)) | |
# Slide - THERE MUST BE AN EASIER WAY | |
# Change the fusion definition to be based on primitive operations that are easily modified | |
def primitive_definition( | |
input1: torch.Tensor, | |
input2: torch.Tensor, | |
weight: torch.Tensor, | |
bias1: torch.Tensor, | |
bias2: torch.Tensor, | |
normalization_axis: int, | |
dropout_prob: float, | |
keepdim: bool, | |
) -> torch.Tensor: | |
bias1_out = input1 + bias1 | |
dropout_out = F.dropout(bias1_out, dropout_prob, training=True) | |
norm_input = dropout_out + input2 | |
mean = norm_input.mean(normalization_axis, keepdim=keepdim) | |
diff = norm_input - mean | |
diff_sq = diff * diff | |
var = diff_sq.mean(normalization_axis, keepdim=keepdim) | |
pre_shift_scale_norm_output = (norm_input - mean) / torch.sqrt(var + 1e-12) | |
norm_output = weight * pre_shift_scale_norm_output + bias2 | |
return norm_output | |
# Profile the definition with Eager Mode | |
func = functools.partial( | |
primitive_definition, | |
input1, | |
input2, | |
weight, | |
bias1, | |
bias2, | |
normalization_axis=2, | |
dropout_prob=0.1, | |
keepdim=True, | |
) | |
profile_workload( | |
func, None, grad_output, iteration_count=100, label="Eager Mode - Primitive Definition" | |
) | |
# Profile primitive definition with TorchScript | |
scripted_primitive_definition = torch.jit.script(primitive_definition) | |
func = functools.partial( | |
scripted_primitive_definition, | |
input1, | |
input2, | |
weight, | |
bias1, | |
bias2, | |
normalization_axis=2, | |
dropout_prob=0.1, | |
keepdim=True, | |
) | |
profile_workload( | |
func, None, grad_output, iteration_count=100, label="TorchScript - Primitive definition" | |
) | |
# Slide - FUNCTORCH | |
# Modify the definition to pull constants into the function to help FuncTorch with tracing. | |
# This would be automatically done if using TorchDynamo | |
def primitive_definition_for_memory_efficient_fusion( | |
input1: torch.Tensor, | |
input2: torch.Tensor, | |
weight: torch.Tensor, | |
bias1: torch.Tensor, | |
bias2: torch.Tensor, | |
) -> torch.Tensor: | |
bias1_out = input1 + bias1 | |
dropout_out = F.dropout(bias1_out, 0.1, training=True) | |
norm_input = dropout_out + input2 | |
mean = norm_input.mean(2, keepdim=True) | |
diff = norm_input - mean | |
diff_sq = diff * diff | |
var = diff_sq.mean(2, keepdim=True) | |
pre_shift_scale_norm_output = (norm_input - mean) / torch.sqrt(var + 1e-12) | |
norm_output = weight * pre_shift_scale_norm_output + bias2 | |
return norm_output | |
# Optimize the model with FuncTorch tracing and the memory efficiency | |
# optimization pass | |
memory_efficient_primitive_definition = memory_efficient_fusion( | |
primitive_definition_for_memory_efficient_fusion | |
) | |
# Profile the primitive definition optimized with FuncTorch | |
func = functools.partial( | |
memory_efficient_primitive_definition, input1, input2, weight, bias1, bias2 | |
) | |
profile_workload( | |
func, | |
None, | |
grad_output, | |
iteration_count=100, | |
label="FuncTorch - Primitive definition", | |
) | |
# Slide - NVFUSER PYTHON FRONTEND | |
# FuncTorch does not have suport for dynamic shapes yet, however, nvFuser does through it's Python frontend | |
# Create the verbose but direct implementation in nvFuser | |
def nvfuser_fusion_dropout_layer_norm_fwd( | |
fd: FusionDefinition, | |
normalization_axis: int, | |
norm_size: int, | |
input_shape: List[int], | |
eps: float, | |
keepDim: bool, | |
dropout_prob: float | |
) -> None: | |
input1 = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) | |
input2 = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) | |
weights = fd.define_tensor( | |
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) | |
bias1 = fd.define_tensor( | |
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) | |
bias2 = fd.define_tensor( | |
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) | |
prob = fd.define_constant(1. - dropout_prob) | |
scale = fd.define_constant(1.0/(1 - dropout_prob)) | |
plus_bias = fd.ops.add(input1, bias1) | |
rand = fd.ops.rand_like(plus_bias) | |
dropout_mask = fd.ops.le(rand, prob) | |
apply_mask = fd.ops.mul(plus_bias, dropout_mask) | |
dropout = fd.ops.mul(apply_mask, scale) | |
layer_norm_in = fd.ops.add(dropout, input2) | |
var, mean = fd.ops.var_mean( | |
layer_norm_in, axes=[normalization_axis], correction=0, keepdim=keepDim) | |
eps_const = fd.define_constant(eps) | |
var_eps = fd.ops.add(var, eps_const) | |
invstd = fd.ops.rsqrt(var_eps) | |
diff = fd.ops.sub(layer_norm_in, mean) | |
pre_scale_bias = fd.ops.mul(diff, invstd) | |
weights_bcast = fd.ops.broadcast_in_dim( | |
weights, output_shape=input_shape, broadcast_dims=[2]) | |
scale = fd.ops.mul(pre_scale_bias, weights_bcast) | |
bias_bcast = fd.ops.broadcast_in_dim( | |
bias2, output_shape=input_shape, broadcast_dims=[2]) | |
out = fd.ops.add(scale, bias_bcast) | |
fd.add_output(out) | |
fd.add_output(dropout_mask) | |
fd.add_output(layer_norm_in) | |
fd.add_output(mean) | |
fd.add_output(invstd) | |
# Create the object to store the operation definition | |
fs_dropout_layer_norm_fwd = Fusion() | |
with FusionDefinition(fs_dropout_layer_norm_fwd) as fd: | |
# Make the operation definition | |
nvfuser_fusion_dropout_layer_norm_fwd( | |
fd, 2, input1.size()[2], input1.size(), 1e-12, True, 0.1) | |
# Wrap the nvFuser implementation into a callable | |
def python_frontend_dropout_layer_norm_fwd( | |
fs: Fusion, | |
input1: torch.Tensor, | |
input2: torch.Tensor, | |
weights: torch.Tensor, | |
bias1: torch.Tensor, | |
bias2: torch.Tensor | |
) -> List[torch.Tensor]: | |
out = fs.execute([input1, input2, weights, bias1, bias2]) | |
return out | |
func_fwd = functools.partial(python_frontend_dropout_layer_norm_fwd, | |
fs_dropout_layer_norm_fwd, input1, input2, weight, bias1, bias2) | |
# nvFuser's frontend doesn't have auto-differentiation, also manually define the backwards operations. | |
def nvfuser_fusion_dropout_layer_norm_bwd( | |
fd: FusionDefinition, | |
normalization_axis: int, | |
norm_size: int, | |
input_shape: List[int], | |
eps: float, | |
keepDim: bool, | |
dropout_prob: float | |
) -> None: | |
weights = fd.define_tensor( | |
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) | |
grad_out = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) | |
dropout_mask = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Bool) | |
layer_norm_in = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) | |
mean = fd.define_tensor( | |
symbolic_sizes=[-1, -1, 1], contiguous=[True, True, True], dtype=DataType.Float) | |
invstd = fd.define_tensor( | |
symbolic_sizes=[-1, -1, 1], contiguous=[True, True, True], dtype=DataType.Float) | |
scale = fd.define_constant(1.0/(1. - dropout_prob)) | |
norm_size_const = fd.define_constant(norm_size) | |
diff = fd.ops.sub(layer_norm_in, mean) | |
x_hat = fd.ops.mul(diff, invstd) | |
weights_bcast = fd.ops.broadcast_in_dim( | |
weights, output_shape=input_shape, broadcast_dims=[2]) | |
grad_x_hat = fd.ops.mul(grad_out, weights_bcast) | |
outa = fd.ops.mul(grad_x_hat, norm_size_const) | |
outb = fd.ops.sum(grad_x_hat, axes=[normalization_axis], keepdim=keepDim) | |
outc1 = fd.ops.mul(grad_x_hat, x_hat) | |
outc2 = fd.ops.sum(outc1, axes=[normalization_axis], keepdim=keepDim) | |
outc3 = fd.ops.mul(x_hat, outc2) | |
inner1 = fd.ops.sub(outa, outb) | |
inner2 = fd.ops.sub(inner1, outc3) | |
recip_size = fd.ops.reciprocal(norm_size_const) | |
out6 = fd.ops.mul(invstd, recip_size) | |
grad_input2 = fd.ops.mul(inner2, out6) | |
out7 = fd.ops.mul(grad_out, x_hat) | |
grad_weights = fd.ops.sum(out7, axes=[0, 1], keepdim=False) | |
grad_bias2 = fd.ops.sum(grad_out, axes=[0, 1], keepdim=False) | |
grad_mask = fd.ops.mul(grad_input2, dropout_mask) | |
grad_input1 = fd.ops.mul(grad_mask, scale) | |
grad_bias1 = fd.ops.sum(grad_input1, axes=[0, 1], keepdim=False) | |
fd.add_output(grad_input1) | |
fd.add_output(grad_input2) | |
fd.add_output(grad_weights) | |
fd.add_output(grad_bias1) | |
fd.add_output(grad_bias2) | |
# Create the object to store the operation definition | |
fs_dropout_layer_norm_bwd = Fusion() | |
with FusionDefinition(fs_dropout_layer_norm_bwd) as fd: | |
# Make the operation definition | |
nvfuser_fusion_dropout_layer_norm_bwd( | |
fd, 2, input1.size()[2], input1.size(), 1e-12, True, 0.1) | |
# Wrap the nvFuser implementation into a callable | |
def python_frontend_dropout_layer_norm_bwd( | |
fs: Fusion, | |
weights: torch.Tensor, | |
grad_out: torch.Tensor, | |
dropout_mask: torch.Tensor, | |
layer_norm_in: torch.Tensor, | |
mean: torch.Tensor, | |
invstd: torch.Tensor | |
) -> List[torch.Tensor]: | |
out = fs.execute([weights, grad_out, dropout_mask, | |
layer_norm_in, mean, invstd]) | |
return out | |
func_bwd = functools.partial( | |
python_frontend_dropout_layer_norm_bwd, fs_dropout_layer_norm_bwd, weight) | |
# Profile the primitive definition written directly in nvFuser | |
profile_workload(func_fwd, func_bwd, grad_output, | |
iteration_count=100, label="Python Frontend - Layer Norm") | |
# Slide - CUSTOM OPERATION | |
# Modify the fusion to use RMSNorm instead of LayerNorm. There's no RMSNorm in the PyTorch API, so define it in primitive operations. | |
def with_rms_norm( | |
input1: torch.Tensor, | |
input2: torch.Tensor, | |
weight: torch.Tensor, | |
bias: torch.Tensor, | |
normalization_axis: int, | |
dropout_prob: float, | |
keepdim: bool, | |
) -> torch.Tensor: | |
bias_out = input1 + bias | |
dropout_out = F.dropout(bias_out, dropout_prob, training=True) | |
norm_input = dropout_out + input2 | |
var = norm_input.mul(norm_input).mean(normalization_axis, keepdim) | |
pre_shift_scale_norm_output = norm_input / torch.sqrt(var + 1e-12) | |
norm_output = weight * pre_shift_scale_norm_output | |
return norm_output | |
# Profile the RMSNorm fusion with Eager Mode | |
func = functools.partial( | |
with_rms_norm, | |
input1, | |
input2, | |
weight, | |
bias1, | |
normalization_axis=2, | |
dropout_prob=0.1, | |
keepdim=True, | |
) | |
profile_workload(func, None, grad_output, iteration_count=100, | |
label="Eager Mode - RMS Norm") | |
# Profile the RMSNorm fusion with TorchScript | |
scripted_with_rms_norm = torch.jit.script(with_rms_norm) | |
func = functools.partial( | |
scripted_with_rms_norm, | |
input1, | |
input2, | |
weight, | |
bias1, | |
normalization_axis=2, | |
dropout_prob=0.1, | |
keepdim=True, | |
) | |
profile_workload(func, None, grad_output, iteration_count=100, | |
label="TorchScript - RMS Norm") | |
# Modify the RMSNorm fusion to pull the constants in for FuncTorch tracing. This would be automatically done if using TorchDynamo. | |
def with_rms_norm_for_memory_efficient_fusion( | |
input1: torch.Tensor, input2: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor | |
) -> torch.Tensor: | |
bias_out = input1 + bias | |
dropout_out = torch.nn.functional.dropout(bias_out, 0.1) | |
norm_input = dropout_out + input2 | |
var = norm_input.mul(norm_input).mean(2, keepdim=True) | |
pre_shift_scale_norm_output = norm_input / torch.sqrt(var + 1e-12) | |
norm_output = weight * pre_shift_scale_norm_output | |
return norm_output | |
memory_efficient_rms_norm = memory_efficient_fusion( | |
with_rms_norm_for_memory_efficient_fusion | |
) | |
func = functools.partial(memory_efficient_rms_norm, | |
input1, input2, weight, bias1) | |
# Profile the RMSNorm fusion with FuncTorch tracing. | |
profile_workload(func, None, grad_output, iteration_count=100, | |
label="FuncTorch - RMS Norm") | |
# Create the verbose but direct implementation of the RMSNorm fusion in nvFuser | |
def nvfuser_fusion_dropout_rms_norm_fwd( | |
fd: FusionDefinition, | |
normalization_axis: int, | |
norm_size: int, | |
input_shape: List[int], | |
eps: float, | |
keepDim: bool, | |
dropout_prob: float | |
) -> None: | |
input1 = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) | |
input2 = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) | |
weights = fd.define_tensor( | |
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) | |
bias = fd.define_tensor( | |
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) | |
prob = fd.define_constant(1. - dropout_prob) | |
scale = fd.define_constant(1.0/(1. - dropout_prob)) | |
plus_bias = fd.ops.add(input1, bias) | |
rand = fd.ops.rand_like(plus_bias) | |
dropout_mask = fd.ops.le(rand, prob) | |
apply_mask = fd.ops.mul(plus_bias, dropout_mask) | |
dropout = fd.ops.mul(apply_mask, scale) | |
rms_norm_in = fd.ops.add(dropout, input2) | |
inputs_sq = fd.ops.mul(rms_norm_in, rms_norm_in) | |
sum0 = fd.ops.sum(inputs_sq, axes=[normalization_axis], keepdim=keepDim) | |
norm_size_const = fd.define_constant(norm_size) | |
var = fd.ops.div(sum0, norm_size_const) | |
eps_const = fd.define_constant(eps) | |
var_eps = fd.ops.add(var, eps_const) | |
sqrt_var = fd.ops.sqrt(var_eps) | |
pre_scale = fd.ops.div(rms_norm_in, sqrt_var) | |
weights_bcast = fd.ops.broadcast_in_dim( | |
weights, output_shape=input_shape, broadcast_dims=[2]) | |
out = fd.ops.mul(pre_scale, weights_bcast) | |
fd.add_output(out) | |
fd.add_output(dropout_mask) | |
fd.add_output(rms_norm_in) | |
fd.add_output(sqrt_var) | |
# Create the object to store the operation definition | |
fs_dropout_rms_norm_fwd = Fusion() | |
with FusionDefinition(fs_dropout_rms_norm_fwd) as fd: | |
# Make the operation definition | |
nvfuser_fusion_dropout_rms_norm_fwd( | |
fd, 2, input1.size()[2], input1.size(), 1e-12, True, 0.1) | |
# Wrap the nvFuser implementation into a callable | |
def python_frontend_dropout_rms_norm_fwd( | |
fs: Fusion, input1: torch.Tensor, input2: torch.Tensor, weights: torch.Tensor, bias: torch.Tensor | |
) -> List[torch.Tensor]: | |
out = fs.execute([input1, input2, weights, bias]) | |
return out | |
func_fwd = functools.partial(python_frontend_dropout_rms_norm_fwd, | |
fs_dropout_rms_norm_fwd, input1, input2, weight, bias1) | |
# nvFuser's frontend doesn't have auto-differentiation, so the backwards operations need to be defined manually. | |
def nvfuser_fusion_dropout_rms_norm_bwd( | |
fd: FusionDefinition, | |
normalization_axis: int, | |
norm_size: int, | |
input_shape: List[int], | |
eps: float, | |
keepDim: bool, | |
dropout_prob: float | |
) -> None: | |
weights = fd.define_tensor( | |
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float) | |
grad_out = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) | |
dropout_mask = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Bool) | |
rms_norm_in = fd.define_tensor( | |
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float) | |
sqrt_var = fd.define_tensor( | |
symbolic_sizes=[-1, -1, 1], contiguous=[True, True, True], dtype=DataType.Float) | |
scale = fd.define_constant(1.0/(1. - dropout_prob)) | |
norm_size_const = fd.define_constant(norm_size) | |
const2 = fd.define_constant(2.) | |
x_hat = fd.ops.div(rms_norm_in, sqrt_var) | |
weights_bcast = fd.ops.broadcast_in_dim( | |
weights, output_shape=[1, 1, input_shape[2]], broadcast_dims=[2]) | |
grad_x_hat = fd.ops.mul(grad_out, weights_bcast) | |
x_hat_1 = fd.ops.div(x_hat, sqrt_var) | |
neg_grad_x_hat = fd.ops.neg(grad_x_hat) | |
mul4 = fd.ops.mul(neg_grad_x_hat, x_hat_1) | |
div3 = fd.ops.div(grad_x_hat, sqrt_var) | |
sum1 = fd.ops.sum(mul4, axes=[normalization_axis], keepdim=keepDim) | |
mul5 = fd.ops.mul(sqrt_var, const2) | |
div4 = fd.ops.div(sum1, mul5) | |
div4_bcast = fd.ops.broadcast_in_dim( | |
div4, output_shape=input_shape, broadcast_dims=[0, 1, 2]) | |
div5 = fd.ops.div(div4_bcast, norm_size_const) | |
mul6 = fd.ops.mul(div5, rms_norm_in) | |
add3 = fd.ops.add(mul6, div3) | |
grad_input2 = fd.ops.add(add3, mul6) | |
grad_mask = fd.ops.mul(grad_input2, dropout_mask) | |
grad_input1 = fd.ops.mul(grad_mask, scale) | |
grad_bias = fd.ops.sum(grad_input1, axes=[0, 1], keepdim=False) | |
mul3 = fd.ops.mul(grad_out, x_hat) | |
grad_weights = fd.ops.sum(mul3, axes=[0, 1], keepdim=False) | |
fd.add_output(grad_input1) | |
fd.add_output(grad_input2) | |
fd.add_output(grad_weights) | |
fd.add_output(grad_bias) | |
# Create the object to store the operation definition | |
fs_dropout_rms_norm_bwd = Fusion() | |
with FusionDefinition(fs_dropout_rms_norm_bwd) as fd: | |
# Make the operation definition | |
nvfuser_fusion_dropout_rms_norm_bwd( | |
fd, 2, input1.size()[2], input1.size(), 1e-12, True, 0.1) | |
# Wrap the nvFuser implementation into a callable | |
def python_frontend_dropout_rms_norm_bwd( | |
fs: Fusion, | |
weights: torch.Tensor, | |
grad_out: torch.Tensor, | |
dropout_mask: torch.Tensor, | |
rms_norm_in: torch.Tensor, | |
invstd: torch.Tensor | |
) -> List[torch.Tensor]: | |
out = fs.execute([weights, grad_out, dropout_mask, rms_norm_in, invstd]) | |
return out | |
func_bwd = functools.partial( | |
python_frontend_dropout_rms_norm_bwd, fs_dropout_rms_norm_bwd, weight) | |
profile_workload(func_fwd, func_bwd, grad_output, | |
iteration_count=100, label="Python Frontend - RMS Norm") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment