Skip to content

Instantly share code, notes, and snippets.

@rpeloff
Created December 30, 2022 18:50
Show Gist options
  • Save rpeloff/89b28f2092283f279c37f74135ee5f6b to your computer and use it in GitHub Desktop.
Save rpeloff/89b28f2092283f279c37f74135ee5f6b to your computer and use it in GitHub Desktop.
TensorFlow 2.2 implementation of Mixture-of-Experts with switch routing
import math
from functools import wraps
import tensorflow as tf
import tf2_profiler
def uncapped_switch_layer(
inputs,
experts,
router_weights,
capacity_factor,
epsilon=None,
training=None,
mask=None,
):
"""A data-parallel and expert-parallel distributed MoE switch layer.
inputs must have shape [batch_size, tokens_per_sequence, d_model]
NOTE: this layer does not support model-parallel splitting of experts over multiple replicas.
Based on the paper pseudo-code and mesh-tensorflow implementation:
- cf. https://arxiv.org/pdf/2101.03961.pdf
- cf. https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py
"""
inputs_shape_dynamic = tf.shape(inputs)
inputs_shape_static = inputs.shape
input_shape = [
inputs_shape_static[i] or inputs_shape_dynamic[i]
for i in range(len(inputs_shape_static))
]
batch_size = input_shape[0]
tokens_per_sequence = input_shape[1]
d_inputs = input_shape[2]
n_experts = len(experts)
# this replica in a data-parallel strategy will route tokens_per_batch tokens to the correct experts
tokens_per_batch = batch_size * tokens_per_sequence
# each expert will have shape [expert_capacity, d_inputs]
expert_capacity = (
tf.cast(tokens_per_batch, tf.float32) * capacity_factor / n_experts
)
# reshape inputs to setup expert dispatching
# inputs shape: [batch_size, tokens_per_sequence, d_inputs] -> [tokens_per_batch, d_inputs]
inputs = tf.reshape(inputs, [tokens_per_batch, d_inputs])
# compute dispatch and combine tensors used to route tokens to and from experts
# - dispatch_tensor: used for routing tokens to the correct expert
# shape: [tokens_per_batch, n_experts, expert_capacity]
# - combine_tensor: used for combining expert outputs and scaling with router probability
# shape: [tokens_per_batch, n_experts, expert_capacity]
# dispatch_tensor, combine_tensor, aux_loss = router(
# inputs, training=training, mask=mask)
expert_index, expert_gate, aux_loss = uncapped_router(
inputs,
router_weights,
n_experts,
expert_capacity,
epsilon=epsilon,
training=training,
mask=mask,
)
expert_inputs_indices = [
tf.where(tf.equal(expert_index, i)) for i in range(n_experts)
]
expert_inputs_list = [
tf.gather_nd(inputs, expert_indices) for expert_indices in expert_inputs_indices
]
expert_outputs_list = [
expert(expert_input)
for expert, expert_input in zip(experts, expert_inputs_list)
]
expert_outputs = tf.concat(expert_outputs_list, axis=0)
d_model = tf.shape(expert_outputs)[-1]
combine_outputs_indices = tf.concat(expert_inputs_indices, axis=0)
combine_outputs_shape = tf.cast(
[tokens_per_batch, d_model], combine_outputs_indices.dtype
)
expert_outputs_combined = tf.scatter_nd(
combine_outputs_indices, expert_outputs, combine_outputs_shape
)
expert_outputs_combined *= expert_gate[..., tf.newaxis]
# matrix multiply inputs and dispatch tensor to assign tokens to the correct expert
# bd,bec->ecd => tf.transpose(dispatch_tensor, perm=[1,2,0]) @ inputs
# expert_inputs shape: [n_experts, expert_capacity, d_inputs]
# expert_inputs = tf.linalg.einsum("bd,bec->ecd", inputs, dispatch_tensor)
# dispatch inputs to experts
# expert_inputs_list = tf.unstack(expert_inputs, axis=0)
# expert_outputs_list = [
# expert(expert_input) for expert, expert_input in zip(experts, expert_inputs_list)
# ]
# expert_outputs = tf.stack(expert_outputs_list, axis=0)
# convert expert outputs back to the input shape and multiply by the routing probability
# expert_outputs_combined shape: [tokens_per_batch, d_model]
# expert_outputs_combined = tf.linalg.einsum("ecd,bec->bd", expert_outputs, combine_tensor)
# remove tokens_per_batch shape used for local routing dispatching to match input shape
# outputs shape: [batch_size, tokens_per_sequence, d_model]
outputs = tf.reshape(
expert_outputs_combined, [batch_size, tokens_per_sequence, d_model]
)
return outputs, aux_loss
def uncapped_router(
inputs,
router_weights,
n_experts,
expert_capacity,
epsilon=None,
training=None,
mask=None,
):
"""Produce dispatch/combine tensors used for sending/receiving tokens to/from highest probability experts.
inputs must have shape [tokens_per_batch, d_model]
Based on the paper pseudo-code and mesh-tensorflow implementation:
- cf. https://arxiv.org/pdf/2101.03961.pdf
- cf. https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py
"""
router_logits = router_weights(inputs) # shape: [tokens_per_batch, n_experts]
if training is True and epsilon is not None:
# add noise for exploration across experts
epsilon = tf.cast(epsilon, tf.float32)
router_logits += tf.random.uniform(
tf.shape(router_logits), minval=1 - epsilon, maxval=1 + epsilon
)
# convert logits input to softmax operation to tf.float32 for stability in case of mixed precision
router_logits = tf.cast(router_logits, tf.float32)
# probabilities for each token of which expert it should be sent to
router_probs = tf.nn.softmax(router_logits, axis=-1)
# get the top-1 expert for each token
# - expert_gate: the top-1 probability for each token
# - expert_index: the the expert each token is going to be routed to
expert_gate, expert_index = tf.math.top_k(router_probs, k=1)
# squeeze expert_gate and expert_index for top-1 shape
expert_gate = tf.squeeze(expert_gate, axis=-1) # shape: [tokens_per_batch]
expert_index = tf.squeeze(expert_index, axis=-1) # shape: [tokens_per_batch]
# expert_mask shape: [tokens_per_batch, n_experts]
expert_mask = tf.one_hot(expert_index, depth=n_experts, dtype=router_probs.dtype)
# compute the load balancing loss
aux_loss = load_balance_loss(router_probs, expert_mask)
# # experts have a fixed capacity (i.e. batch size per expert), ensure we do not exceed it!
# # construct tensor indicating the position of each token in the receiving expert
# # position_in_expert shape: [tokens_per_batch, n_experts]
# position_in_expert = tf.math.cumsum(expert_mask, exclusive=True, axis=0) * expert_mask
# # keep only the tokens that fit within each expert (i.e. position in receiving expert lower than expert_capacity
# expert_mask *= tf.cast(tf.math.less(position_in_expert, tf.cast(expert_capacity, tf.float32)), tf.float32)
# # mask out the experts that have overflowed the expert capacity
# expert_mask_flat = tf.math.reduce_sum(expert_mask, axis=-1)
# expert_gate *= expert_mask_flat
# expert_index *= tf.cast(expert_mask_flat, tf.int32)
# construct one-hot tensor indicating the position of each token in the receiving expert
# position_in_expert_one_hot = tf.one_hot(
# tf.cast(position_in_expert, tf.int32),
# depth=tf.cast(expert_capacity, tf.int32),
# dtype=tf.float32,
# )
# combine tensor used for combining expert outputs and scaling with router probability
# combine_tensor shape: [tokens_per_batch, n_experts, expert_capacity]
# combine_tensor = expert_gate[..., tf.newaxis] * expert_mask
# combine_tensor = combine_tensor[..., tf.newaxis] * position_in_expert_one_hot
# convert outputs back to inputs dtype
# combine_tensor = tf.cast(combine_tensor, inputs.dtype)
# create binary dispatch tensor that is 1 if the token gets routed to the corresponding expert
# dispatch_tensor shape: [tokens_per_batch, n_experts, expert_capacity]
# dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), inputs.dtype)
# return dispatch_tensor, combine_tensor, aux_loss
return expert_index, expert_gate, aux_loss
# class Router(tf.keras.layers.Layer):
# def __init__(self, n_experts, epsilon, **kwargs):
# super().__init__(**kwargs)
# self.n_experts = n_experts
# self.epsilon = epsilon
# self.router_weights = tf.keras.layers.Dense(self.n_experts, use_bias=False, activation=None)
# def call(self, inputs, training=None, mask=None):
# return router(inputs, self.router_weights, self.n_experts, 8, epsilon=self.epsilon, training=training, mask=mask)
class SwitchMoE(tf.keras.layers.Layer):
def __init__(self, experts, capacity_factor, epsilon=None, **kwargs):
super().__init__(**kwargs)
self.experts = experts
self.capacity_factor = capacity_factor
self.epsilon = epsilon
# self.router = Router(len(self.experts), self.epsilon)
self.router_weights = tf.keras.layers.Dense(
len(self.experts), use_bias=False, activation=None
)
def call(self, inputs, training=None, mask=None):
return uncapped_switch_layer(
inputs,
self.experts,
self.router_weights,
self.capacity_factor,
self.epsilon,
training=training,
mask=mask,
)
def switch_layer(
inputs,
experts,
router_weights,
capacity_factor,
epsilon=None,
training=None,
mask=None,
):
"""A data-parallel and expert-parallel distributed MoE switch layer.
inputs must have shape [batch_size, tokens_per_sequence, d_model]
NOTE: this layer does not support model-parallel splitting of experts over multiple replicas.
Based on the paper pseudo-code and mesh-tensorflow implementation:
- cf. https://arxiv.org/pdf/2101.03961.pdf
- cf. https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py
"""
inputs_shape_dynamic = tf.shape(inputs)
inputs_shape_static = inputs.shape
input_shape = [
inputs_shape_static[i] or inputs_shape_dynamic[i]
for i in range(len(inputs_shape_static))
]
batch_size = input_shape[0]
tokens_per_sequence = input_shape[1]
d_inputs = input_shape[2]
n_experts = len(experts)
# this replica in a data-parallel strategy will route tokens_per_batch tokens to the correct experts
tokens_per_batch = batch_size * tokens_per_sequence
# each expert will have shape [expert_capacity, d_inputs]
expert_capacity = (
tf.cast(tokens_per_batch, tf.float32) * capacity_factor / n_experts
)
# reshape inputs to setup expert dispatching
# inputs shape: [batch_size, tokens_per_sequence, d_inputs] -> [tokens_per_batch, d_inputs]
inputs = tf.reshape(inputs, [tokens_per_batch, d_inputs])
# compute dispatch and combine tensors used to route tokens to and from experts
# - dispatch_tensor: used for routing tokens to the correct expert
# shape: [tokens_per_batch, n_experts, expert_capacity]
# - combine_tensor: used for combining expert outputs and scaling with router probability
# shape: [tokens_per_batch, n_experts, expert_capacity]
# dispatch_tensor, combine_tensor, aux_loss = router(
# inputs, training=training, mask=mask)
dispatch_tensor, combine_tensor, aux_loss = router(
inputs,
router_weights,
n_experts,
expert_capacity,
epsilon=epsilon,
training=training,
mask=mask,
)
# matrix multiply inputs and dispatch tensor to assign tokens to the correct expert
# bd,bec->ecd => tf.transpose(dispatch_tensor, perm=[1,2,0]) @ inputs
# expert_inputs shape: [n_experts, expert_capacity, d_inputs]
expert_inputs = tf.linalg.einsum("bd,bec->ecd", inputs, dispatch_tensor)
# dispatch inputs to experts
expert_inputs_list = tf.unstack(expert_inputs, axis=0)
expert_outputs_list = [
expert(expert_input)
for expert, expert_input in zip(experts, expert_inputs_list)
]
expert_outputs = tf.stack(expert_outputs_list, axis=0)
# convert expert outputs back to the input shape and multiply by the routing probability
# expert_outputs_combined shape: [tokens_per_batch, d_model]
expert_outputs_combined = tf.linalg.einsum(
"ecd,bec->bd", expert_outputs, combine_tensor
)
# remove tokens_per_batch shape used for local routing dispatching to match input shape
# outputs shape: [batch_size, tokens_per_sequence, d_model]
d_model = tf.shape(expert_outputs)[-1]
outputs = tf.reshape(
expert_outputs_combined, [batch_size, tokens_per_sequence, d_model]
)
return outputs, aux_loss
def router(
inputs,
router_weights,
n_experts,
expert_capacity,
epsilon=None,
training=None,
mask=None,
):
"""Produce dispatch/combine tensors used for sending/receiving tokens to/from highest probability experts.
inputs must have shape [tokens_per_batch, d_model]
Based on the paper pseudo-code and mesh-tensorflow implementation:
- cf. https://arxiv.org/pdf/2101.03961.pdf
- cf. https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/moe.py
"""
router_logits = router_weights(inputs) # shape: [tokens_per_batch, n_experts]
if training is True and epsilon is not None:
# add noise for exploration across experts
epsilon = tf.cast(epsilon, tf.float32)
router_logits += tf.random.uniform(
tf.shape(router_logits), minval=1 - epsilon, maxval=1 + epsilon
)
# convert logits input to softmax operation to tf.float32 for stability in case of mixed precision
router_logits = tf.cast(router_logits, tf.float32)
# probabilities for each token of which expert it should be sent to
router_probs = tf.nn.softmax(router_logits, axis=-1)
# get the top-1 expert for each token
# - expert_gate: the top-1 probability for each token
# - expert_index: the the expert each token is going to be routed to
expert_gate, expert_index = tf.math.top_k(router_probs, k=1)
# squeeze expert_gate and expert_index for top-1 shape
expert_gate = tf.squeeze(expert_gate, axis=-1) # shape: [tokens_per_batch]
expert_index = tf.squeeze(expert_index, axis=-1) # shape: [tokens_per_batch]
# expert_mask shape: [tokens_per_batch, n_experts]
expert_mask = tf.one_hot(expert_index, depth=n_experts, dtype=router_probs.dtype)
# compute the load balancing loss
aux_loss = load_balance_loss(router_probs, expert_mask)
# experts have a fixed capacity (i.e. batch size per expert), ensure we do not exceed it!
# construct tensor indicating the position of each token in the receiving expert
# position_in_expert shape: [tokens_per_batch, n_experts]
position_in_expert = (
tf.math.cumsum(expert_mask, exclusive=True, axis=0) * expert_mask
)
# keep only the tokens that fit within each expert (i.e. position in receiving expert lower than expert_capacity
expert_mask *= tf.cast(
tf.math.less(position_in_expert, tf.cast(expert_capacity, tf.float32)),
tf.float32,
)
# mask out the experts that have overflowed the expert capacity
expert_mask_flat = tf.math.reduce_sum(expert_mask, axis=-1)
expert_gate *= expert_mask_flat
# construct one-hot tensor indicating the position of each token in the receiving expert
position_in_expert_one_hot = tf.one_hot(
tf.cast(position_in_expert, tf.int32),
depth=tf.cast(expert_capacity, tf.int32),
dtype=tf.float32,
)
# combine tensor used for combining expert outputs and scaling with router probability
# combine_tensor shape: [tokens_per_batch, n_experts, expert_capacity]
combine_tensor = expert_gate[..., tf.newaxis] * expert_mask
combine_tensor = combine_tensor[..., tf.newaxis] * position_in_expert_one_hot
# convert outputs back to inputs dtype
combine_tensor = tf.cast(combine_tensor, inputs.dtype)
# create binary dispatch tensor that is 1 if the token gets routed to the corresponding expert
# dispatch_tensor shape: [tokens_per_batch, n_experts, expert_capacity]
dispatch_tensor = tf.cast(tf.cast(combine_tensor, tf.bool), inputs.dtype)
return dispatch_tensor, combine_tensor, aux_loss
def load_balance_loss(router_probs, expert_mask):
"""Calculate load-balancing loss to ensure diverse expert routing."""
# router_probs [tokens_per_batch, num_experts] is the probability assigned for
# each expert per token. expert_mask [tokens_per_batch, num_experts] contains
# the expert with the highest router probability in one−hot format.
num_experts = tf.shape(expert_mask)[-1]
# Get the fraction of tokens routed to each expert.
# density is a vector of length num experts that sums to 1.
density = tf.reduce_mean(expert_mask, axis=0)
# Get fraction of probability mass assigned to each expert from the router
# across all tokens. density_proxy is a vector of length num experts that sums to 1.
density_proxy = tf.reduce_mean(router_probs, axis=0)
# Want both vectors to have uniform allocation (1/num experts) across all
# num_expert elements. The two vectors will be pushed towards uniform allocation
# when the dot product is minimized.
loss = tf.reduce_mean(density_proxy * density) * tf.cast(
(num_experts**2), tf.dtypes.float32
)
return loss
def _on_device(device=None, name="layer"):
def decorator(cls_method):
@wraps(cls_method)
def cls_method_wrapper(*args, **kwargs):
if device:
with tf.device(device):
outputs = cls_method(*args, **kwargs)
if outputs is not None:
tf.print(
f"Inputs of {name} with shape {[arg.shape for arg in args if hasattr(arg, 'shape')]} and device placement {device} on device: {[arg.device for arg in args if hasattr(arg, 'device')]}"
)
tf.print(
f"Output of {name} with shape {outputs.shape} and device placement {device} on device: {outputs.device}"
)
return outputs
return cls_method(*args, **kwargs)
return cls_method_wrapper
return decorator
def OnDevice(layer_cls, device=None):
"""Simple class decorator that executes layer build and call methods on a specific device."""
# acts like tf.keras.layers.Wrapper without subclassing and will not be saved with the model
assert isinstance(layer_cls, tf.keras.layers.Layer)
layer_cls.build = _on_device(device, layer_cls.name)(layer_cls.build)
layer_cls.call = _on_device(device, layer_cls.name)(layer_cls.call)
return layer_cls
def build_simple_model_parallelism_model(input_shape, d_hidden, devices):
assert len(devices) >= 2
inputs = tf.keras.Input(shape=input_shape)
dense_1 = OnDevice(
tf.keras.layers.Dense(d_hidden, use_bias=False), device=devices[0]
)
outputs_1 = dense_1(inputs)
dense_2 = OnDevice(
tf.keras.layers.Dense(d_hidden, use_bias=False), device=devices[1]
)
outputs_2 = dense_2(inputs)
add = tf.keras.layers.Add()
outputs = add([outputs_1, outputs_2])
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
def build_simple_expert_parallelism_model(input_shape, d_hidden, devices):
assert len(devices) >= 2
inputs = tf.keras.Input(shape=input_shape)
expert_1 = OnDevice(
tf.keras.layers.Dense(d_hidden, use_bias=False, name="simple_expert_1"),
device=devices[0],
)
expert_2 = OnDevice(
tf.keras.layers.Dense(d_hidden, use_bias=False, name="simple_expert_2"),
device=devices[1],
)
# split half of sequence tokens to one expert and remaining tokens to second expert
token_split = input_shape[0] // 2
expert_1_inputs = inputs[:, :token_split]
expert_2_inputs = inputs[:, token_split:]
expert_1_outputs = expert_1(expert_1_inputs)
expert_2_outputs = expert_2(expert_2_inputs)
combine_experts = tf.keras.layers.Concatenate(axis=1)
outputs = combine_experts([expert_1_outputs, expert_2_outputs])
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
def build_simple_switch_layer_expert_parallelism_model(
input_shape, d_hidden, n_experts, capacity_factor, devices, batch_size=None
):
assert len(devices) >= 2
inputs = tf.keras.Input(shape=input_shape, batch_size=batch_size)
experts_per_device = math.ceil(n_experts / len(devices))
experts = [
OnDevice(
tf.keras.layers.Dense(
d_hidden, use_bias=False, name=f"switch_expert_{i+1}"
),
device=devices[i // experts_per_device],
)
for i in range(n_experts)
]
switch_moe = SwitchMoE(experts, capacity_factor, epsilon=None)
outputs, aux_loss = switch_moe(inputs)
model = tf.keras.Model(inputs=inputs, outputs=outputs)
return model
def matmul_flops(m, p, q, exact=False):
"""Compute floating point operations of a matrix multiplication AB where A[m,p] and B[p,q]."""
if exact:
return (2 * p - 1) * m * q
# cf. https://github.com/tensorflow/tensorflow/pull/19792#issuecomment-415607267
return 2 * m * p * q
def setup_tf():
"""Can only be called at start of program."""
tf.debugging.set_log_device_placement(True)
tf.config.set_soft_device_placement(False)
# specify 2 virtual CPUs to simulate multi-device training
# source: https://www.tensorflow.org/api_docs/python/tf/config/set_logical_device_configuration
physical_devices = tf.config.list_physical_devices("CPU")
assert len(physical_devices) == 1, "No CPUs found"
tf.config.set_logical_device_configuration(
physical_devices[0],
[
tf.config.LogicalDeviceConfiguration(),
tf.config.LogicalDeviceConfiguration(),
],
)
logical_devices = tf.config.list_logical_devices("CPU")
devices = [device.name for device in logical_devices]
print(f"Logical CPU devices: {devices}")
return devices
def create_distibuted_strategy(devices):
"""Created a simple data-parallel mirrored strategy."""
return tf.distribute.MirroredStrategy(devices=devices)
def create_dummy_dataset(n_samples, input_shape, batch_size):
return (
tf.data.Dataset.from_tensors(
(tf.random.normal([n_samples, *input_shape])),
)
.unbatch()
.batch(batch_size)
)
def main():
batch_size = 3
seq_length = 8
d_input = 32
d_hidden = 128
n_experts = 2
capacity_factor = 1.0
input_shape = (seq_length, d_input)
devices = setup_tf()
n_devices = len(devices)
# Benchmark example with simple model-parallelism
# -----------------------------------------------
print("Building simple model")
model = build_simple_model_parallelism_model(
input_shape, d_hidden=d_hidden, devices=devices
)
print("Check device op placement")
x = tf.random.uniform((batch_size, *input_shape))
y = model(x)
print("Compute FLOPS per example")
expected_flops = matmul_flops(seq_length, d_input, d_hidden) # dense_1
expected_flops += matmul_flops(seq_length, d_input, d_hidden) # dense_2
tf2_profiler.profile_flops(model, input_shape=input_shape, batch_size=1)
print(f"Expected float ops: {expected_flops:,}")
print()
# Benchmark 2 expert MoE model example
# ------------------------------------
print("Building 2 expert MoE model")
model = build_simple_expert_parallelism_model(
input_shape, d_hidden=d_hidden, devices=devices
)
print("Check device op placement")
x = tf.random.uniform((batch_size, *input_shape))
y = model(x)
print("Compute FLOPS per example")
expected_flops = matmul_flops(seq_length, d_input, d_hidden) // 2 # expert_1
expected_flops += matmul_flops(seq_length, d_input, d_hidden) // 2 # expert_2
tf2_profiler.profile_flops(model, input_shape=input_shape, batch_size=1)
print(f"Expected float ops: {expected_flops:,}")
print()
#
# ---------
print("Building switch-layer 2 expert MoE model")
model = build_simple_switch_layer_expert_parallelism_model(
input_shape,
d_hidden=d_hidden,
n_experts=n_experts,
capacity_factor=capacity_factor,
devices=devices,
)
print("Check device op placement")
x = tf.random.uniform((batch_size, *input_shape))
y = model(x)
print("Compute FLOPS per example")
expected_flops = matmul_flops(seq_length, d_input, n_experts) # expert routing
expected_flops += matmul_flops(
seq_length, d_input, d_hidden
) # mixture of experts switch gating
single_sample_model = build_simple_switch_layer_expert_parallelism_model(
input_shape,
d_hidden=d_hidden,
n_experts=n_experts,
capacity_factor=capacity_factor,
devices=devices,
batch_size=1,
)
tf2_profiler.profile_flops(
single_sample_model, input_shape=input_shape, batch_size=1
)
print(f"Expected MatMul float ops: {expected_flops:,}")
print()
# TODO show that data is distributed across cores and that
# experts placed on each core cause data to be copied to
# across cores to the location of the expert
strategy = create_distibuted_strategy(devices)
with strategy.scope():
model = build_simple_switch_layer_expert_parallelism_model(
input_shape,
d_hidden=d_hidden,
n_experts=n_experts,
capacity_factor=capacity_factor,
devices=devices,
)
dataset = create_dummy_dataset(batch_size * 3, input_shape, batch_size)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
def distributed_model_fn(inputs):
return model(inputs)
@tf.function
def distributed_model(dist_inputs):
per_replica_outputs = strategy.run(distributed_model_fn, args=(dist_inputs,))
return per_replica_outputs
print("\nCheck device op placement")
for dist_inputs in dist_dataset:
print(f"Dist inputs: {dist_inputs}\n")
print(f"Dist outputs: {distributed_model(dist_inputs)}\n")
break
...
print("\nBuilding 2 expert MoE model with data-parallelism")
# TODO show that data is distributed across cores and that
# experts placed on each core cause data to be copied to
# across cores to the location of the expert
strategy = create_distibuted_strategy(devices)
with strategy.scope():
model = build_simple_expert_parallelism_model(
input_shape, d_hidden=d_hidden, devices=devices
)
dataset = create_dummy_dataset(batch_size * 3, input_shape, batch_size)
dist_dataset = strategy.experimental_distribute_dataset(dataset)
def distributed_model_fn(inputs):
return model(inputs)
@tf.function
def distributed_model(dist_inputs):
per_replica_outputs = strategy.run(distributed_model_fn, args=(dist_inputs,))
return per_replica_outputs
print("\nCheck device op placement")
for dist_inputs in dist_dataset:
print(f"Dist inputs: {dist_inputs}\n")
print(f"Dist outputs: {distributed_model(dist_inputs)}\n")
break
print("Compute FLOPS per example")
expected_flops = (
matmul_flops(seq_length, d_input, d_hidden) // 2
) # dense_1 half input
expected_flops += (
matmul_flops(seq_length, d_input, d_hidden) // 2
) # dense_2 half input
tf2_profiler.profile_flops(model, input_shape=input_shape, batch_size=1)
print(f"Expected float ops: {expected_flops:,}")
n_experts = 4
capacity_factor = 1.0
experts_per_device = math.ceil(n_experts / n_devices)
experts = [
OnDevice(
tf.keras.layers.Dense(d_input, use_bias=False),
device=devices[i // experts_per_device],
)
for i in range(n_experts)
]
inputs = tf.random.normal((batch_size, *input_shape))
router_weights = tf.keras.layers.Dense(
len(experts), use_bias=False, activation=None
)
outputs, aux_loss = uncapped_switch_layer(
inputs,
experts,
router_weights,
capacity_factor,
epsilon=None,
training=None,
mask=None,
)
if __name__ == "__main__":
main()
# sourced from https://github.com/tensorflow/tensorflow/issues/32809#issuecomment-849439287
import tensorflow as tf
from tensorflow.python.profiler.model_analyzer import profile
from tensorflow.python.profiler.option_builder import ProfileOptionBuilder
def profile_flops(
model, input_shape=None, batch_size=1, input_signature=None, cmd="op"
):
"""Profile floating point operations of TensorFlow Keras model."""
if not input_signature:
if not input_shape:
input_shape = model.input_shape[1:]
input_shape = (batch_size,) + input_shape
input_signature = [tf.TensorSpec(shape=input_shape)]
forward_pass = tf.function(model.call, input_signature=input_signature)
graph_info = profile(
forward_pass.get_concrete_function().graph,
options=ProfileOptionBuilder.float_operation(),
cmd=cmd,
)
# NOTE: `profile` counts multiply and accumulate as two flops, instead fused multiply accumulate ops are half this
print(f"Total float ops: {graph_info.total_float_ops:,}")
def profile_flops_and_params_v1_compat(model_func):
def wrapper(*args, **kwargs):
session = tf.compat.v1.Session()
graph = tf.compat.v1.get_default_graph()
with graph.as_default():
with session.as_default():
result = model_func(*args, **kwargs)
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
flops = tf.compat.v1.profiler.profile(
graph=graph, run_meta=run_meta, cmd="op", options=opts
)
opts = (
tf.compat.v1.profiler.ProfileOptionBuilder.trainable_variables_parameter()
)
params = tf.compat.v1.profiler.profile(
graph, run_meta=run_meta, cmd="op", options=opts
)
print(f"Total float ops: {flops.total_float_ops:,}")
print(f"Total parameters: {params.total_parameters}")
return result
return wrapper
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment