Created
April 26, 2023 01:47
-
-
Save yujiepan-work/cff1c5b72aeb57c97477f7ecf0b201b1 to your computer and use it in GitHub Desktop.
benchmarking FFN with MoE
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% | |
import torch.utils.benchmark as benchmark | |
import torch.nn as nn | |
from dataclasses import dataclass | |
import torch | |
import os | |
os.environ['CUDA_VISIBLE_DEVICES'] = '0' | |
class FeedForward(nn.Module): | |
def __init__(self, input_dim, intermediate_size): | |
super().__init__() | |
# first layer | |
self.fc1 = nn.Linear(input_dim, intermediate_size) | |
self.intermediate_act_fn = torch.nn.functional.gelu | |
# second layer | |
self.fc2 = nn.Linear(intermediate_size, input_dim) | |
self.LayerNorm = nn.LayerNorm(input_dim) | |
# self.dropout = nn.Dropout(dropout) | |
def forward(self, hidden_states): | |
input_tensor = hidden_states | |
hidden_states = self.fc1(hidden_states) | |
hidden_states = self.intermediate_act_fn(hidden_states) | |
hidden_states = self.fc2(hidden_states) | |
# hidden_states = self.dropout(hidden_states) | |
hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
return hidden_states | |
class MoEFFN(nn.Module): | |
def __init__(self, input_dim, intermediate_size, num_experts=4, seq_len=128) -> None: | |
super().__init__() | |
self.gating = nn.Linear(input_dim, 4) | |
self.orders = torch.arange(seq_len * 1) | |
self.input_dim = input_dim | |
self.seq_len = seq_len | |
self.ffn1 = FeedForward(input_dim, intermediate_size) | |
self.ffn2 = FeedForward(input_dim, intermediate_size) | |
self.ffn3 = FeedForward(input_dim, intermediate_size) | |
self.ffn4 = FeedForward(input_dim, intermediate_size) | |
def forward(self, hidden_states): | |
x = hidden_states.view(-1, self.input_dim) | |
logits = self.gating(x) # (B * L) * 4 | |
gate = torch.argmax(logits, dim=-1) # (B * L, ) | |
orders = self.orders | |
# print(gate.shape) | |
ids0 = (gate == 0) | |
# x_l0 = self.ffn1(x[ids0]) | |
x_l0 = self.ffn1(x[ids0]) | |
order_l0 = orders[ids0] | |
ids1 = (gate == 1) | |
x_l1 = self.ffn2(x[ids1]) | |
order_l1 = orders[ids1] | |
ids2 = (gate == 2) | |
x_l2 = self.ffn3(x[ids2]) | |
order_l2 = orders[ids2] | |
ids3 = (gate == 3) | |
x_l3 = self.ffn4(x[ids3]) | |
order_l3 = orders[ids3] | |
# print(x_l1.shape, x_l2.shape, x_l3.shape, x_l0.shape) | |
x = torch.cat([x_l0, x_l1, x_l2, x_l3], dim=0) | |
order = torch.cat([order_l0, order_l1, order_l2, order_l3], dim=0) | |
x = x[order.argsort(0)] # restore original order | |
x = x.view(1, self.seq_len, self.input_dim) | |
return x | |
def create_nncf_model(torch_module, nncf_config, train_dataloader, onnx_name): | |
import nncf | |
from nncf.config import NNCFConfig | |
from nncf.config.structures import (BNAdaptationInitArgs, | |
QuantizationRangeInitArgs) | |
from nncf.torch import create_compressed_model | |
from nncf.torch.initialization import PTInitializingDataLoader | |
class MyInitializingDataloader(PTInitializingDataLoader): | |
def get_inputs(self, dataloader_output): | |
return (), dataloader_output | |
nncf_config.register_extra_structs([ | |
QuantizationRangeInitArgs(MyInitializingDataloader(train_dataloader)), | |
BNAdaptationInitArgs(MyInitializingDataloader(train_dataloader)), | |
]) | |
torch_module.train() | |
compression_ctrl, compressed_model = create_compressed_model(torch_module, nncf_config) | |
onnx_path = os.path.join( '/tmp', f'{onnx_name}.onnx') | |
# print(compressed_model) | |
compression_ctrl.export_model(onnx_path) | |
return onnx_path | |
@dataclass | |
class BenchmarkResult: | |
stdout: str = '' | |
stderr: str = '' | |
avg_latency: float = 0. | |
throughput: float = 0. | |
def run_benchmark(cmd): | |
import subprocess | |
with subprocess.Popen(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) as p: | |
p.wait() | |
stdout = p.stdout.read().decode() | |
stderr = p.stderr.read().decode() | |
# print(stdout) | |
# print(stderr) | |
return BenchmarkResult( | |
stdout=stdout, | |
stderr=stderr, | |
throughput=list(filter(lambda x: x.strip(), stdout.split('\n')))[-1].split()[-2], | |
avg_latency=list(filter(lambda x: x.strip(), stdout.split('\n')))[-4].split()[-2], | |
) | |
# %% | |
import pandas as pd | |
from torch.utils.data import DataLoader, Dataset | |
from nncf.config import NNCFConfig | |
from copy import deepcopy | |
num_threads = torch.get_num_threads() | |
print(f'Benchmarking on {num_threads} threads') | |
num_trials = 2000 | |
results = [] | |
benchmark_int8_results = [] | |
benchmark_fp32_results = [] | |
for seq_len in [64, 128, 256, 384, 768, 1024]: | |
for input_dim in [64, 128, 256, 512, 768]: | |
x = torch.rand((1, seq_len, input_dim)) | |
ffn_model = FeedForward(input_dim, input_dim * 4) | |
# print(seq_len, input_dim) | |
max_expert = seq_len | |
while max_expert > 0.5 * seq_len: # ensures better balance | |
moe_model = MoEFFN(input_dim, input_dim, num_experts=4, seq_len=seq_len) | |
nn.init.normal_(moe_model.gating.weight) | |
nn.init.zeros_(moe_model.gating.bias) | |
logits = moe_model.gating(x.view(-1, input_dim)) | |
gate = torch.argmax(logits, dim=-1) | |
max_expert = max((gate == i).sum() for i in range(4)) | |
@torch.no_grad() | |
def ffn(x): | |
return ffn_model(x) | |
@torch.no_grad() | |
def moe(x): | |
return moe_model(x) | |
t0 = benchmark.Timer( | |
stmt='ffn(x)', | |
globals={'x': x, 'moe': moe, 'ffn': ffn}, | |
num_threads=num_threads, | |
description='vanilla FFN', | |
label='benchmark', | |
sub_label=f'seq_len{seq_len:<3d}, hidden_dim{input_dim:<3d}').timeit(num_trials) | |
t1 = benchmark.Timer( | |
stmt='moe(x)', | |
globals={'x': x, 'moe': moe, 'ffn': ffn}, | |
num_threads=num_threads, | |
description='MoE', | |
label='benchmark', | |
sub_label=f'seq_len{seq_len:<3d}, hidden_dim{input_dim:<3d}').timeit(num_trials) | |
results.append(t0) | |
results.append(t1) | |
## IR export | |
class MyDataset(Dataset): | |
def __len__(self): | |
return 500 | |
def __getitem__(self, index): | |
return {'hidden_states': torch.rand((seq_len, input_dim))} | |
dataloader = DataLoader(MyDataset(), 1) | |
nncf_config_int8 = NNCFConfig.from_dict( | |
{ | |
"input_info": [ | |
{ | |
"sample_size": [ | |
1, | |
seq_len, | |
input_dim | |
], | |
} | |
], | |
"compression": { | |
"algorithm": "quantization", | |
"preset": "mixed", | |
"overflow_fix": "disable", | |
"initializer": { | |
"range": { | |
"num_init_samples": 0, | |
"type": "mean_min_max" | |
}, | |
"batchnorm_adaptation": { | |
"num_bn_adaptation_samples": 0 | |
}, | |
}, | |
"scope_overrides": { | |
"activations": { | |
"{re}.*matmul_0": { | |
"mode": "symmetric" | |
} | |
} | |
}, | |
"ignored_scopes": [ | |
# "{re}.*Embedding*", | |
"{re}.*__add___[0-1]", | |
"{re}.*layer_norm_0", | |
# "{re}.*matmul_1", | |
# "{re}.*__truediv__*", | |
], | |
} | |
} | |
) | |
nncf_config_fp32 = NNCFConfig.from_dict( | |
{ | |
"input_info": [ | |
{ | |
"sample_size": [ | |
1, | |
seq_len, | |
input_dim | |
], | |
} | |
], | |
"compression": [] | |
} | |
) | |
for datatype, nncf_config in dict(int8=nncf_config_int8, fp32=nncf_config_fp32).items(): | |
ffn_onnx = create_nncf_model(deepcopy(ffn_model), deepcopy(nncf_config), dataloader, f'{seq_len}x{input_dim}ffn-{datatype}') | |
moe_onnx = create_nncf_model(deepcopy(moe_model), deepcopy(nncf_config), dataloader, f'{seq_len}x{input_dim}moe-{datatype}') | |
benchmark_results = benchmark_int8_results if datatype == 'int8' else benchmark_fp32_results | |
benchmark_result = run_benchmark(cmd=f'benchmark_app -m {ffn_onnx} -niter 3000 -hint latency -infer_precision f32') | |
benchmark_results.append( | |
dict( | |
model='ffn', | |
identifier=(seq_len, input_dim), | |
latency=benchmark_result.avg_latency | |
) | |
) | |
benchmark_result = run_benchmark(cmd=f'benchmark_app -m {moe_onnx} -niter 3000 -hint latency -infer_precision f32') | |
benchmark_results.append( | |
dict( | |
model='moe', | |
identifier=(seq_len, input_dim), | |
latency=benchmark_result.avg_latency | |
) | |
) | |
compare = benchmark.Compare(results) | |
compare.print() | |
print(pd.DataFrame(benchmark_fp32_results)) | |
print(pd.DataFrame(benchmark_int8_results)) | |
# %% | |
# %% | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment