Skip to content

Instantly share code, notes, and snippets.

@csarofeen
Last active September 11, 2022 17:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save csarofeen/a5c214637f24d06a8f07c0cee3ca83ed to your computer and use it in GitHub Desktop.
Save csarofeen/a5c214637f24d06a8f07c0cee3ca83ed to your computer and use it in GitHub Desktop.
from torch._C._nvfuser import Fusion, FusionCache, FusionDefinition, DataType
import torch
import torch.nn.functional as F
import functorch
from functorch.compile import memory_efficient_fusion
from copy import deepcopy
from typing import List
import time
import functools
import random
random.seed(42)
if torch.__version__ < (1, 12, 0):
raise RuntimeError(
"PyTorch >= 1.12.0 required, but your environment uses torch=={}".format(
torch.__version__
)
)
major, minor, _ = functorch.__version__.split(".")
if int(major) == 0 and int(minor) < 2:
raise RuntimeError(
"FuncTorch >= 0.2.0 required, but your environment uses functorch=={}".format(
functorch.__version__
)
)
# Slide - FUSION DEFINITION
# Start with the composite definition that we would like optimized
def composite_definition(
input1: torch.Tensor,
input2: torch.Tensor,
weight: torch.Tensor,
bias1: torch.Tensor,
bias2: torch.Tensor,
normalization_axis: int,
dropout_prob: float,
) -> torch.Tensor:
bias1_out = input1 + bias1
dropout_out = F.dropout(bias1_out, dropout_prob, training=True)
norm_input = dropout_out + input2
norm_output = F.layer_norm(
norm_input, (input1.size(normalization_axis),), weight, bias2
)
return norm_output
# Slide - INITIALIZE TENSOR TO USE
# Setup and initialize tensors and parameters
input_size = [64, 128, 1024]
device = "cuda"
dtype = torch.float32
# Create sample inputs
input1 = torch.randn(*input_size, device=device,
dtype=dtype, requires_grad=True)
input2 = torch.rand_like(input1).requires_grad_()
# Precompute a grad output tensor, for this example it's the same size
# as the inputs
grad_output = torch.rand_like(input1)
# Randomly initialize the model parameters
weight = torch.nn.Parameter(torch.randn(
input_size[2], dtype=dtype, device=device))
bias1 = torch.nn.Parameter(torch.randn(
input_size[2], dtype=dtype, device=device))
bias2 = torch.nn.Parameter(torch.randn(
input_size[2], dtype=dtype, device=device))
parameters = [input1, input2, weight, bias1, bias2]
# Slide - PROFILING UTILITY
# Utility to profile the different workloads
def profile_workload(forward_func, backward_func, grad_output, iteration_count=100, label=""):
# Perform warm-up iterations
for _ in range(3):
if grad_output is not None:
for p in parameters:
if p.grad is not None:
p.grad = None
# Run model, forward and backward
output = forward_func()
if grad_output is not None:
if backward_func is not None:
backward_func(grad_output, *output[1:])
else:
output.backward(grad_output)
# delete gradiens to avoid profiling the gradient accumulation
# for p in parameters:
# p.grad = None
# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for _ in range(iteration_count):
if grad_output is not None:
for p in parameters:
if p.grad is not None:
p.grad = None
# Run model, forward and backward
output = forward_func()
if grad_output is not None:
if backward_func is not None:
backward_func(grad_output, *output[1:])
else:
output.backward(grad_output)
# delete gradiens to avoid profiling the gradient accumulation
# for p in parameters:
# p.grad = None
# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
if label:
print(label)
print("Average iterations per second: {:.2f}".format(iters_per_second))
# Slide - RUN WITH EAGER MODE
# Run and profile eager mode execution on the composite definition of our
# operations.
func = functools.partial(
composite_definition,
input1,
input2,
weight,
bias1,
bias2,
normalization_axis=2,
dropout_prob=0.1,
)
output_eager = profile_workload(
func, None, grad_output, iteration_count=100, label="Eager Mode - Composite definition"
)
# Slide - TORCHSCRIPT
# Script the function for fusion with nvFuser
scripted_composite_definition = torch.jit.script(composite_definition)
func = functools.partial(
scripted_composite_definition,
input1,
input2,
weight,
bias1,
bias2,
normalization_axis=2,
dropout_prob=0.1,
)
profile_workload(
func, None, grad_output, iteration_count=100, label="TorchScript - Composite definition"
)
# Slide - DYNAMIC SHAPES
# Create some dynamic shapes to run through the torch scripted model
SHAPE_COUNT = 20
dynamic_sizes = deepcopy(input_size)
inputs1: List[torch.Tensor] = []
inputs2: List[torch.Tensor] = []
grad_outputs: List[torch.Tensor] = []
# Slide - DYNAMIC SHAPES
# Create the shapes
for _ in range(SHAPE_COUNT):
dynamic_sizes[0] = input_size[0] + random.randrange(-2, 3)
dynamic_sizes[1] = input_size[1] + random.randrange(-2, 3)
input = torch.randn(*dynamic_sizes, device=device,
dtype=dtype, requires_grad=True)
inputs1.append(input)
inputs2.append(torch.rand_like(input))
grad_outputs.append(torch.rand_like(input))
# Manually perform profiling since the profiling utility expects static shapes
# Perform warm-up iterations
for _ in range(3):
dynamic_input1 = inputs1[0]
dynamic_input2 = inputs2[0]
dynamic_grad_output = grad_outputs[0]
# Run model, forward and backward
output = scripted_composite_definition(
dynamic_input1,
dynamic_input2,
weight,
bias1,
bias2,
normalization_axis=2,
dropout_prob=0.1,
)
output.backward(dynamic_grad_output)
iteration_count = 100
# Synchronize the GPU before starting the timer
torch.cuda.synchronize()
start = time.perf_counter()
for i in range(iteration_count):
dynamic_input1 = inputs1[i % SHAPE_COUNT]
dynamic_input2 = inputs2[i % SHAPE_COUNT]
dynamic_grad_output = grad_outputs[i % SHAPE_COUNT]
dynamic_parameters = [dynamic_input1, dynamic_input2, weight, bias1, bias2]
# Run model, forward and backward
output = scripted_composite_definition(
dynamic_input1,
dynamic_input2,
weight,
bias1,
bias2,
normalization_axis=2,
dropout_prob=0.1,
)
output.backward(dynamic_grad_output)
# Delete the gradients to avoid profiling the gradient accumulation
for p in dynamic_parameters:
p.grad = None
# Synchronize the GPU before stopping the timer
torch.cuda.synchronize()
stop = time.perf_counter()
iters_per_second = iteration_count / (stop - start)
print("TorchScript - Random Sizes")
print("Average iterations per second: {:.2f}".format(iters_per_second))
# Slide - THERE MUST BE AN EASIER WAY
# Change the fusion definition to be based on primitive operations that are easily modified
def primitive_definition(
input1: torch.Tensor,
input2: torch.Tensor,
weight: torch.Tensor,
bias1: torch.Tensor,
bias2: torch.Tensor,
normalization_axis: int,
dropout_prob: float,
keepdim: bool,
) -> torch.Tensor:
bias1_out = input1 + bias1
dropout_out = F.dropout(bias1_out, dropout_prob, training=True)
norm_input = dropout_out + input2
mean = norm_input.mean(normalization_axis, keepdim=keepdim)
diff = norm_input - mean
diff_sq = diff * diff
var = diff_sq.mean(normalization_axis, keepdim=keepdim)
pre_shift_scale_norm_output = (norm_input - mean) / torch.sqrt(var + 1e-12)
norm_output = weight * pre_shift_scale_norm_output + bias2
return norm_output
# Profile the definition with Eager Mode
func = functools.partial(
primitive_definition,
input1,
input2,
weight,
bias1,
bias2,
normalization_axis=2,
dropout_prob=0.1,
keepdim=True,
)
profile_workload(
func, None, grad_output, iteration_count=100, label="Eager Mode - Primitive Definition"
)
# Profile primitive definition with TorchScript
scripted_primitive_definition = torch.jit.script(primitive_definition)
func = functools.partial(
scripted_primitive_definition,
input1,
input2,
weight,
bias1,
bias2,
normalization_axis=2,
dropout_prob=0.1,
keepdim=True,
)
profile_workload(
func, None, grad_output, iteration_count=100, label="TorchScript - Primitive definition"
)
# Slide - FUNCTORCH
# Modify the definition to pull constants into the function to help FuncTorch with tracing.
# This would be automatically done if using TorchDynamo
def primitive_definition_for_memory_efficient_fusion(
input1: torch.Tensor,
input2: torch.Tensor,
weight: torch.Tensor,
bias1: torch.Tensor,
bias2: torch.Tensor,
) -> torch.Tensor:
bias1_out = input1 + bias1
dropout_out = F.dropout(bias1_out, 0.1, training=True)
norm_input = dropout_out + input2
mean = norm_input.mean(2, keepdim=True)
diff = norm_input - mean
diff_sq = diff * diff
var = diff_sq.mean(2, keepdim=True)
pre_shift_scale_norm_output = (norm_input - mean) / torch.sqrt(var + 1e-12)
norm_output = weight * pre_shift_scale_norm_output + bias2
return norm_output
# Optimize the model with FuncTorch tracing and the memory efficiency
# optimization pass
memory_efficient_primitive_definition = memory_efficient_fusion(
primitive_definition_for_memory_efficient_fusion
)
# Profile the primitive definition optimized with FuncTorch
func = functools.partial(
memory_efficient_primitive_definition, input1, input2, weight, bias1, bias2
)
profile_workload(
func,
None,
grad_output,
iteration_count=100,
label="FuncTorch - Primitive definition",
)
# Slide - NVFUSER PYTHON FRONTEND
# FuncTorch does not have suport for dynamic shapes yet, however, nvFuser does through it's Python frontend
# Create the verbose but direct implementation in nvFuser
def nvfuser_fusion_dropout_layer_norm_fwd(
fd: FusionDefinition,
normalization_axis: int,
norm_size: int,
input_shape: List[int],
eps: float,
keepDim: bool,
dropout_prob: float
) -> None:
input1 = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
input2 = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
weights = fd.define_tensor(
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
bias1 = fd.define_tensor(
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
bias2 = fd.define_tensor(
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
prob = fd.define_constant(1. - dropout_prob)
scale = fd.define_constant(1.0/(1 - dropout_prob))
plus_bias = fd.ops.add(input1, bias1)
rand = fd.ops.rand_like(plus_bias)
dropout_mask = fd.ops.le(rand, prob)
apply_mask = fd.ops.mul(plus_bias, dropout_mask)
dropout = fd.ops.mul(apply_mask, scale)
layer_norm_in = fd.ops.add(dropout, input2)
var, mean = fd.ops.var_mean(
layer_norm_in, axes=[normalization_axis], correction=0, keepdim=keepDim)
eps_const = fd.define_constant(eps)
var_eps = fd.ops.add(var, eps_const)
invstd = fd.ops.rsqrt(var_eps)
diff = fd.ops.sub(layer_norm_in, mean)
pre_scale_bias = fd.ops.mul(diff, invstd)
weights_bcast = fd.ops.broadcast_in_dim(
weights, output_shape=input_shape, broadcast_dims=[2])
scale = fd.ops.mul(pre_scale_bias, weights_bcast)
bias_bcast = fd.ops.broadcast_in_dim(
bias2, output_shape=input_shape, broadcast_dims=[2])
out = fd.ops.add(scale, bias_bcast)
fd.add_output(out)
fd.add_output(dropout_mask)
fd.add_output(layer_norm_in)
fd.add_output(mean)
fd.add_output(invstd)
# Create the object to store the operation definition
fs_dropout_layer_norm_fwd = Fusion()
with FusionDefinition(fs_dropout_layer_norm_fwd) as fd:
# Make the operation definition
nvfuser_fusion_dropout_layer_norm_fwd(
fd, 2, input1.size()[2], input1.size(), 1e-12, True, 0.1)
# Wrap the nvFuser implementation into a callable
def python_frontend_dropout_layer_norm_fwd(
fs: Fusion,
input1: torch.Tensor,
input2: torch.Tensor,
weights: torch.Tensor,
bias1: torch.Tensor,
bias2: torch.Tensor
) -> List[torch.Tensor]:
out = fs.execute([input1, input2, weights, bias1, bias2])
return out
func_fwd = functools.partial(python_frontend_dropout_layer_norm_fwd,
fs_dropout_layer_norm_fwd, input1, input2, weight, bias1, bias2)
# nvFuser's frontend doesn't have auto-differentiation, also manually define the backwards operations.
def nvfuser_fusion_dropout_layer_norm_bwd(
fd: FusionDefinition,
normalization_axis: int,
norm_size: int,
input_shape: List[int],
eps: float,
keepDim: bool,
dropout_prob: float
) -> None:
weights = fd.define_tensor(
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
grad_out = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
dropout_mask = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Bool)
layer_norm_in = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
mean = fd.define_tensor(
symbolic_sizes=[-1, -1, 1], contiguous=[True, True, True], dtype=DataType.Float)
invstd = fd.define_tensor(
symbolic_sizes=[-1, -1, 1], contiguous=[True, True, True], dtype=DataType.Float)
scale = fd.define_constant(1.0/(1. - dropout_prob))
norm_size_const = fd.define_constant(norm_size)
diff = fd.ops.sub(layer_norm_in, mean)
x_hat = fd.ops.mul(diff, invstd)
weights_bcast = fd.ops.broadcast_in_dim(
weights, output_shape=input_shape, broadcast_dims=[2])
grad_x_hat = fd.ops.mul(grad_out, weights_bcast)
outa = fd.ops.mul(grad_x_hat, norm_size_const)
outb = fd.ops.sum(grad_x_hat, axes=[normalization_axis], keepdim=keepDim)
outc1 = fd.ops.mul(grad_x_hat, x_hat)
outc2 = fd.ops.sum(outc1, axes=[normalization_axis], keepdim=keepDim)
outc3 = fd.ops.mul(x_hat, outc2)
inner1 = fd.ops.sub(outa, outb)
inner2 = fd.ops.sub(inner1, outc3)
recip_size = fd.ops.reciprocal(norm_size_const)
out6 = fd.ops.mul(invstd, recip_size)
grad_input2 = fd.ops.mul(inner2, out6)
out7 = fd.ops.mul(grad_out, x_hat)
grad_weights = fd.ops.sum(out7, axes=[0, 1], keepdim=False)
grad_bias2 = fd.ops.sum(grad_out, axes=[0, 1], keepdim=False)
grad_mask = fd.ops.mul(grad_input2, dropout_mask)
grad_input1 = fd.ops.mul(grad_mask, scale)
grad_bias1 = fd.ops.sum(grad_input1, axes=[0, 1], keepdim=False)
fd.add_output(grad_input1)
fd.add_output(grad_input2)
fd.add_output(grad_weights)
fd.add_output(grad_bias1)
fd.add_output(grad_bias2)
# Create the object to store the operation definition
fs_dropout_layer_norm_bwd = Fusion()
with FusionDefinition(fs_dropout_layer_norm_bwd) as fd:
# Make the operation definition
nvfuser_fusion_dropout_layer_norm_bwd(
fd, 2, input1.size()[2], input1.size(), 1e-12, True, 0.1)
# Wrap the nvFuser implementation into a callable
def python_frontend_dropout_layer_norm_bwd(
fs: Fusion,
weights: torch.Tensor,
grad_out: torch.Tensor,
dropout_mask: torch.Tensor,
layer_norm_in: torch.Tensor,
mean: torch.Tensor,
invstd: torch.Tensor
) -> List[torch.Tensor]:
out = fs.execute([weights, grad_out, dropout_mask,
layer_norm_in, mean, invstd])
return out
func_bwd = functools.partial(
python_frontend_dropout_layer_norm_bwd, fs_dropout_layer_norm_bwd, weight)
# Profile the primitive definition written directly in nvFuser
profile_workload(func_fwd, func_bwd, grad_output,
iteration_count=100, label="Python Frontend - Layer Norm")
# Slide - CUSTOM OPERATION
# Modify the fusion to use RMSNorm instead of LayerNorm. There's no RMSNorm in the PyTorch API, so define it in primitive operations.
def with_rms_norm(
input1: torch.Tensor,
input2: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor,
normalization_axis: int,
dropout_prob: float,
keepdim: bool,
) -> torch.Tensor:
bias_out = input1 + bias
dropout_out = F.dropout(bias_out, dropout_prob, training=True)
norm_input = dropout_out + input2
var = norm_input.mul(norm_input).mean(normalization_axis, keepdim)
pre_shift_scale_norm_output = norm_input / torch.sqrt(var + 1e-12)
norm_output = weight * pre_shift_scale_norm_output
return norm_output
# Profile the RMSNorm fusion with Eager Mode
func = functools.partial(
with_rms_norm,
input1,
input2,
weight,
bias1,
normalization_axis=2,
dropout_prob=0.1,
keepdim=True,
)
profile_workload(func, None, grad_output, iteration_count=100,
label="Eager Mode - RMS Norm")
# Profile the RMSNorm fusion with TorchScript
scripted_with_rms_norm = torch.jit.script(with_rms_norm)
func = functools.partial(
scripted_with_rms_norm,
input1,
input2,
weight,
bias1,
normalization_axis=2,
dropout_prob=0.1,
keepdim=True,
)
profile_workload(func, None, grad_output, iteration_count=100,
label="TorchScript - RMS Norm")
# Modify the RMSNorm fusion to pull the constants in for FuncTorch tracing. This would be automatically done if using TorchDynamo.
def with_rms_norm_for_memory_efficient_fusion(
input1: torch.Tensor, input2: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor
) -> torch.Tensor:
bias_out = input1 + bias
dropout_out = torch.nn.functional.dropout(bias_out, 0.1)
norm_input = dropout_out + input2
var = norm_input.mul(norm_input).mean(2, keepdim=True)
pre_shift_scale_norm_output = norm_input / torch.sqrt(var + 1e-12)
norm_output = weight * pre_shift_scale_norm_output
return norm_output
memory_efficient_rms_norm = memory_efficient_fusion(
with_rms_norm_for_memory_efficient_fusion
)
func = functools.partial(memory_efficient_rms_norm,
input1, input2, weight, bias1)
# Profile the RMSNorm fusion with FuncTorch tracing.
profile_workload(func, None, grad_output, iteration_count=100,
label="FuncTorch - RMS Norm")
# Create the verbose but direct implementation of the RMSNorm fusion in nvFuser
def nvfuser_fusion_dropout_rms_norm_fwd(
fd: FusionDefinition,
normalization_axis: int,
norm_size: int,
input_shape: List[int],
eps: float,
keepDim: bool,
dropout_prob: float
) -> None:
input1 = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
input2 = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
weights = fd.define_tensor(
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
bias = fd.define_tensor(
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
prob = fd.define_constant(1. - dropout_prob)
scale = fd.define_constant(1.0/(1. - dropout_prob))
plus_bias = fd.ops.add(input1, bias)
rand = fd.ops.rand_like(plus_bias)
dropout_mask = fd.ops.le(rand, prob)
apply_mask = fd.ops.mul(plus_bias, dropout_mask)
dropout = fd.ops.mul(apply_mask, scale)
rms_norm_in = fd.ops.add(dropout, input2)
inputs_sq = fd.ops.mul(rms_norm_in, rms_norm_in)
sum0 = fd.ops.sum(inputs_sq, axes=[normalization_axis], keepdim=keepDim)
norm_size_const = fd.define_constant(norm_size)
var = fd.ops.div(sum0, norm_size_const)
eps_const = fd.define_constant(eps)
var_eps = fd.ops.add(var, eps_const)
sqrt_var = fd.ops.sqrt(var_eps)
pre_scale = fd.ops.div(rms_norm_in, sqrt_var)
weights_bcast = fd.ops.broadcast_in_dim(
weights, output_shape=input_shape, broadcast_dims=[2])
out = fd.ops.mul(pre_scale, weights_bcast)
fd.add_output(out)
fd.add_output(dropout_mask)
fd.add_output(rms_norm_in)
fd.add_output(sqrt_var)
# Create the object to store the operation definition
fs_dropout_rms_norm_fwd = Fusion()
with FusionDefinition(fs_dropout_rms_norm_fwd) as fd:
# Make the operation definition
nvfuser_fusion_dropout_rms_norm_fwd(
fd, 2, input1.size()[2], input1.size(), 1e-12, True, 0.1)
# Wrap the nvFuser implementation into a callable
def python_frontend_dropout_rms_norm_fwd(
fs: Fusion, input1: torch.Tensor, input2: torch.Tensor, weights: torch.Tensor, bias: torch.Tensor
) -> List[torch.Tensor]:
out = fs.execute([input1, input2, weights, bias])
return out
func_fwd = functools.partial(python_frontend_dropout_rms_norm_fwd,
fs_dropout_rms_norm_fwd, input1, input2, weight, bias1)
# nvFuser's frontend doesn't have auto-differentiation, so the backwards operations need to be defined manually.
def nvfuser_fusion_dropout_rms_norm_bwd(
fd: FusionDefinition,
normalization_axis: int,
norm_size: int,
input_shape: List[int],
eps: float,
keepDim: bool,
dropout_prob: float
) -> None:
weights = fd.define_tensor(
symbolic_sizes=[-1], contiguous=[True], dtype=DataType.Float)
grad_out = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
dropout_mask = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Bool)
rms_norm_in = fd.define_tensor(
symbolic_sizes=[-1, -1, -1], contiguous=[True, True, True], dtype=DataType.Float)
sqrt_var = fd.define_tensor(
symbolic_sizes=[-1, -1, 1], contiguous=[True, True, True], dtype=DataType.Float)
scale = fd.define_constant(1.0/(1. - dropout_prob))
norm_size_const = fd.define_constant(norm_size)
const2 = fd.define_constant(2.)
x_hat = fd.ops.div(rms_norm_in, sqrt_var)
weights_bcast = fd.ops.broadcast_in_dim(
weights, output_shape=[1, 1, input_shape[2]], broadcast_dims=[2])
grad_x_hat = fd.ops.mul(grad_out, weights_bcast)
x_hat_1 = fd.ops.div(x_hat, sqrt_var)
neg_grad_x_hat = fd.ops.neg(grad_x_hat)
mul4 = fd.ops.mul(neg_grad_x_hat, x_hat_1)
div3 = fd.ops.div(grad_x_hat, sqrt_var)
sum1 = fd.ops.sum(mul4, axes=[normalization_axis], keepdim=keepDim)
mul5 = fd.ops.mul(sqrt_var, const2)
div4 = fd.ops.div(sum1, mul5)
div4_bcast = fd.ops.broadcast_in_dim(
div4, output_shape=input_shape, broadcast_dims=[0, 1, 2])
div5 = fd.ops.div(div4_bcast, norm_size_const)
mul6 = fd.ops.mul(div5, rms_norm_in)
add3 = fd.ops.add(mul6, div3)
grad_input2 = fd.ops.add(add3, mul6)
grad_mask = fd.ops.mul(grad_input2, dropout_mask)
grad_input1 = fd.ops.mul(grad_mask, scale)
grad_bias = fd.ops.sum(grad_input1, axes=[0, 1], keepdim=False)
mul3 = fd.ops.mul(grad_out, x_hat)
grad_weights = fd.ops.sum(mul3, axes=[0, 1], keepdim=False)
fd.add_output(grad_input1)
fd.add_output(grad_input2)
fd.add_output(grad_weights)
fd.add_output(grad_bias)
# Create the object to store the operation definition
fs_dropout_rms_norm_bwd = Fusion()
with FusionDefinition(fs_dropout_rms_norm_bwd) as fd:
# Make the operation definition
nvfuser_fusion_dropout_rms_norm_bwd(
fd, 2, input1.size()[2], input1.size(), 1e-12, True, 0.1)
# Wrap the nvFuser implementation into a callable
def python_frontend_dropout_rms_norm_bwd(
fs: Fusion,
weights: torch.Tensor,
grad_out: torch.Tensor,
dropout_mask: torch.Tensor,
rms_norm_in: torch.Tensor,
invstd: torch.Tensor
) -> List[torch.Tensor]:
out = fs.execute([weights, grad_out, dropout_mask, rms_norm_in, invstd])
return out
func_bwd = functools.partial(
python_frontend_dropout_rms_norm_bwd, fs_dropout_rms_norm_bwd, weight)
profile_workload(func_fwd, func_bwd, grad_output,
iteration_count=100, label="Python Frontend - RMS Norm")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment