Created
December 30, 2022 18:50
-
-
Save rpeloff/89b28f2092283f279c37f74135ee5f6b to your computer and use it in GitHub Desktop.
TensorFlow 2.2 implementation of Mixture-of-Experts with switch routing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import math | |
from functools import wraps | |
import tensorflow as tf | |
import tf2_profiler | |
def uncapped_switch_layer( | |
inputs, | |
experts, | |
router_weights, | |
capacity_factor, | |
epsilon=None, | |
training=None, | |
mask=None, | |
): | |
"""A data-parallel and expert-parallel distributed MoE switch layer. | |
inputs must have shape [batch_size, tokens_per_sequence, d_model] | |
NOTE: this layer does not support model-parallel splitting of experts over multiple replicas. | |
Based on the paper pseudo-code and mesh-tensorflow implementation: | |
- cf. https://arxiv.org/pdf/2101.03961.pdf | |
- cf. https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py | |
""" | |
inputs_shape_dynamic = tf.shape(inputs) | |
inputs_shape_static = inputs.shape | |
input_shape = [ | |
inputs_shape_static[i] or inputs_shape_dynamic[i] | |
for i in range(len(inputs_shape_static)) | |
] | |
batch_size = input_shape[0] | |
tokens_per_sequence = input_shape[1] | |
d_inputs = input_shape[2] | |
n_experts = len(experts) | |
# this replica in a data-parallel strategy will route tokens_per_batch tokens to the correct experts | |
tokens_per_batch = batch_size * tokens_per_sequence | |
# each expert will have shape [expert_capacity, d_inputs] | |
expert_capacity = ( | |
tf.cast(tokens_per_batch, tf.float32) * capacity_factor / n_experts | |
) | |
# reshape inputs to setup expert dispatching | |
# inputs shape: [batch_size, tokens_per_sequence, d_inputs] -> [tokens_per_batch, d_inputs] | |
inputs = tf.reshape(inputs, [tokens_per_batch, d_inputs]) | |
# compute dispatch and combine tensors used to route tokens to and from experts | |
# - dispatch_tensor: used for routing tokens to the correct expert | |
# shape: [tokens_per_batch, n_experts, expert_capacity] | |
# - combine_tensor: used for combining expert outputs and scaling with router probability | |
# shape: [tokens_per_batch, n_experts, expert_capacity] | |
# dispatch_tensor, combine_tensor, aux_loss = router( | |
# inputs, training=training, mask=mask) | |
expert_index, expert_gate, aux_loss = uncapped_router( | |
inputs, | |
router_weights, | |
n_experts, | |
expert_capacity, | |
epsilon=epsilon, | |
training=training, | |
mask=mask, | |
) | |
expert_inputs_indices = [ | |
tf.where(tf.equal(expert_index, i)) for i in range(n_experts) | |
] | |
expert_inputs_list = [ | |
tf.gather_nd(inputs, expert_indices) for expert_indices in expert_inputs_indices | |
] | |
expert_outputs_list = [ | |
expert(expert_input) | |
for expert, expert_input in zip(experts, expert_inputs_list) | |
] | |
expert_outputs = tf.concat(expert_outputs_list, axis=0) | |
d_model = tf.shape(expert_outputs)[-1] | |
combine_outputs_indices = tf.concat(expert_inputs_indices, axis=0) | |
combine_outputs_shape = tf.cast( | |
[tokens_per_batch, d_model], combine_outputs_indices.dtype | |
) | |
expert_outputs_combined = tf.scatter_nd( | |
combine_outputs_indices, expert_outputs, combine_outputs_shape | |
) | |
expert_outputs_combined *= expert_gate[..., tf.newaxis] | |
# matrix multiply inputs and dispatch tensor to assign tokens to the correct expert | |
# bd,bec->ecd => tf.transpose(dispatch_tensor, perm=[1,2,0]) @ inputs | |
# expert_inputs shape: [n_experts, expert_capacity, d_inputs] | |
# expert_inputs = tf.linalg.einsum("bd,bec->ecd", inputs, dispatch_tensor) | |
# dispatch inputs to experts | |
# expert_inputs_list = tf.unstack(expert_inputs, axis=0) | |
# expert_outputs_list = [ | |
# expert(expert_input) for expert, expert_input in zip(experts, expert_inputs_list) | |
# ] | |
# expert_outputs = tf.stack(expert_outputs_list, axis=0) | |
# convert expert outputs back to the input shape and multiply by the routing probability | |
# expert_outputs_combined shape: [tokens_per_batch, d_model] | |
# expert_outputs_combined = tf.linalg.einsum("ecd,bec->bd", expert_outputs, combine_tensor) | |
# remove tokens_per_batch shape used for local routing dispatching to match input shape | |
# outputs shape: [batch_size, tokens_per_sequence, d_model] | |
outputs = tf.reshape( | |
expert_outputs_combined, [batch_size, tokens_per_sequence, d_model] | |
) | |
return outputs, aux_loss | |
def uncapped_router( | |
inputs, | |
router_weights, | |
n_experts, | |
expert_capacity, | |
epsilon=None, | |
training=None, | |
mask=None, | |
): | |
"""Produce dispatch/combine tensors used for sending/receiving tokens to/from highest probability experts. | |
inputs must have shape [tokens_per_batch, d_model] | |
Based on the paper pseudo-code and mesh-tensorflow implementation: | |
- cf. https://arxiv.org/pdf/2101.03961.pdf | |
- cf. https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py | |
""" | |
router_logits = router_weights(inputs) # shape: [tokens_per_batch, n_experts] | |
if training is True and epsilon is not None: | |
# add noise for exploration across experts | |
epsilon = tf.cast(epsilon, tf.float32) | |
router_logits += tf.random.uniform( | |
tf.shape(router_logits), minval=1 - epsilon, maxval=1 + epsilon | |
) | |
# convert logits input to softmax operation to tf.float32 for stability in case of mixed precision | |
router_logits = tf.cast(router_logits, tf.float32) | |
# probabilities for each token of which expert it should be sent to | |
router_probs = tf.nn.softmax(router_logits, axis=-1) | |
# get the top-1 expert for each token | |
# - expert_gate: the top-1 probability for each token | |
# - expert_index: the the expert each token is going to be routed to | |
expert_gate, expert_index = tf.math.top_k(router_probs, k=1) | |
# squeeze expert_gate and expert_index for top-1 shape | |
expert_gate = tf.squeeze(expert_gate, axis=-1) # shape: [tokens_per_batch] | |
expert_index = tf.squeeze(expert_index, axis=-1) # shape: [tokens_per_batch] | |
# expert_mask shape: [tokens_per_batch, n_experts] | |
expert_mask = tf.one_hot(expert_index, depth=n_experts, dtype=router_probs.dtype) | |
# compute the load balancing loss | |
aux_loss = load_balance_loss(router_probs, expert_mask) | |
# # experts have a fixed capacity (i.e. batch size per expert), ensure we do not exceed it! | |
# # construct tensor indicating the position of each token in the receiving expert | |
# # position_in_expert shape: [tokens_per_batch, n_experts] | |
# position_in_expert = tf.math.cumsum(expert_mask, exclusive=True, axis=0) * expert_mask | |
# # keep only the tokens that fit within each expert (i.e. position in receiving expert lower than expert_capacity | |
# expert_mask *= tf.cast(tf.math.less(position_in_expert, tf.cast(expert_capacity, tf.float32)), tf.float32) | |
# # mask out the experts that have overflowed the expert capacity | |
# expert_mask_flat = tf.math.reduce_sum(expert_mask, axis=-1) | |
# expert_gate *= expert_mask_flat | |
# expert_index *= tf.cast(expert_mask_flat, tf.int32) | |
# construct one-hot tensor indicating the position of each token in the receiving expert | |
# position_in_expert_one_hot = tf.one_hot( | |
# tf.cast(position_in_expert, tf.int32), | |
# depth=tf.cast(expert_capacity, tf.int32), | |
# dtype=tf.float32, | |
# ) | |
# combine tensor used for combining expert outputs and scaling with router probability | |
# combine_tensor shape: [tokens_per_batch, n_experts, expert_capacity] | |
# combine_tensor = expert_gate[..., tf.newaxis] * expert_mask | |
# combine_tensor = combine_tensor[..., tf.newaxis] * position_in_expert_one_hot | |
# convert outputs back to inputs dtype | |
# combine_tensor = tf.cast(combine_tensor, inputs.dtype) | |
# create binary dispatch tensor that is 1 if the token gets routed to the corresponding expert | |
# dispatch_tensor shape: [tokens_per_batch, n_experts, expert_capacity] | |
# dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), inputs.dtype) | |
# return dispatch_tensor, combine_tensor, aux_loss | |
return expert_index, expert_gate, aux_loss | |
# class Router(tf.keras.layers.Layer): | |
# def __init__(self, n_experts, epsilon, **kwargs): | |
# super().__init__(**kwargs) | |
# self.n_experts = n_experts | |
# self.epsilon = epsilon | |
# self.router_weights = tf.keras.layers.Dense(self.n_experts, use_bias=False, activation=None) | |
# def call(self, inputs, training=None, mask=None): | |
# return router(inputs, self.router_weights, self.n_experts, 8, epsilon=self.epsilon, training=training, mask=mask) | |
class SwitchMoE(tf.keras.layers.Layer): | |
def __init__(self, experts, capacity_factor, epsilon=None, **kwargs): | |
super().__init__(**kwargs) | |
self.experts = experts | |
self.capacity_factor = capacity_factor | |
self.epsilon = epsilon | |
# self.router = Router(len(self.experts), self.epsilon) | |
self.router_weights = tf.keras.layers.Dense( | |
len(self.experts), use_bias=False, activation=None | |
) | |
def call(self, inputs, training=None, mask=None): | |
return uncapped_switch_layer( | |
inputs, | |
self.experts, | |
self.router_weights, | |
self.capacity_factor, | |
self.epsilon, | |
training=training, | |
mask=mask, | |
) | |
def switch_layer( | |
inputs, | |
experts, | |
router_weights, | |
capacity_factor, | |
epsilon=None, | |
training=None, | |
mask=None, | |
): | |
"""A data-parallel and expert-parallel distributed MoE switch layer. | |
inputs must have shape [batch_size, tokens_per_sequence, d_model] | |
NOTE: this layer does not support model-parallel splitting of experts over multiple replicas. | |
Based on the paper pseudo-code and mesh-tensorflow implementation: | |
- cf. https://arxiv.org/pdf/2101.03961.pdf | |
- cf. https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py | |
""" | |
inputs_shape_dynamic = tf.shape(inputs) | |
inputs_shape_static = inputs.shape | |
input_shape = [ | |
inputs_shape_static[i] or inputs_shape_dynamic[i] | |
for i in range(len(inputs_shape_static)) | |
] | |
batch_size = input_shape[0] | |
tokens_per_sequence = input_shape[1] | |
d_inputs = input_shape[2] | |
n_experts = len(experts) | |
# this replica in a data-parallel strategy will route tokens_per_batch tokens to the correct experts | |
tokens_per_batch = batch_size * tokens_per_sequence | |
# each expert will have shape [expert_capacity, d_inputs] | |
expert_capacity = ( | |
tf.cast(tokens_per_batch, tf.float32) * capacity_factor / n_experts | |
) | |
# reshape inputs to setup expert dispatching | |
# inputs shape: [batch_size, tokens_per_sequence, d_inputs] -> [tokens_per_batch, d_inputs] | |
inputs = tf.reshape(inputs, [tokens_per_batch, d_inputs]) | |
# compute dispatch and combine tensors used to route tokens to and from experts | |
# - dispatch_tensor: used for routing tokens to the correct expert | |
# shape: [tokens_per_batch, n_experts, expert_capacity] | |
# - combine_tensor: used for combining expert outputs and scaling with router probability | |
# shape: [tokens_per_batch, n_experts, expert_capacity] | |
# dispatch_tensor, combine_tensor, aux_loss = router( | |
# inputs, training=training, mask=mask) | |
dispatch_tensor, combine_tensor, aux_loss = router( | |
inputs, | |
router_weights, | |
n_experts, | |
expert_capacity, | |
epsilon=epsilon, | |
training=training, | |
mask=mask, | |
) | |
# matrix multiply inputs and dispatch tensor to assign tokens to the correct expert | |
# bd,bec->ecd => tf.transpose(dispatch_tensor, perm=[1,2,0]) @ inputs | |
# expert_inputs shape: [n_experts, expert_capacity, d_inputs] | |
expert_inputs = tf.linalg.einsum("bd,bec->ecd", inputs, dispatch_tensor) | |
# dispatch inputs to experts | |
expert_inputs_list = tf.unstack(expert_inputs, axis=0) | |
expert_outputs_list = [ | |
expert(expert_input) | |
for expert, expert_input in zip(experts, expert_inputs_list) | |
] | |
expert_outputs = tf.stack(expert_outputs_list, axis=0) | |
# convert expert outputs back to the input shape and multiply by the routing probability | |
# expert_outputs_combined shape: [tokens_per_batch, d_model] | |
expert_outputs_combined = tf.linalg.einsum( | |
"ecd,bec->bd", expert_outputs, combine_tensor | |
) | |
# remove tokens_per_batch shape used for local routing dispatching to match input shape | |
# outputs shape: [batch_size, tokens_per_sequence, d_model] | |
d_model = tf.shape(expert_outputs)[-1] | |
outputs = tf.reshape( | |
expert_outputs_combined, [batch_size, tokens_per_sequence, d_model] | |
) | |
return outputs, aux_loss | |
def router( | |
inputs, | |
router_weights, | |
n_experts, | |
expert_capacity, | |
epsilon=None, | |
training=None, | |
mask=None, | |
): | |
"""Produce dispatch/combine tensors used for sending/receiving tokens to/from highest probability experts. | |
inputs must have shape [tokens_per_batch, d_model] | |
Based on the paper pseudo-code and mesh-tensorflow implementation: | |
- cf. https://arxiv.org/pdf/2101.03961.pdf | |
- cf. https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py | |
""" | |
router_logits = router_weights(inputs) # shape: [tokens_per_batch, n_experts] | |
if training is True and epsilon is not None: | |
# add noise for exploration across experts | |
epsilon = tf.cast(epsilon, tf.float32) | |
router_logits += tf.random.uniform( | |
tf.shape(router_logits), minval=1 - epsilon, maxval=1 + epsilon | |
) | |
# convert logits input to softmax operation to tf.float32 for stability in case of mixed precision | |
router_logits = tf.cast(router_logits, tf.float32) | |
# probabilities for each token of which expert it should be sent to | |
router_probs = tf.nn.softmax(router_logits, axis=-1) | |
# get the top-1 expert for each token | |
# - expert_gate: the top-1 probability for each token | |
# - expert_index: the the expert each token is going to be routed to | |
expert_gate, expert_index = tf.math.top_k(router_probs, k=1) | |
# squeeze expert_gate and expert_index for top-1 shape | |
expert_gate = tf.squeeze(expert_gate, axis=-1) # shape: [tokens_per_batch] | |
expert_index = tf.squeeze(expert_index, axis=-1) # shape: [tokens_per_batch] | |
# expert_mask shape: [tokens_per_batch, n_experts] | |
expert_mask = tf.one_hot(expert_index, depth=n_experts, dtype=router_probs.dtype) | |
# compute the load balancing loss | |
aux_loss = load_balance_loss(router_probs, expert_mask) | |
# experts have a fixed capacity (i.e. batch size per expert), ensure we do not exceed it! | |
# construct tensor indicating the position of each token in the receiving expert | |
# position_in_expert shape: [tokens_per_batch, n_experts] | |
position_in_expert = ( | |
tf.math.cumsum(expert_mask, exclusive=True, axis=0) * expert_mask | |
) | |
# keep only the tokens that fit within each expert (i.e. position in receiving expert lower than expert_capacity | |
expert_mask *= tf.cast( | |
tf.math.less(position_in_expert, tf.cast(expert_capacity, tf.float32)), | |
tf.float32, | |
) | |
# mask out the experts that have overflowed the expert capacity | |
expert_mask_flat = tf.math.reduce_sum(expert_mask, axis=-1) | |
expert_gate *= expert_mask_flat | |
# construct one-hot tensor indicating the position of each token in the receiving expert | |
position_in_expert_one_hot = tf.one_hot( | |
tf.cast(position_in_expert, tf.int32), | |
depth=tf.cast(expert_capacity, tf.int32), | |
dtype=tf.float32, | |
) | |
# combine tensor used for combining expert outputs and scaling with router probability | |
# combine_tensor shape: [tokens_per_batch, n_experts, expert_capacity] | |
combine_tensor = expert_gate[..., tf.newaxis] * expert_mask | |
combine_tensor = combine_tensor[..., tf.newaxis] * position_in_expert_one_hot | |
# convert outputs back to inputs dtype | |
combine_tensor = tf.cast(combine_tensor, inputs.dtype) | |
# create binary dispatch tensor that is 1 if the token gets routed to the corresponding expert | |
# dispatch_tensor shape: [tokens_per_batch, n_experts, expert_capacity] | |
dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), inputs.dtype) | |
return dispatch_tensor, combine_tensor, aux_loss | |
def load_balance_loss(router_probs, expert_mask): | |
"""Calculate load-balancing loss to ensure diverse expert routing.""" | |
# router_probs [tokens_per_batch, num_experts] is the probability assigned for | |
# each expert per token. expert_mask [tokens_per_batch, num_experts] contains | |
# the expert with the highest router probability in one−hot format. | |
num_experts = tf.shape(expert_mask)[-1] | |
# Get the fraction of tokens routed to each expert. | |
# density is a vector of length num experts that sums to 1. | |
density = tf.reduce_mean(expert_mask, axis=0) | |
# Get fraction of probability mass assigned to each expert from the router | |
# across all tokens. density_proxy is a vector of length num experts that sums to 1. | |
density_proxy = tf.reduce_mean(router_probs, axis=0) | |
# Want both vectors to have uniform allocation (1/num experts) across all | |
# num_expert elements. The two vectors will be pushed towards uniform allocation | |
# when the dot product is minimized. | |
loss = tf.reduce_mean(density_proxy * density) * tf.cast( | |
(num_experts**2), tf.dtypes.float32 | |
) | |
return loss | |
def _on_device(device=None, name="layer"): | |
def decorator(cls_method): | |
@wraps(cls_method) | |
def cls_method_wrapper(*args, **kwargs): | |
if device: | |
with tf.device(device): | |
outputs = cls_method(*args, **kwargs) | |
if outputs is not None: | |
tf.print( | |
f"Inputs of {name} with shape {[arg.shape for arg in args if hasattr(arg, 'shape')]} and device placement {device} on device: {[arg.device for arg in args if hasattr(arg, 'device')]}" | |
) | |
tf.print( | |
f"Output of {name} with shape {outputs.shape} and device placement {device} on device: {outputs.device}" | |
) | |
return outputs | |
return cls_method(*args, **kwargs) | |
return cls_method_wrapper | |
return decorator | |
def OnDevice(layer_cls, device=None): | |
"""Simple class decorator that executes layer build and call methods on a specific device.""" | |
# acts like tf.keras.layers.Wrapper without subclassing and will not be saved with the model | |
assert isinstance(layer_cls, tf.keras.layers.Layer) | |
layer_cls.build = _on_device(device, layer_cls.name)(layer_cls.build) | |
layer_cls.call = _on_device(device, layer_cls.name)(layer_cls.call) | |
return layer_cls | |
def build_simple_model_parallelism_model(input_shape, d_hidden, devices): | |
assert len(devices) >= 2 | |
inputs = tf.keras.Input(shape=input_shape) | |
dense_1 = OnDevice( | |
tf.keras.layers.Dense(d_hidden, use_bias=False), device=devices[0] | |
) | |
outputs_1 = dense_1(inputs) | |
dense_2 = OnDevice( | |
tf.keras.layers.Dense(d_hidden, use_bias=False), device=devices[1] | |
) | |
outputs_2 = dense_2(inputs) | |
add = tf.keras.layers.Add() | |
outputs = add([outputs_1, outputs_2]) | |
model = tf.keras.Model(inputs=inputs, outputs=outputs) | |
return model | |
def build_simple_expert_parallelism_model(input_shape, d_hidden, devices): | |
assert len(devices) >= 2 | |
inputs = tf.keras.Input(shape=input_shape) | |
expert_1 = OnDevice( | |
tf.keras.layers.Dense(d_hidden, use_bias=False, name="simple_expert_1"), | |
device=devices[0], | |
) | |
expert_2 = OnDevice( | |
tf.keras.layers.Dense(d_hidden, use_bias=False, name="simple_expert_2"), | |
device=devices[1], | |
) | |
# split half of sequence tokens to one expert and remaining tokens to second expert | |
token_split = input_shape[0] // 2 | |
expert_1_inputs = inputs[:, :token_split] | |
expert_2_inputs = inputs[:, token_split:] | |
expert_1_outputs = expert_1(expert_1_inputs) | |
expert_2_outputs = expert_2(expert_2_inputs) | |
combine_experts = tf.keras.layers.Concatenate(axis=1) | |
outputs = combine_experts([expert_1_outputs, expert_2_outputs]) | |
model = tf.keras.Model(inputs=inputs, outputs=outputs) | |
return model | |
def build_simple_switch_layer_expert_parallelism_model( | |
input_shape, d_hidden, n_experts, capacity_factor, devices, batch_size=None | |
): | |
assert len(devices) >= 2 | |
inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size) | |
experts_per_device = math.ceil(n_experts / len(devices)) | |
experts = [ | |
OnDevice( | |
tf.keras.layers.Dense( | |
d_hidden, use_bias=False, name=f"switch_expert_{i+1}" | |
), | |
device=devices[i // experts_per_device], | |
) | |
for i in range(n_experts) | |
] | |
switch_moe = SwitchMoE(experts, capacity_factor, epsilon=None) | |
outputs, aux_loss = switch_moe(inputs) | |
model = tf.keras.Model(inputs=inputs, outputs=outputs) | |
return model | |
def matmul_flops(m, p, q, exact=False): | |
"""Compute floating point operations of a matrix multiplication AB where A[m,p] and B[p,q].""" | |
if exact: | |
return (2 * p - 1) * m * q | |
# cf. https://github.com/tensorflow/tensorflow/pull/19792#issuecomment-415607267 | |
return 2 * m * p * q | |
def setup_tf(): | |
"""Can only be called at start of program.""" | |
tf.debugging.set_log_device_placement(True) | |
tf.config.set_soft_device_placement(False) | |
# specify 2 virtual CPUs to simulate multi-device training | |
# source: https://www.tensorflow.org/api_docs/python/tf/config/set_logical_device_configuration | |
physical_devices = tf.config.list_physical_devices("CPU") | |
assert len(physical_devices) == 1, "No CPUs found" | |
tf.config.set_logical_device_configuration( | |
physical_devices[0], | |
[ | |
tf.config.LogicalDeviceConfiguration(), | |
tf.config.LogicalDeviceConfiguration(), | |
], | |
) | |
logical_devices = tf.config.list_logical_devices("CPU") | |
devices = [device.name for device in logical_devices] | |
print(f"Logical CPU devices: {devices}") | |
return devices | |
def create_distibuted_strategy(devices): | |
"""Created a simple data-parallel mirrored strategy.""" | |
return tf.distribute.MirroredStrategy(devices=devices) | |
def create_dummy_dataset(n_samples, input_shape, batch_size): | |
return ( | |
tf.data.Dataset.from_tensors( | |
(tf.random.normal([n_samples, *input_shape])), | |
) | |
.unbatch() | |
.batch(batch_size) | |
) | |
def main(): | |
batch_size = 3 | |
seq_length = 8 | |
d_input = 32 | |
d_hidden = 128 | |
n_experts = 2 | |
capacity_factor = 1.0 | |
input_shape = (seq_length, d_input) | |
devices = setup_tf() | |
n_devices = len(devices) | |
# Benchmark example with simple model-parallelism | |
# ----------------------------------------------- | |
print("Building simple model") | |
model = build_simple_model_parallelism_model( | |
input_shape, d_hidden=d_hidden, devices=devices | |
) | |
print("Check device op placement") | |
x = tf.random.uniform((batch_size, *input_shape)) | |
y = model(x) | |
print("Compute FLOPS per example") | |
expected_flops = matmul_flops(seq_length, d_input, d_hidden) # dense_1 | |
expected_flops += matmul_flops(seq_length, d_input, d_hidden) # dense_2 | |
tf2_profiler.profile_flops(model, input_shape=input_shape, batch_size=1) | |
print(f"Expected float ops: {expected_flops:,}") | |
print() | |
# Benchmark 2 expert MoE model example | |
# ------------------------------------ | |
print("Building 2 expert MoE model") | |
model = build_simple_expert_parallelism_model( | |
input_shape, d_hidden=d_hidden, devices=devices | |
) | |
print("Check device op placement") | |
x = tf.random.uniform((batch_size, *input_shape)) | |
y = model(x) | |
print("Compute FLOPS per example") | |
expected_flops = matmul_flops(seq_length, d_input, d_hidden) // 2 # expert_1 | |
expected_flops += matmul_flops(seq_length, d_input, d_hidden) // 2 # expert_2 | |
tf2_profiler.profile_flops(model, input_shape=input_shape, batch_size=1) | |
print(f"Expected float ops: {expected_flops:,}") | |
print() | |
# | |
# --------- | |
print("Building switch-layer 2 expert MoE model") | |
model = build_simple_switch_layer_expert_parallelism_model( | |
input_shape, | |
d_hidden=d_hidden, | |
n_experts=n_experts, | |
capacity_factor=capacity_factor, | |
devices=devices, | |
) | |
print("Check device op placement") | |
x = tf.random.uniform((batch_size, *input_shape)) | |
y = model(x) | |
print("Compute FLOPS per example") | |
expected_flops = matmul_flops(seq_length, d_input, n_experts) # expert routing | |
expected_flops += matmul_flops( | |
seq_length, d_input, d_hidden | |
) # mixture of experts switch gating | |
single_sample_model = build_simple_switch_layer_expert_parallelism_model( | |
input_shape, | |
d_hidden=d_hidden, | |
n_experts=n_experts, | |
capacity_factor=capacity_factor, | |
devices=devices, | |
batch_size=1, | |
) | |
tf2_profiler.profile_flops( | |
single_sample_model, input_shape=input_shape, batch_size=1 | |
) | |
print(f"Expected MatMul float ops: {expected_flops:,}") | |
print() | |
# TODO show that data is distributed across cores and that | |
# experts placed on each core cause data to be copied to | |
# across cores to the location of the expert | |
strategy = create_distibuted_strategy(devices) | |
with strategy.scope(): | |
model = build_simple_switch_layer_expert_parallelism_model( | |
input_shape, | |
d_hidden=d_hidden, | |
n_experts=n_experts, | |
capacity_factor=capacity_factor, | |
devices=devices, | |
) | |
dataset = create_dummy_dataset(batch_size * 3, input_shape, batch_size) | |
dist_dataset = strategy.experimental_distribute_dataset(dataset) | |
def distributed_model_fn(inputs): | |
return model(inputs) | |
@tf.function | |
def distributed_model(dist_inputs): | |
per_replica_outputs = strategy.run(distributed_model_fn, args=(dist_inputs,)) | |
return per_replica_outputs | |
print("\nCheck device op placement") | |
for dist_inputs in dist_dataset: | |
print(f"Dist inputs: {dist_inputs}\n") | |
print(f"Dist outputs: {distributed_model(dist_inputs)}\n") | |
break | |
... | |
print("\nBuilding 2 expert MoE model with data-parallelism") | |
# TODO show that data is distributed across cores and that | |
# experts placed on each core cause data to be copied to | |
# across cores to the location of the expert | |
strategy = create_distibuted_strategy(devices) | |
with strategy.scope(): | |
model = build_simple_expert_parallelism_model( | |
input_shape, d_hidden=d_hidden, devices=devices | |
) | |
dataset = create_dummy_dataset(batch_size * 3, input_shape, batch_size) | |
dist_dataset = strategy.experimental_distribute_dataset(dataset) | |
def distributed_model_fn(inputs): | |
return model(inputs) | |
@tf.function | |
def distributed_model(dist_inputs): | |
per_replica_outputs = strategy.run(distributed_model_fn, args=(dist_inputs,)) | |
return per_replica_outputs | |
print("\nCheck device op placement") | |
for dist_inputs in dist_dataset: | |
print(f"Dist inputs: {dist_inputs}\n") | |
print(f"Dist outputs: {distributed_model(dist_inputs)}\n") | |
break | |
print("Compute FLOPS per example") | |
expected_flops = ( | |
matmul_flops(seq_length, d_input, d_hidden) // 2 | |
) # dense_1 half input | |
expected_flops += ( | |
matmul_flops(seq_length, d_input, d_hidden) // 2 | |
) # dense_2 half input | |
tf2_profiler.profile_flops(model, input_shape=input_shape, batch_size=1) | |
print(f"Expected float ops: {expected_flops:,}") | |
n_experts = 4 | |
capacity_factor = 1.0 | |
experts_per_device = math.ceil(n_experts / n_devices) | |
experts = [ | |
OnDevice( | |
tf.keras.layers.Dense(d_input, use_bias=False), | |
device=devices[i // experts_per_device], | |
) | |
for i in range(n_experts) | |
] | |
inputs = tf.random.normal((batch_size, *input_shape)) | |
router_weights = tf.keras.layers.Dense( | |
len(experts), use_bias=False, activation=None | |
) | |
outputs, aux_loss = uncapped_switch_layer( | |
inputs, | |
experts, | |
router_weights, | |
capacity_factor, | |
epsilon=None, | |
training=None, | |
mask=None, | |
) | |
if __name__ == "__main__": | |
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# sourced from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-849439287 | |
import tensorflow as tf | |
from tensorflow.python.profiler.model_analyzer import profile | |
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder | |
def profile_flops( | |
model, input_shape=None, batch_size=1, input_signature=None, cmd="op" | |
): | |
"""Profile floating point operations of TensorFlow Keras model.""" | |
if not input_signature: | |
if not input_shape: | |
input_shape = model.input_shape[1:] | |
input_shape = (batch_size,) + input_shape | |
input_signature = [tf.TensorSpec(shape=input_shape)] | |
forward_pass = tf.function(model.call, input_signature=input_signature) | |
graph_info = profile( | |
forward_pass.get_concrete_function().graph, | |
options=ProfileOptionBuilder.float_operation(), | |
cmd=cmd, | |
) | |
# NOTE: `profile` counts multiply and accumulate as two flops, instead fused multiply accumulate ops are half this | |
print(f"Total float ops: {graph_info.total_float_ops:,}") | |
def profile_flops_and_params_v1_compat(model_func): | |
def wrapper(*args, **kwargs): | |
session = tf.compat.v1.Session() | |
graph = tf.compat.v1.get_default_graph() | |
with graph.as_default(): | |
with session.as_default(): | |
result = model_func(*args, **kwargs) | |
run_meta = tf.compat.v1.RunMetadata() | |
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation() | |
flops = tf.compat.v1.profiler.profile( | |
graph=graph, run_meta=run_meta, cmd="op", options=opts | |
) | |
opts = ( | |
tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter() | |
) | |
params = tf.compat.v1.profiler.profile( | |
graph, run_meta=run_meta, cmd="op", options=opts | |
) | |
print(f"Total float ops: {flops.total_float_ops:,}") | |
print(f"Total parameters: {params.total_parameters}") | |
return result | |
return wrapper |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment