Skip to content

Instantly share code, notes, and snippets.

@t-kalinowski
Last active January 31, 2024 16:25
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-kalinowski/62e9a1bbf8d670b712082c1765be4df4 to your computer and use it in GitHub Desktop.
Save t-kalinowski/62e9a1bbf8d670b712082c1765be4df4 to your computer and use it in GitHub Desktop.
LLaMA implemented in R Tensorflow and Keras
## Setup
Sys.setenv(CUDA_VISIBLE_DEVICES='')
options(tensorflow.extract.warn_tensors_passed_asis = FALSE)
library(dplyr, warn.conflicts = FALSE)
library(purrr)
library(glue)
library(envir)
library(tensorflow)
library(tfautograph)
library(keras)
reticulate::use_virtualenv("./.venv", required = TRUE)
attach_eval({
np <- reticulate::import("numpy", convert = FALSE)
import_from(withr, with_options, local_options)
import_from(keras$layers, Dense)
import_from(tf$compiler$tf2xla$python$xla, dynamic_update_slice)
nlist <- \(...) rlang::dots_list(..., .named = TRUE)
seq_len0 <- \(x) seq.int(from = 0L, length.out = x)
})
precompute_rotarty_freqs <- function(seqlen, feature_dim, theta = 10000) {
repeat_each_twice <- function(x)
tf$`repeat`(x, 2L, axis = -1L)
t <- tf$range(seqlen, dtype = tf$float32)
freqs <- tf$range(start = 0, limit = 1,
delta = 1 / (feature_dim %/% 2),
dtype = tf$float32)
tf_assert(tf$size(freqs) == feature_dim %/% 2)
freqs <- 1 / (theta ^ freqs)
# outer product; (seqlen, head_size/2)
freqs <- tf$einsum('a,b->ab', t, freqs)
# prep to recycle across head_size axis and
# broadcast across batch_size and n_heads axes
list(cos = tf$cos(freqs),
sin = tf$sin(freqs)) |>
lapply(repeat_each_twice) |>
lapply(\(m) m[tf$newaxis, , tf$newaxis, ]) # (1, seqlen, 1, head_size)
}
apply_rotary_embedding <- function(x, freqs) {
rotate_every_two <- function(x) {
x1 <- x[all_dims(), `::2`]
x2 <- x[all_dims(), `2::2`]
x_ <- tf$stack(list(-x2, x1), axis = -1L)
tf$reshape(x_, tf$shape(x))
}
(x * freqs$cos) + (rotate_every_two(x) * freqs$sin)
}
make_mask <- function(seqlen, position_index = 0L, dtype = k_floatx()) {
x <- tf$range(seqlen)
i <- x[, tf$newaxis] + position_index
j <- x[tf$newaxis, ]
mask <- tf$where(i < j,
tf$constant(-Inf, dtype = dtype),
tf$constant(0, dtype = dtype))
mask[tf$newaxis, tf$newaxis, , ] # (1, 1, seqlen_q, seqlen_q)
}
RMSNorm(keras$layers$Layer) %py_class% {
initialize <-
function(eps = 1e-6, ..., block_id = NULL, feeds_into = NULL) {
super$initialize(...)
self$eps <- eps
self$block_id <- block_id
self$feeds_into <- feeds_into
}
build <- function(input_shape) {
# input_shape == (batch_size, seqlen, params$dim)
# self$w will broadcast over batch_size and seqlen dims.
# w_shape == (1, 1, params$dim)
w_shape <- rep(1L, length(input_shape))
w_shape[length(input_shape)] <- as.integer(input_shape) |> tail(1L)
# helper that will load
# the pretrained-weights if we supplied `block_id` and `feeds_into`
import_from({self}, block_id, feeds_into)
initializer <- if (is.null(self$block_id))
"ones"
else if (block_id >=0) {
\(...) weights_path("7B/layers.{block_id}.{feeds_into}_norm.weight.npy") |>
np$load() |> np$expand_dims(0:1)
} else if(block_id == -1)
# load weights for the final output norm, which is not part of a TransformerBlock
\(...) weights_path("7B/norm.weight.npy") |>
np$load() |> np$expand_dims(0:1)
self$w <- self$add_weight(shape = w_shape,
initializer = initializer,
trainable = TRUE)
}
rrms <- function(x) {
# reciprocal root mean square along the last axis
x %>%
tf$math$square() %>%
tf$reduce_mean(axis = -1L, keepdims = TRUE) %>%
tf$math$add(self$eps) %>% # for numerical stability
tf$math$rsqrt()
}
call <- function(x) {
x * self$rrms(x) * self$w
}
}
FeedForward(keras$layers$Layer) %py_class% {
initialize <- function(hidden_dim, multiple_of = 256L, ..., block_id = NULL) {
super$initialize()
if(!is.null(multiple_of)) {
hidden_dim <- hidden_dim %>%
{ as.integer( . * (2/3)) } %>%
{ (. + multiple_of - 1) %/% multiple_of } %>%
{ . * multiple_of }
}
self$hidden_dim <- hidden_dim
self$block_id <- block_id
}
build <- function(input_shape) {
output_dim <- input_shape |> as.integer() |> tail(1)
load_weight <- NULL
if(!is.null(self$block_id))
load_weight <- \(name) \(...) np$load(weights_path(
"7B/layers.{self$block_id}.feed_forward.{name}.weight.npy"))$`T`
self$w1 <- Dense(self$hidden_dim, use_bias = FALSE,
kernel_initializer = load_weight("w1"))
self$w2 <- Dense(output_dim, use_bias = FALSE,
kernel_initializer = load_weight("w2"))
self$w3 <- Dense(self$hidden_dim, use_bias = FALSE,
kernel_initializer = load_weight("w3"))
super$build(input_shape)
}
call <- function(x) {
import_from({self}, w1, w2, w3)
import_from(tf$nn, silu)
x %>%
{ silu(w1(.)) * w3(.) } %>% # SwiGLU
w2()
}
}
Attention(keras$layers$Layer) %py_class% {
initialize <- function(head_size, n_heads, ..., block_id = NULL) {
super$initialize(...)
self$head_size <- head_size
self$n_heads <- n_heads
if (is.null(block_id))
load_weight <- function(name) NULL
else
load_weight <- \(name) \(...) np$load(weights_path(
"7B/layers.{block_id}.attention.{name}.weight.npy"))$`T`
Dense <- function(name) keras$layers$Dense(
units = n_heads * head_size,
use_bias = FALSE,
kernel_initializer = load_weight(name)
)
self$wq <- Dense("wq")
self$wk <- Dense("wk")
self$wv <- Dense("wv")
self$wo <- Dense("wo")
}
call <- function(x, ...,
freqs = NULL,
cache = NULL,
cache_index = NULL,
mask = NULL) {
c(batch_size, seqlen_q, n_features) %<-% tf$unstack(tf$shape(x))
seqlen_k <- seqlen_v <- cache_index + seqlen_q
split_heads_shape <- c(batch_size, seqlen_q, self$n_heads, self$head_size)
q <- x |> self$wq() |> tf$reshape(split_heads_shape)
k <- x |> self$wk() |> tf$reshape(split_heads_shape)
v <- x |> self$wv() |> tf$reshape(split_heads_shape)
q %<>% apply_rotary_embedding(freqs) # (bsz, seqlen_q, n_heads, head_size)
k %<>% apply_rotary_embedding(freqs) # (bsz, seqlen_q, n_heads, head_size)
if(!is.null(cache)) {
# append k,v to respective caches; fetch full k,v from cache
cache$k %<>% dynamic_update_slice(k, c(0L, cache_index, 0L, 0L))
cache$v %<>% dynamic_update_slice(v, c(0L, cache_index, 0L, 0L))
k <- cache$k[, NA:seqlen_k, , ]
v <- cache$v[, NA:seqlen_v, , ]
}
v <- tf$transpose(v, c(0L, 2L, 1L, 3L)) # (bsz, n_heads, seqlen_v, head_size)
q <- tf$transpose(q, c(0L, 2L, 1L, 3L)) # (bsz, n_heads, seqlen_q, head_size)
k <- tf$transpose(k, c(0L, 2L, 3L, 1L)) # (bsz, n_heads, head_size, seqlen_k)
scores <- (q %*% k) / sqrt(self$head_size) # (bsz, n_heads, seqlen_q, seqlen_k)
# apply causal mask, so the model can't "look ahead" during training
if (!is.null(mask))
scores %<>% { . + mask }
scores <- tf$nn$softmax(scores, axis = -1L)
# adjust values tensor with attention scores
# scores (bsz, n_heads, seqlen_q, seqlen_k)
# v (bsz, n_heads, seqlen_v, head_size)
output <- scores %*% v # (bsz, n_heads, seqlen_q, head_size)
# combine heads back into a single features dim,
# so Attention output_shape==input_shape
# (needed so that you can add residuals in TransformerBlock)
output <- output |>
tf$transpose(c(0L, 2L, 1L, 3L)) |> # (bsz, seqlen_q, n_heads, head_size)
tf$reshape(c(batch_size, seqlen_q, # (bsz, seqlen_q, n_heads * head_size)
self$n_heads * self$head_size))
# one more trainable linear projection for good luck
output <- self$wo(output) # (bsz, seqlen_q, n_heads * head_size)
if(is.null(cache))
output
else
list(output, cache)
}
}
TransformerBlock(keras$layers$Layer) %py_class% {
initialize <- function(attn_head_size, attn_n_heads,
norm_eps = k_epsilon(), ...,
block_id = NULL) {
super$initialize(...)
self$attention <- Attention(attn_head_size, attn_n_heads,
block_id = block_id)
self$feed_forward <- FeedForward(
hidden_dim = 4 * attn_head_size * attn_n_heads,
block_id = block_id)
self$attention_norm <- RMSNorm(eps = norm_eps, block_id = block_id,
feeds_into = "attention")
self$feed_forward_norm <- RMSNorm(eps = norm_eps, block_id = block_id,
feeds_into = "ffn")
}
call <- function(x, ..., cache = NULL) {
# norm and attention
x2 <- x |>
self$attention_norm() |>
self$attention(..., cache = cache)
# maybe unpack cache returned by Attention
if(!is.null(cache))
c(x2, cache) %<-% x2
x <- x + x2 # add residual
# norm and swiglu projection
x2 <- x %>%
self$feed_forward_norm() %>%
self$feed_forward()
x <- x + x2 # residual again
if(is.null(cache)) x else list(x, cache)
}
}
TransformerDecoder(keras$Model) %py_class% {
initialize <- function(vocab_size, n_blocks, n_heads, head_size, norm_eps) {
super$initialize()
self$head_size <- head_size
self$n_heads <- n_heads
self$tok_embeddings <- keras$layers$Embedding(
input_dim = vocab_size,
output_dim = n_heads*head_size,
embeddings_initializer =
\(...) np$load(weights_path("7B/tok_embeddings.weight.npy")))
self$blocks <- lapply(seq_len0(n_blocks), function(block_id) {
TransformerBlock(attn_head_size = head_size,
attn_n_heads = n_heads,
norm_eps = norm_eps,
block_id = block_id)
})
self$norm <- RMSNorm(block_id = -1, eps = norm_eps)
self$output_proj <- Dense(
vocab_size, use_bias = FALSE,
kernel_initializer = \(...)
np$load(weights_path("7B/output.weight.npy"))$`T`)
self$freqs <- precompute_rotarty_freqs(feature_dim = head_size,
seqlen = 2048L)
}
call <- function(tokens) {
c(bsz, seqlen) %<-% tf$unstack(tf$shape(tokens))
freqs <- self$freqs |> lapply(\(f) f[, NA:seqlen, , ])
mask <- make_mask(seqlen)
x <- tokens |>
self$tok_embeddings()
for (block in self$blocks)
x <- block(x, freqs = freqs, mask = mask)
local_options(c(tensorflow.extract.warn_negatives_pythonic = FALSE))
x |>
self$norm() |>
_[, -1, ] |>
self$output_proj()
}
call_with_cache <- function(tokens, cache, position) {
c(batch_size, seqlen) %<-% tf$unstack(tf$shape(tokens))
# Sanity check: after the initial seeding of cache with the prompt, we
# should only be running inference on one token at a time.
tf_assert(position == 0 | seqlen == 1)
if(is.numeric(position) && position == 0L) {
# initial cache seeding
mask <- make_mask(seqlen)
freqs <- self$freqs |> lapply(\(f) f[, NA:seqlen, , ])
} else {
# inference with one token
position %<>% as_tensor(dtype = "int32")
freqs <- self$freqs |> lapply(\(f) f[, position, , ])
mask <- NULL
}
blocks <- self$blocks
stopifnot(is.list(cache), length(cache) == length(blocks))
x <- tokens |>
self$tok_embeddings()
for (i in seq_along(blocks)) {
c(x, cache[[i]]) %<-% blocks[[i]](x, cache = cache[[i]],
cache_index = position,
freqs = freqs,
mask = mask)
}
local_options(c(tensorflow.extract.warn_negatives_pythonic = FALSE))
output <- x |>
self$norm() |>
_[,-1,] |>
self$output_proj()
list(output, cache)
}
.make_cache <- function(prompt_tokens, max_seqlen = 2048L) {
c(batch_size, seqlen) %<-% tf$unstack(tf$shape(prompt_tokens))
import_from({self}, head_size, n_heads)
max_seqlen <- min(max_seqlen + seqlen, 2048L)
cache_shape <- c(batch_size, max_seqlen, n_heads, head_size)
cache <- lapply(seq_along(self$blocks), \(.) {
list(k = tf$zeros(cache_shape), v = tf$zeros(cache_shape))
})
tokens_with_preallocated_space <-
tf$zeros(c(batch_size, max_seqlen), dtype = "int32") |>
dynamic_update_slice(update = prompt_tokens, indices = c(0L, 0L))
# run first forward pass to seed cache with initial prompt
# return (propmt_tokens, next_token_probs, cache)
c(tokens_with_preallocated_space,
self$call_with_cache(prompt_tokens, cache = cache, position = 0L))
}
private$sampler_fn <- \(logits) logits |>
tf$argmax(axis = -1L, output_type = "int32") |>
tf$expand_dims(-1L)
sampler %<-active% function(fn) {
if(missing(fn))
private$sampler_fn
else
private$sampler_fn <- fn
}
generate <- function(prompt, max_len = 20L) {
max_len %<>% as_tensor("int32")
prompt %<>% as_tensor()
# accept either tokens or a string
if (prompt$dtype$name == "string") {
if(length(dim(prompt)) == 0) # ensure a batch dim
prompt %<>% .[tf$newaxis]
tokens <- tokenizer$tokenize(prompt)$to_tensor()
} else {
tokens <- prompt
if(length(dim(prompt)) == 1) # ensure a batch dim
tokens %<>% .[tf$newaxis, ]
}
c(batch_size, initial_prompt_len) %<-% tf$unstack(tf$shape(tokens))
max_seqlen <- min(max_len + initial_prompt_len, 2048L)
c(tokens, next_token_probs, cache) %<-% self$.make_cache(tokens, max_len)
i <- initial_prompt_len
autograph({
# enable `if` and `for` to accept tensors
for (i in tf$range(initial_prompt_len, max_seqlen, dtype = "int32")) {
next_token <- self$sampler(next_token_probs)
tokens %<>% dynamic_update_slice(next_token, c(0L, i))
if (any(next_token == 2L))
break # end-of-sequence token
c(next_token_probs, cache) %<-%
self$call_with_cache(next_token, cache, i)
}
})
tokens %<>% .[, NA:(i+1)] # drop unused preallocated space
if(prompt$dtype$name == "string")
# return string if supplied a string
tokenizer$detokenize(tokens)
else
tokens
}
}
# ---- load
weights_path <- function(rel_path) {
normalizePath(
file.path(
"~/github/facebookresearch/llama/weights/LLaMA/",
glue::glue(rel_path, .envir = parent.frame())
),
mustWork = TRUE
)
}
params <- jsonlite::read_json(weights_path("7B/params.json"))
tf_text <- reticulate::import("tensorflow_text")
tokenizer_path <- weights_path("tokenizer.model")
tokenizer <- tf_text$SentencepieceTokenizer(
tf$io$gfile$GFile(tokenizer_path, "rb")$read(),
add_bos = TRUE, add_eos = FALSE,
)
llama <- TransformerDecoder(vocab_size = tokenizer$vocab_size(),
n_blocks = params$n_layers,
n_heads = params$n_heads,
head_size = params$dim %/% params$n_heads,
norm_eps = params$norm_eps)
prompt <- "The best way to attract bees"
test_generate <- function() {
prompt |>
tokenizer$tokenize() |>
llama$generate(as_tensor(17L)) |>
tokenizer$detokenize() |>
as.character() |>
strwrap(60) |> writeLines()
}
test_generate()
## expected output with the argmax() sampler:
# The best way to attract bees to your garden is to plant a
# variety of flowers that bloom at different times.
# Timings on CPU:
print(system.time(test_generate()))
# user system elapsed
# 99.562 0.149 89.057
# Compile to XLA
llama$generate %<>% tf_function(jit_compile = TRUE)
# First call includes tracing time
print(system.time(test_generate()))
# user system elapsed
# 64.944 0.809 55.314
# Second call is pure graph mode
print(system.time(generate()))
# user system elapsed
# 28.754 0.120 18.453
@t-kalinowski
Copy link
Author

For details and explanations, see the associated blog post: https://blogs.rstudio.com/ai/posts/2023-05-25-llama-tensorflow-keras/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment