Skip to content

Instantly share code, notes, and snippets.

@djphoenix
Last active April 7, 2024 10:22
Show Gist options
  • Save djphoenix/da473afc228e73bb8bf4d6eebaf20ae3 to your computer and use it in GitHub Desktop.
Save djphoenix/da473afc228e73bb8bf4d6eebaf20ae3 to your computer and use it in GitHub Desktop.
import time
import torch
import mlx.core as mx
import mlx.nn as mn
import numpy as np
# 1. Generate input
T, B, C = 128, 64, 32
t = T // 2 - 4
atol = 1e-5
torch.manual_seed(1)
logits = torch.randn(T, B, C).requires_grad_()
targets = torch.randint(1, C, (B, t), dtype=torch.long)
input_lengths = torch.randint(T//2, T, (B,), dtype=torch.long)
target_lengths = torch.randint(t//2, t, (B,), dtype=torch.long)
log_probs = logits.log_softmax(dim = -1)
zero = -1e+30
print('Log-probs shape (time X batch X channels):', 'x'.join(map(str, log_probs.shape)))
# 2. Reference output
at = []
while sum(at) < 3:
t1 = time.perf_counter()
builtin_ctc = torch.nn.functional.ctc_loss(
log_probs, targets,
input_lengths, target_lengths,
blank=0, reduction='mean',
)
builtin_ctc_grad, = torch.autograd.grad(builtin_ctc, logits, retain_graph = True)
t2 = time.perf_counter()
at.append(t2-t1)
at = [s for s in sorted(at) if s <= sum(at)/len(at)]
tt = sum(at)/len(at)
print(f'Builtin time: {tt:.3f}s, value={builtin_ctc:.5f}')
# 3. Plain MLX implementation
@mx.compile
def mx_ctc_loss(log_probs: mx.array, targets: mx.array, input_lengths: mx.array, target_lengths: mx.array, blank: int = 0):
input_time_size, batch_size = log_probs.shape[:2]
targets = mx.pad(mx.expand_dims(mx.concatenate([targets, targets[:, :1]], 1), 2), [[0,0],[0,0],[1,0]], blank).flatten(1)
diff_labels = mx.pad(targets[:,2:] != targets[:,:-2], [[0,0],[2,0]])
log_probs_ = mx.take_along_axis(log_probs, mx.tile(targets, (input_time_size, 1, 1)), 2)
log_alpha = mx.full([*log_probs_.shape[:2], log_probs_.shape[2]+2], zero, dtype=log_probs_.dtype)
log_alpha[0, :, 2:4] = log_probs_[0, :, :2]
log_last = log_alpha[0]
for t in range(1, input_time_size):
a = log_last[:, 2: ]
b = log_last[:, 1:-1]
c_ = log_last[:, :-2]
c = mx.where(diff_labels, c_, zero)
lm = mx.maximum(mx.maximum(a, b), c)
log_alpha[t] = (log_last := mx.pad(log_probs_[t] + lm + mx.log(mx.exp(a-lm)+mx.exp(b-lm)+mx.exp(c-lm)), [[0,0],[2,0]], zero))
B = mx.arange(batch_size)
l1 = log_alpha[input_lengths - 1, B, 2 + target_lengths * 2 - 1]
l2 = log_alpha[input_lengths - 1, B, 2 + target_lengths * 2 - 0]
lm = mx.maximum(l1, l2)
return -lm-mx.log(mx.exp(l1-lm)+mx.exp(l2-lm))
with mx.stream(mx.gpu):
mx_logits = mx.array(logits)
mx_targets = mx.array(targets)
mx_target_lengths = mx.array(target_lengths)
mx_input_lengths = mx.array(input_lengths)
at = []
mx_ctc_loss_grad = mx.value_and_grad(lambda p,t,i,l: (mx_ctc_loss(mn.log_softmax(p, -1),t,i,l)/mx_target_lengths).mean())
while sum(at) < 3:
t1 = time.perf_counter()
mlx_ctc, mlx_ctc_grad = mx_ctc_loss_grad(mx_logits, mx_targets, mx_input_lengths, mx_target_lengths)
mx.eval(mlx_ctc, mlx_ctc_grad)
t2 = time.perf_counter()
at.append(t2-t1)
at = [s for s in sorted(at) if s <= sum(at)/len(at)]
print(f'MLX time: {sum(at)/len(at):.3f}s ({sum(at)/len(at)/tt*100:.2f}%) value={(mlx_ctc).item():.5f}')
print('Loss match', torch.allclose(builtin_ctc, torch.tensor(np.array(mlx_ctc.astype(mx.float32))), atol = atol))
print('Grad match', torch.allclose(builtin_ctc_grad, torch.tensor(np.array(mlx_ctc_grad.astype(mx.float32))), atol = atol))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment