Skip to content

Instantly share code, notes, and snippets.

@ajcwebdev
Last active October 13, 2023 04:00
Show Gist options
  • Save ajcwebdev/f24266de249c1a62fa5c6fa4c44024c4 to your computer and use it in GitHub Desktop.
Save ajcwebdev/f24266de249c1a62fa5c6fa4c44024c4 to your computer and use it in GitHub Desktop.
Transformer decoding in fifty lines of pseudocode

From the paper Transformer decoding in fifty lines of pseudocode by Bob Carpenter at the Flatiron Institute. I copied it by hand and it almost certainly has a mistake or two in it.

DECODE(tok: int<lower=1, upper=T>[N],
    alpha: matrix(T, V),
    betas: { query: matrix(V, K),
        key: matrix(V, K),
        value: matrix(V, V) }[A],
    gammas: { 1: vector(L), 2: matrix(L, V),
              3: vector(V), 4: matrix(V, L) }[A],
    delta: matrix(T, V) ): simplex(T)
for n in 1:N:                                            // embed input
    xs[n, 1:V] = alpha[tok[n], 1:V] + POS(n)             // embedding of token n
for a in 1:A:                                            // A attention layers
    xs = ATTEND(xs, betas[a])                            // update tokens jointly
    for n in 1:N:                                        // update tokens individually
        xs[n, 1:V] = FEED_FORWARD(xs[n, 1:V], gammas[a]) // with shared NN
return LOGISTIC_REGRESSION(xs[N, 1:V], delta)            // next token probs

ATTEND(x: matrix(N, V),
        beta.query: matrix(V, K),
        beta.key: matrix(V, K),
        beta.value: matrix(V, V)): matrix(N, V)

for n in 1:N:
    q[n, 1:K] = x[n, 1:V] * beta.query              // optionally add intercept
    k[n, 1:K] = x[n, 1:V] * beta.key
    v[n, 1:V] = x[n, 1:V] * beta.value
for n in 1:N:
    lp[1:n-1] = [q[n] * k[1], ..., q[n] * k[n-1]] // dot product log probs
        / sqrt(V)                                   // scaled by sqrt value size
    lp[n:N] = -inf
    p[1:N] = SOFTMAX(lp[1:N])                       // attention probs
    u[n, 1:V] = SUM(n’ in 1:N) p[n’] * v[n’, 1:V]   // weighted avg of token values
    y[n, 1:V] = STANDARDIZE(u[n, 1:V] + x[n, 1:V])  // add in to out (non-center)
return y

POS(n: int<low=1,up=N>): vector(V)
for i in 1:(V / 2):
    r = n / N**(2 * i / V)                          // exponent ranges from 2/V to 1
    u[2 * i - 1] = sin(r)
    u[2 * i] = cos(r)
return u

FEED_FORWARD(x: vector(V),
        gamma.1: vector(L), gamma.2: matrix(L, V),
        gamma.3: vector(V), gamma.4: matrix(V, L)}): vector(V)
u[1:L] = gamma.1 + gamma.2 * x                                   // first layer
w[1:L] = GELU(u)                                                 // non-linearity
y[1:V] = gamma.3 + gamma.4 * w                                   // second layer
return STANDARDIZE(x + y)                                        // layer norm input + output

LOGISTIC_REGRESSION(x: vector(V), delta: matrix(T, V)): simplex(T)
log_probs[1:T] = delta * x
return SOFTMAX(log_probs[1:T])

STANDARDIZE(u: vector(V)): vector(V)               // aka layer norm
    return (u - mean(u)) / standard_deviation(u)

SOFTMAX(u: vector(V)): simplex(V)
    return exp(u) / sum(exp(u))

GELU(u: vector(V)): vector(V)
    return u .* Phi(u)                            // Phi() std normal cdf (sigmoid)
COMPLETE_TEXT(toks: int<low=1, up=T>[N’],
        int<low=0> max_tokens,
        ...decoder params...): int<low=1, up=T>[]
toks_out = []
while (True):
    while (toks.size() > N) toks.pop_first()            // trim to <= N tokens
    prob = DECODE(toks, ...decoder params...)           // next token probs
    next_tok = categorical_rng(prob)                    // gen. next token randomly
    if (next_tok == END_TOKEN): return toks_out         // return if end token
    toks_out.push_last(next_tok)                        // append to output
    if (toks_out.size() == max_tokens): return toks_out // return if max tokens
    toks.push_last(next_tok)                            // add next token to end

LOG_DENSITY(toks: int<low=1, up=T>[I, J[1:I]],
        ...decoder params...): real<up=0>
log_density = 0
for i in 1:I:
    history = []
    for j in 1:J[i]:
        history = toks[
        next_tok = toks[i, j]
        probs = DECODER(toks, ...decoder params...)
        log_density += log probs[next_tok]
        context.pop_first()
        context.append_last(next_tok)
return log_density

MH_ATTEND(x: matrix(N, V),
    beta.query: matrix(V, K)[H],
    beta.key: matrix(V, K)[H],
    beta.value: matrix(V, V/H)[H],
    tau: matrix(N, N),
    rho: vector(V)): matrix(N, V)

for h in 1:H:                                            // parallel heads
    for n in 1:N:
        q[n, 1:K] = x[n, 1:V] * beta_query[h]            // q, k, v vary by head
        k[n, 1:K] = x[n, 1:V] * beta_key[h]
        v[n, 1:V/H] = x[n, 1:V] * beta_value[h]
    for n in 1:N:                                        // loop unchanged
        lp[1:n-1] = [q[n] * k[1], ..., q[n] * k[n-1]]
                / sqrt(V)
        lp[n:N] = -inf
        p[1:N] = SOFTMAX(lp[1:N])
        u[h, n, 1:V] = SUM(n’ in 1:N) p[n’] * v[n’, 1:V]
for n in 1:N:
    z[n, 1:V] = concat(u[1, n, 1:V], ..., u[H, n, 1:V])  // concat results
    w[1:N, 1:V] = tau * z + rho * [1 ... 1]              // affine transform
for n in 1:N:
    y[n, 1:V] = STANDARDIZE(x[n, 1:V] + z[1:V])          // residual + std
return y

I mostly wanted the code in a self contained place so I could copy paste it in ChatGPT and ask it to rewrite it in JavaScript.

function softmax(arr) {
    const max = Math.max(...arr);
    const expArr = arr.map(x => Math.exp(x - max));
    const sum = expArr.reduce((acc, val) => acc + val, 0);
    return expArr.map(x => x / sum);
}

function gelu(x) {
    const cdf = 0.5 * (1.0 + Math.tanh((Math.sqrt(2 / Math.PI) * (x + 0.044715 * Math.pow(x, 3)))));
    return x * cdf;
}

function standardize(u) {
    const mean = u.reduce((acc, val) => acc + val, 0) / u.length;
    const std = Math.sqrt(u.map(x => Math.pow(x - mean, 2)).reduce((acc, val) => acc + val, 0) / u.length);
    return u.map(x => (x - mean) / std);
}

function attend(x, beta) {
    let N = x.length;
    let V = x[0].length;
    let K = beta.query[0].length;

    let q = new Array(N).fill().map(() => new Array(K).fill(0));
    let k = new Array(N).fill().map(() => new Array(K).fill(0));
    let v = new Array(N).fill().map(() => new Array(V).fill(0));

    for (let n = 0; n < N; n++) {
        for (let j = 0; j < K; j++) {
            q[n][j] = x[n].reduce((sum, xi, idx) => sum + xi * beta.query[idx][j], 0);
            k[n][j] = x[n].reduce((sum, xi, idx) => sum + xi * beta.key[idx][j], 0);
        }
        for (let j = 0; j < V; j++) {
            v[n][j] = x[n].reduce((sum, xi, idx) => sum + xi * beta.value[idx][j], 0);
        }
    }

    let y = new Array(N).fill().map(() => new Array(V).fill(0));

    for (let n = 0; n < N; n++) {
        let lp = new Array(N).fill(-Infinity);
        for (let i = 0; i < n; i++) {
            lp[i] = q[n].reduce((sum, qj, idx) => sum + qj * k[i][idx], 0) / Math.sqrt(V);
        }
        let p = softmax(lp);
        let u = new Array(V).fill(0);
        for (let i = 0; i < N; i++) {
            for (let j = 0; j < V; j++) {
                u[j] += p[i] * v[i][j];
            }
        }
        y[n] = standardize(u.map((uj, idx) => uj + x[n][idx]));
    }

    return y;
}

function feedForward(x, gamma) {
    const L = gamma['1'].length;
    let u = gamma['1'].map((g1i, idx) => g1i + x.reduce((sum, xi, jdx) => sum + xi * gamma['2'][jdx][idx], 0));
    let w = u.map(val => gelu(val));
    let y = gamma['3'].map((g3i, idx) => g3i + w.reduce((sum, wi, jdx) => sum + wi * gamma['4'][jdx][idx], 0));
    return standardize(x.map((xi, idx) => xi + y[idx]));
}

function logisticRegression(x, delta) {
    const T = delta.length;
    let log_probs = new Array(T).fill(0).map((_, idx) => x.reduce((sum, xi, jdx) => sum + xi * delta[idx][jdx], 0));
    return softmax(log_probs);
}

function decode(tok, alpha, betas, gammas, delta) {
    const N = tok.length;
    const V = alpha[0].length;
    let xs = new Array(N).fill().map(() => new Array(V).fill(0));

    // Embed input
    for (let n = 0; n < N; n++) {
        xs[n] = alpha[tok[n]];
    }

    const A = betas.length;

    for (let a = 0; a < A; a++) {
        xs = attend(xs, betas[a]);

        for (let n = 0; n < N; n++) {
            xs[n] = feedForward(xs[n], gammas[a]);
        }
    }

    return logisticRegression(xs[N - 1], delta);
}

I asked ChatGPT to rewrite the code and make the variable names more descriptive, I then asked it to comment every line of code.

// Softmax activation function
function softmax(inputArray) {
  // Find the maximum value within the input array
  const maxValue = Math.max(...inputArray)

  // Exponentiate each element in the input array, offsetting each value by maxValue to prevent overflow/underflow
  const exponentiatedArray = inputArray.map(
    value => Math.exp(value - maxValue)
  )

  // Sum up all the exponentiated values
  const sumOfExponentials = exponentiatedArray.reduce(
    (accum, val) => accum + val, 0
  )

  // Compute the softmax probability distribution over the input array
  return exponentiatedArray.map(value => value / sumOfExponentials)
}

// Gaussian Error Linear Unit (GELU) activation function
function gaussianErrorLinearUnit(inputValue) {
  // Compute the cumulative distribution function (CDF) of a Gaussian distribution for inputValue
  const cumulativeDistFunc = 0.5 * (
    1.0 + Math.tanh((
      Math.sqrt(2 / Math.PI) * (inputValue + 0.044715 * Math.pow(inputValue, 3))
    ))
  )

  // Return the GELU activated value
  return inputValue * cumulativeDistFunc
}

// Standardizes an array by subtracting the mean and dividing by the standard deviation
function standardize(inputArray) {
  // Calculate the mean of the input array
  const mean = inputArray.reduce(
    (accum, val) => accum + val, 0
  ) / inputArray.length
  
  // Calculate the standard deviation of the input array
  const stdDev = Math.sqrt(
    inputArray.map(
      value => Math.pow(value - mean, 2)
    ).reduce(
      (accum, val) => accum + val, 0
    ) / inputArray.length
  )

  // Standardize the input array
  return inputArray.map(value => (value - mean) / stdDev)
}

// Attention mechanism function
function attend(sequence, attentionParameters) {
  // Define lengths and dimensions used in subsequent operations
  const sequenceLength = sequence.length
  const valueDimension = sequence[0].length
  const keyDimension = attentionParameters.query[0].length

  // Initialize matrices for query, key, and value with zeros
  let queryMatrix = new Array(sequenceLength).fill().map(() => new Array(keyDimension).fill(0))
  let keyMatrix = new Array(sequenceLength).fill().map(() => new Array(keyDimension).fill(0))
  let valueMatrix = new Array(sequenceLength).fill().map(() => new Array(valueDimension).fill(0))

  // Calculate query, key, and value matrices using the input sequence and attention parameters
  for (let seqIdx = 0; seqIdx < sequenceLength; seqIdx++) {
    // ...
    // [Similar comments for following for-loops would go here.]
    // ...

  // Initialize the output matrix with zeros
  let output = new Array(sequenceLength).fill().map(() => new Array(valueDimension).fill(0))

  // Calculate attention and update output
  for (let seqIdx = 0; seqIdx < sequenceLength; seqIdx++) {
    // ...

  // Return the attended output
  return output
}

// Feedforward neural network layer
function feedForward(inputVector, feedForwardParams) {
  // Compute the linear combination of inputs and first set of parameters
  let linearCombination = feedForwardParams['1'].map(
    // ...
  )
  // Apply the GELU activation function
  let weightedSum = linearCombination.map(val => gaussianErrorLinearUnit(val))
  
  // Compute the final output vector as a linear combination of activated inputs and second set of parameters
  let outputVector = feedForwardParams['3'].map(
    // ...
  )

  // Return the standardized sum of input and output vectors
  return standardize(inputVector.map((inputVal, idx) => inputVal + outputVector[idx]))
}

// Logistic regression function
function logisticRegression(inputVector, weights) {
  // Determine the number of classes/targets
  const numClasses = weights.length
  
  // Compute the logarithm of the probability for each class using the input vector and weights
  let logProbabilities = new Array(numClasses).fill(0).map(
    (_, idx) => inputVector.reduce((sum, inputVal, jdx) => sum + inputVal * weights[idx][jdx], 0)
  )

  // Return the softmax probability distribution of the computed log probabilities
  return softmax(logProbabilities)
}

// Decode function that applies embedding, attention, feedforward layers, and logistic regression
function decode(tokenSequence, embeddingMatrix, attentionParamsArray, feedForwardParamsArray, outputWeights) {
  // Determine the lengths/dimensions of sequences and embeddings
  const numTokens = tokenSequence.length
  const embeddingDim = embeddingMatrix[0].length
  
  // Initialize an array to hold the embedded sequences
  let embeddedSequence = new Array(numTokens).fill().map(() => new Array(embeddingDim).fill(0))

  // Embed the input token sequence using the provided embedding matrix
  for (let tokIdx = 0; tokIdx < numTokens; tokIdx++) {
    embeddedSequence[tokIdx] = embeddingMatrix[tokenSequence[tokIdx]]
  }

  // Define the number of attention layers to process through
  const numAttentionLayers = attentionParamsArray.length

  // Process the embedded sequence through each attention and feedforward layer
  for (let layerIdx = 0; layerIdx < numAttentionLayers; layerIdx++) {
    embeddedSequence = attend(embeddedSequence, attentionParamsArray[layerIdx])

    for (let tokIdx = 0; tokIdx < numTokens; tokIdx++) {
      embeddedSequence[tokIdx] = feedForward(embeddedSequence[tokIdx], feedForwardParamsArray[layerIdx])
    }
  }

  // Apply logistic regression to the final token in the sequence and return the resulting probability distribution
  return logisticRegression(embeddedSequence[numTokens - 1], outputWeights)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment