Skip to content

Instantly share code, notes, and snippets.

@jonnyli1125
Last active October 9, 2023 17:56
Show Gist options
  • Save jonnyli1125/e5bab12ed6f36711c57807b7f1528f3a to your computer and use it in GitHub Desktop.
Save jonnyli1125/e5bab12ed6f36711c57807b7f1528f3a to your computer and use it in GitHub Desktop.
RNN Transducer in ~100 lines of NumPy code. Paper: https://arxiv.org/abs/1211.3711
from dataclasses import dataclass
import numpy as np
vocab = [None, 'a', 'b', 'c']
null_idx = 0
V = len(vocab)
@dataclass
class LSTMWeights:
xi: list[list[float]] # (V-1, V)
hi: list[list[float]] # (V, V)
si: list[list[float]] # (V, V)
xf: list[list[float]] # (V-1, V)
hf: list[list[float]] # (V, V)
sf: list[list[float]] # (V, V)
xs: list[list[float]] # (V-1, V)
hs: list[list[float]] # (V, V)
xo: list[list[float]] # (V-1, V)
ho: list[list[float]] # (V, V)
sh: list[list[float]] # (V, V)
@dataclass
class RNNTransducerWeights:
trans_f: LSTMWeights
trans_b: LSTMWeights
pred: LSTMWeights
def one_hot(x: list[int], embed_size: int) -> list[list[float]]:
# transforms sequence of token indexes to a sequence of vectors
# each vector is size embed_size and set to 1 only at the index of the original number, otherwise 0
# example: one_hot([1,2], 4) -> [[0,1,0,0], [0,0,1,0]]
seq = np.zeros((len(x), embed_size))
seq[np.arange(len(x)), x] = 1
return seq
def softmax(x: list[float]) -> list[float]:
# forces x to be a valid probability distribution
# i.e. sum(x) is 1 and x_i is between 0 and 1 for all x_i in x
exp_x = np.exp(x - np.max(x, axis=-1, keepdims=True))
return exp_x / np.sum(exp_x, axis=-1, keepdims=True)
def sigmoid(x: list[float]) -> list[float]:
# forces each element of x to be between 0 and 1
return 1 / (1 + np.exp(-x))
def lstm(seq: list[list[float]], hidden_size: int, W: LSTMWeights) -> list[float]:
input_gate, forget_gate, output_gate, state, hidden = [np.zeros(hidden_size) for i in range(5)]
for x in seq:
input_gate = sigmoid(W.xi.T @ x + W.hi.T @ hidden + W.si.T @ state)
forget_gate = sigmoid(W.xf.T @ x + W.hf.T @ hidden + W.sf.T @ state)
state = forget_gate * state + input_gate * np.tanh(W.xs.T @ x + W.hs.T @ hidden)
output_gate = sigmoid(W.xo.T @ x + W.ho.T @ hidden + W.sh.T @ state)
hidden = output_gate * np.tanh(state)
return hidden
def prediction_network(seq_y: list[list[float]], W: LSTMWeights) -> list[float]:
return lstm(seq_y, V, W)
def transcription_network(seq_x: list[list[float]], W_f: LSTMWeights, W_b: LSTMWeights) -> list[float]:
return lstm(seq_x, V, W_f) + lstm(reversed(seq_x), V, W_b)
def joiner_network(seq_x: list[list[float]], seq_y: list[list[float]], W: RNNTransducerWeights) -> list[float]:
return transcription_network(seq_x, W.trans_f, W.trans_b) + prediction_network(seq_y, W.pred)
def rnn_transducer(seq_x: list[list[float]], seq_y: list[int], W: RNNTransducerWeights) -> list[float]:
seq_y_one_hot = one_hot(seq_y, len(vocab) - 1)
logits = joiner_network(seq_x, seq_y_one_hot, W)
return softmax(logits)
@dataclass
class Hypothesis:
seq: list[int]
logp: float
def decode_beam_search(input_seq: list[list[float]], W: RNNTransducerWeights, beam_size: int = 2) -> str:
B = [Hypothesis([], 0)]
for t in range(len(input_seq)):
A = B
B = []
# prefix boosting
for y in A:
boost_p = 0
for y_hat in A:
if len(y_hat) < len(y) and y_hat == y[:len(y_hat)]:
y_hat_to_y_logp = 0
for u in range(len(y_hat.seq) + 1, len(y.seq)):
log_probs = np.log(rnn_transducer(input_seq[:t], y.seq[:u-1], W))
y_hat_to_y_logp += y_hat.logp + log_probs[y.seq[u]]
boost_p += np.exp(y_hat_to_y_logp)
if boost_p:
y.logp = np.log(np.exp(y.logp) + boost_p)
# main decoding loop
y_star_idx, y_star = max(enumerate(A), key=lambda idx, hyp: hyp.logp)
while len([y for y in B if y.logp > y_star.logp]) < beam_size:
del A[y_star_idx]
log_probs = np.log(rnn_transducer(input_seq[:t], y_star.seq, W))
for k in range(len(log_probs)):
if k == null_idx:
B.append(Hypothesis(y_star.seq, y_star.logp + log_probs[k]))
else:
A.append(Hypothesis(y_star.seq + [k], y_star.logp + log_probs[k]))
y_star_idx, y_star = max(enumerate(A), key=lambda idx, hyp: hyp.logp)
# only keep top `beam_size` elements
B = sorted(B, key=lambda hyp: hyp.logp)[:beam_size]
# take best length normalized hypothesis
best_hyp = max(B, key=lambda hyp: hyp.logp / len(hyp.seq))
# de-tokenize
return ''.join(vocab[idx] for idx in best_hyp.seq)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment