Skip to content

Instantly share code, notes, and snippets.

@yujiepan-work
Created April 26, 2023 01:47
Show Gist options
  • Save yujiepan-work/cff1c5b72aeb57c97477f7ecf0b201b1 to your computer and use it in GitHub Desktop.
Save yujiepan-work/cff1c5b72aeb57c97477f7ecf0b201b1 to your computer and use it in GitHub Desktop.
benchmarking FFN with MoE
# %%
import torch.utils.benchmark as benchmark
import torch.nn as nn
from dataclasses import dataclass
import torch
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
class FeedForward(nn.Module):
def __init__(self, input_dim, intermediate_size):
super().__init__()
# first layer
self.fc1 = nn.Linear(input_dim, intermediate_size)
self.intermediate_act_fn = torch.nn.functional.gelu
# second layer
self.fc2 = nn.Linear(intermediate_size, input_dim)
self.LayerNorm = nn.LayerNorm(input_dim)
# self.dropout = nn.Dropout(dropout)
def forward(self, hidden_states):
input_tensor = hidden_states
hidden_states = self.fc1(hidden_states)
hidden_states = self.intermediate_act_fn(hidden_states)
hidden_states = self.fc2(hidden_states)
# hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
class MoEFFN(nn.Module):
def __init__(self, input_dim, intermediate_size, num_experts=4, seq_len=128) -> None:
super().__init__()
self.gating = nn.Linear(input_dim, 4)
self.orders = torch.arange(seq_len * 1)
self.input_dim = input_dim
self.seq_len = seq_len
self.ffn1 = FeedForward(input_dim, intermediate_size)
self.ffn2 = FeedForward(input_dim, intermediate_size)
self.ffn3 = FeedForward(input_dim, intermediate_size)
self.ffn4 = FeedForward(input_dim, intermediate_size)
def forward(self, hidden_states):
x = hidden_states.view(-1, self.input_dim)
logits = self.gating(x) # (B * L) * 4
gate = torch.argmax(logits, dim=-1) # (B * L, )
orders = self.orders
# print(gate.shape)
ids0 = (gate == 0)
# x_l0 = self.ffn1(x[ids0])
x_l0 = self.ffn1(x[ids0])
order_l0 = orders[ids0]
ids1 = (gate == 1)
x_l1 = self.ffn2(x[ids1])
order_l1 = orders[ids1]
ids2 = (gate == 2)
x_l2 = self.ffn3(x[ids2])
order_l2 = orders[ids2]
ids3 = (gate == 3)
x_l3 = self.ffn4(x[ids3])
order_l3 = orders[ids3]
# print(x_l1.shape, x_l2.shape, x_l3.shape, x_l0.shape)
x = torch.cat([x_l0, x_l1, x_l2, x_l3], dim=0)
order = torch.cat([order_l0, order_l1, order_l2, order_l3], dim=0)
x = x[order.argsort(0)] # restore original order
x = x.view(1, self.seq_len, self.input_dim)
return x
def create_nncf_model(torch_module, nncf_config, train_dataloader, onnx_name):
import nncf
from nncf.config import NNCFConfig
from nncf.config.structures import (BNAdaptationInitArgs,
QuantizationRangeInitArgs)
from nncf.torch import create_compressed_model
from nncf.torch.initialization import PTInitializingDataLoader
class MyInitializingDataloader(PTInitializingDataLoader):
def get_inputs(self, dataloader_output):
return (), dataloader_output
nncf_config.register_extra_structs([
QuantizationRangeInitArgs(MyInitializingDataloader(train_dataloader)),
BNAdaptationInitArgs(MyInitializingDataloader(train_dataloader)),
])
torch_module.train()
compression_ctrl, compressed_model = create_compressed_model(torch_module, nncf_config)
onnx_path = os.path.join( '/tmp', f'{onnx_name}.onnx')
# print(compressed_model)
compression_ctrl.export_model(onnx_path)
return onnx_path
@dataclass
class BenchmarkResult:
stdout: str = ''
stderr: str = ''
avg_latency: float = 0.
throughput: float = 0.
def run_benchmark(cmd):
import subprocess
with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p:
p.wait()
stdout = p.stdout.read().decode()
stderr = p.stderr.read().decode()
# print(stdout)
# print(stderr)
return BenchmarkResult(
stdout=stdout,
stderr=stderr,
throughput=list(filter(lambda x: x.strip(), stdout.split('\n')))[-1].split()[-2],
avg_latency=list(filter(lambda x: x.strip(), stdout.split('\n')))[-4].split()[-2],
)
# %%
import pandas as pd
from torch.utils.data import DataLoader, Dataset
from nncf.config import NNCFConfig
from copy import deepcopy
num_threads = torch.get_num_threads()
print(f'Benchmarking on {num_threads} threads')
num_trials = 2000
results = []
benchmark_int8_results = []
benchmark_fp32_results = []
for seq_len in [64, 128, 256, 384, 768, 1024]:
for input_dim in [64, 128, 256, 512, 768]:
x = torch.rand((1, seq_len, input_dim))
ffn_model = FeedForward(input_dim, input_dim * 4)
# print(seq_len, input_dim)
max_expert = seq_len
while max_expert > 0.5 * seq_len: # ensures better balance
moe_model = MoEFFN(input_dim, input_dim, num_experts=4, seq_len=seq_len)
nn.init.normal_(moe_model.gating.weight)
nn.init.zeros_(moe_model.gating.bias)
logits = moe_model.gating(x.view(-1, input_dim))
gate = torch.argmax(logits, dim=-1)
max_expert = max((gate == i).sum() for i in range(4))
@torch.no_grad()
def ffn(x):
return ffn_model(x)
@torch.no_grad()
def moe(x):
return moe_model(x)
t0 = benchmark.Timer(
stmt='ffn(x)',
globals={'x': x, 'moe': moe, 'ffn': ffn},
num_threads=num_threads,
description='vanilla FFN',
label='benchmark',
sub_label=f'seq_len{seq_len:<3d}, hidden_dim{input_dim:<3d}').timeit(num_trials)
t1 = benchmark.Timer(
stmt='moe(x)',
globals={'x': x, 'moe': moe, 'ffn': ffn},
num_threads=num_threads,
description='MoE',
label='benchmark',
sub_label=f'seq_len{seq_len:<3d}, hidden_dim{input_dim:<3d}').timeit(num_trials)
results.append(t0)
results.append(t1)
## IR export
class MyDataset(Dataset):
def __len__(self):
return 500
def __getitem__(self, index):
return {'hidden_states': torch.rand((seq_len, input_dim))}
dataloader = DataLoader(MyDataset(), 1)
nncf_config_int8 = NNCFConfig.from_dict(
{
"input_info": [
{
"sample_size": [
1,
seq_len,
input_dim
],
}
],
"compression": {
"algorithm": "quantization",
"preset": "mixed",
"overflow_fix": "disable",
"initializer": {
"range": {
"num_init_samples": 0,
"type": "mean_min_max"
},
"batchnorm_adaptation": {
"num_bn_adaptation_samples": 0
},
},
"scope_overrides": {
"activations": {
"{re}.*matmul_0": {
"mode": "symmetric"
}
}
},
"ignored_scopes": [
# "{re}.*Embedding*",
"{re}.*__add___[0-1]",
"{re}.*layer_norm_0",
# "{re}.*matmul_1",
# "{re}.*__truediv__*",
],
}
}
)
nncf_config_fp32 = NNCFConfig.from_dict(
{
"input_info": [
{
"sample_size": [
1,
seq_len,
input_dim
],
}
],
"compression": []
}
)
for datatype, nncf_config in dict(int8=nncf_config_int8, fp32=nncf_config_fp32).items():
ffn_onnx = create_nncf_model(deepcopy(ffn_model), deepcopy(nncf_config), dataloader, f'{seq_len}x{input_dim}ffn-{datatype}')
moe_onnx = create_nncf_model(deepcopy(moe_model), deepcopy(nncf_config), dataloader, f'{seq_len}x{input_dim}moe-{datatype}')
benchmark_results = benchmark_int8_results if datatype == 'int8' else benchmark_fp32_results
benchmark_result = run_benchmark(cmd=f'benchmark_app -m {ffn_onnx} -niter 3000 -hint latency -infer_precision f32')
benchmark_results.append(
dict(
model='ffn',
identifier=(seq_len, input_dim),
latency=benchmark_result.avg_latency
)
)
benchmark_result = run_benchmark(cmd=f'benchmark_app -m {moe_onnx} -niter 3000 -hint latency -infer_precision f32')
benchmark_results.append(
dict(
model='moe',
identifier=(seq_len, input_dim),
latency=benchmark_result.avg_latency
)
)
compare = benchmark.Compare(results)
compare.print()
print(pd.DataFrame(benchmark_fp32_results))
print(pd.DataFrame(benchmark_int8_results))
# %%
# %%
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment