Skip to content

Instantly share code, notes, and snippets.

@bquast
Created April 22, 2023 13:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bquast/016e47b6de4303fb8c86a63b89e7d5b1 to your computer and use it in GitHub Desktop.
Save bquast/016e47b6de4303fb8c86a63b89e7d5b1 to your computer and use it in GitHub Desktop.
library(matrixStats)
# Softmax function
softmax <- function(x) {
exp_x <- exp(x - max(x))
exp_x / sum(exp_x)
}
# Scaled dot product attention
scaled_dot_product_attention <- function(Q, K, V, mask = NULL) {
dk <- ncol(K)
scores <- Q %*% t(K) / sqrt(dk)
if (!is.null(mask)) {
scores <- scores * mask + (1 - mask) * (-1e10)
}
attention_weights <- t(apply(scores, 1, softmax)) # Transpose the result after applying softmax
if (any(is.infinite(attention_weights))) {
attention_weights[is.infinite(attention_weights)] <- 0
}
output <- attention_weights %*% V
return(list(output, attention_weights))
}
# Multi-head attention
multi_head_attention <- function(Q, K, V, d_model, num_heads, mask = NULL) {
depth <- d_model / num_heads
WQ <- matrix(rnorm(d_model * d_model), d_model, d_model)
WK <- matrix(rnorm(d_model * d_model), d_model, d_model)
WV <- matrix(rnorm(d_model * d_model), d_model, d_model)
Q <- Q %*% WQ
K <- K %*% WK
V <- V %*% WV
Qs <- lapply(1:num_heads, function(i) Q[, ((i - 1) * depth + 1):(i * depth)])
Ks <- lapply(1:num_heads, function(i) K[, ((i - 1) * depth + 1):(i * depth)])
Vs <- lapply(1:num_heads, function(i) V[, ((i - 1) * depth + 1):(i * depth)])
outputs <- lapply(1:num_heads, function(i) {
scaled_dot_product_attention(Qs[[i]], Ks[[i]], Vs[[i]], mask)
})
concat_attention <- do.call(cbind, lapply(outputs, function(x) x[[1]]))
WO <- matrix(rnorm(d_model * d_model), d_model, d_model)
output <- concat_attention %*% WO
return(output)
}
# Position-wise feed-forward network
feed_forward <- function(x, dff, d_model) {
W1 <- matrix(rnorm(d_model * dff), d_model, dff)
b1 <- matrix(rnorm(1 * dff), 1, dff)
W2 <- matrix(rnorm(dff * d_model), dff, d_model)
b2 <- matrix(rnorm(1 * d_model), 1, d_model)
hidden <- pmax(x %*% W1 + matrix(rep(b1, nrow(x)), nrow(x), ncol(b1), byrow = TRUE), 0)
output <- hidden %*% W2 + matrix(rep(b2, nrow(x)), nrow(x), ncol(b2), byrow = TRUE)
return(output)
}
# Layer normalization
layer_norm <- function(x, epsilon = 1e-6) {
mean_x <- rowMeans(x)
std_x <- sqrt(rowVars(x) + epsilon)
norm_x <- (x - mean_x) / std_x
return(norm_x)
}
# Transformer layer
transformer_layer <- function(x, d_model, num_heads, dff, mask = NULL) {
attn_output <- multi_head_attention(x, x, x, d_model, num_heads, mask)
x1 <- layer_norm(x + attn_output)
ff_output <- feed_forward(x1, dff, d_model)
x2 <- layer_norm(x1 + ff_output)
return(x2)
}
# Example usage
x <- matrix(rnorm(50 * 512), 50, 512)
d_model <- 512
num_heads <- 8
dff <- 2048
output <- transformer_layer(x, d_model, num_heads, dff)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment