## Setup
options(tensorflow.extract.warn_tensors_passed_asis = FALSE)
library(dplyr, warn.conflicts = FALSE)
reticulate::use_virtualenv("./.venv", required = TRUE)
np <- reticulate::import("numpy", convert = FALSE)
import_from(withr, with_options, local_options)
import_from(keras$layers, Dense)
import_from(tf$compiler$tf2xla$python$xla, dynamic_update_slice)
nlist <- \(...) rlang::dots_list(..., .named = TRUE)
seq_len0 <- \(x) = 0L, length.out = x)
precompute_rotarty_freqs <- function(seqlen, feature_dim, theta = 10000) {
repeat_each_twice <- function(x)
tf$`repeat`(x, 2L, axis = -1L)
t <- tf$range(seqlen, dtype = tf$float32)
freqs <- tf$range(start = 0, limit = 1,
delta = 1 / (feature_dim %/% 2),
dtype = tf$float32)
tf_assert(tf$size(freqs) == feature_dim %/% 2)
freqs <- 1 / (theta ^ freqs)
# outer product; (seqlen, head_size/2)
freqs <- tf$einsum('a,b->ab', t, freqs)
# prep to recycle across head_size axis and
# broadcast across batch_size and n_heads axes
list(cos = tf$cos(freqs),
sin = tf$sin(freqs)) |>
lapply(repeat_each_twice) |>
lapply(\(m) m[tf$newaxis, , tf$newaxis, ]) # (1, seqlen, 1, head_size)
apply_rotary_embedding <- function(x, freqs) {
rotate_every_two <- function(x) {
x1 <- x[all_dims(), `::2`]
x2 <- x[all_dims(), `2::2`]
x_ <- tf$stack(list(-x2, x1), axis = -1L)
tf$reshape(x_, tf$shape(x))
(x * freqs$cos) + (rotate_every_two(x) * freqs$sin)
make_mask <- function(seqlen, position_index = 0L, dtype = k_floatx()) {
x <- tf$range(seqlen)
i <- x[, tf$newaxis] + position_index
j <- x[tf$newaxis, ]
mask <- tf$where(i < j,
tf$constant(-Inf, dtype = dtype),
tf$constant(0, dtype = dtype))
mask[tf$newaxis, tf$newaxis, , ] # (1, 1, seqlen_q, seqlen_q)
RMSNorm(keras$layers$Layer) %py_class% {
initialize <-
function(eps = 1e-6, ..., block_id = NULL, feeds_into = NULL) {
self$eps <- eps
self$block_id <- block_id
self$feeds_into <- feeds_into
build <- function(input_shape) {
# input_shape == (batch_size, seqlen, params$dim)
# self$w will broadcast over batch_size and seqlen dims.
# w_shape == (1, 1, params$dim)
w_shape <- rep(1L, length(input_shape))
w_shape[length(input_shape)] <- as.integer(input_shape) |> tail(1L)
# helper that will load
# the pretrained-weights if we supplied `block_id` and `feeds_into`
import_from({self}, block_id, feeds_into)
initializer <- if (is.null(self$block_id))
else if (block_id >=0) {
\(...) weights_path("7B/layers.{block_id}.{feeds_into}_norm.weight.npy") |>
np$load() |> np$expand_dims(0:1)
} else if(block_id == -1)
# load weights for the final output norm, which is not part of a TransformerBlock
\(...) weights_path("7B/norm.weight.npy") |>
np$load() |> np$expand_dims(0:1)
self$w <- self$add_weight(shape = w_shape,
initializer = initializer,
trainable = TRUE)
rrms <- function(x) {
# reciprocal root mean square along the last axis
x %>%
tf$math$square() %>%
tf$reduce_mean(axis = -1L, keepdims = TRUE) %>%
tf$math$add(self$eps) %>% # for numerical stability
call <- function(x) {
x * self$rrms(x) * self$w
FeedForward(keras$layers$Layer) %py_class% {
initialize <- function(hidden_dim, multiple_of = 256L, ..., block_id = NULL) {
if(!is.null(multiple_of)) {
hidden_dim <- hidden_dim %>%
{ as.integer( . * (2/3)) } %>%
{ (. + multiple_of - 1) %/% multiple_of } %>%
{ . * multiple_of }
self$hidden_dim <- hidden_dim
self$block_id <- block_id
build <- function(input_shape) {
output_dim <- input_shape |> as.integer() |> tail(1)
load_weight <- NULL
load_weight <- \(name) \(...) np$load(weights_path(
self$w1 <- Dense(self$hidden_dim, use_bias = FALSE,
kernel_initializer = load_weight("w1"))
self$w2 <- Dense(output_dim, use_bias = FALSE,
kernel_initializer = load_weight("w2"))
self$w3 <- Dense(self$hidden_dim, use_bias = FALSE,
kernel_initializer = load_weight("w3"))
call <- function(x) {
import_from({self}, w1, w2, w3)
import_from(tf$nn, silu)
x %>%
{ silu(w1(.)) * w3(.) } %>% # SwiGLU
Attention(keras$layers$Layer) %py_class% {
initialize <- function(head_size, n_heads, ..., block_id = NULL) {
self$head_size <- head_size
self$n_heads <- n_heads
if (is.null(block_id))
load_weight <- function(name) NULL
load_weight <- \(name) \(...) np$load(weights_path(
Dense <- function(name) keras$layers$Dense(
units = n_heads * head_size,
use_bias = FALSE,
kernel_initializer = load_weight(name)
self$wq <- Dense("wq")
self$wk <- Dense("wk")
self$wv <- Dense("wv")
self$wo <- Dense("wo")
call <- function(x, ...,
freqs = NULL,
cache = NULL,
cache_index = NULL,
mask = NULL) {
c(batch_size, seqlen_q, n_features) %<-% tf$unstack(tf$shape(x))
seqlen_k <- seqlen_v <- cache_index + seqlen_q
split_heads_shape <- c(batch_size, seqlen_q, self$n_heads, self$head_size)
q <- x |> self$wq() |> tf$reshape(split_heads_shape)
k <- x |> self$wk() |> tf$reshape(split_heads_shape)
v <- x |> self$wv() |> tf$reshape(split_heads_shape)
q %<>% apply_rotary_embedding(freqs) # (bsz, seqlen_q, n_heads, head_size)
k %<>% apply_rotary_embedding(freqs) # (bsz, seqlen_q, n_heads, head_size)
if(!is.null(cache)) {
# append k,v to respective caches; fetch full k,v from cache
cache$k %<>% dynamic_update_slice(k, c(0L, cache_index, 0L, 0L))
cache$v %<>% dynamic_update_slice(v, c(0L, cache_index, 0L, 0L))
k <- cache$k[, NA:seqlen_k, , ]
v <- cache$v[, NA:seqlen_v, , ]
v <- tf$transpose(v, c(0L, 2L, 1L, 3L)) # (bsz, n_heads, seqlen_v, head_size)
q <- tf$transpose(q, c(0L, 2L, 1L, 3L)) # (bsz, n_heads, seqlen_q, head_size)
k <- tf$transpose(k, c(0L, 2L, 3L, 1L)) # (bsz, n_heads, head_size, seqlen_k)
scores <- (q %*% k) / sqrt(self$head_size) # (bsz, n_heads, seqlen_q, seqlen_k)
# apply causal mask, so the model can't "look ahead" during training
if (!is.null(mask))
scores %<>% { . + mask }
scores <- tf$nn$softmax(scores, axis = -1L)
# adjust values tensor with attention scores
# scores (bsz, n_heads, seqlen_q, seqlen_k)
# v (bsz, n_heads, seqlen_v, head_size)
output <- scores %*% v # (bsz, n_heads, seqlen_q, head_size)
# combine heads back into a single features dim,
# so Attention output_shape==input_shape
# (needed so that you can add residuals in TransformerBlock)
output <- output |>
tf$transpose(c(0L, 2L, 1L, 3L)) |> # (bsz, seqlen_q, n_heads, head_size)
tf$reshape(c(batch_size, seqlen_q, # (bsz, seqlen_q, n_heads * head_size)
self$n_heads * self$head_size))
# one more trainable linear projection for good luck
output <- self$wo(output) # (bsz, seqlen_q, n_heads * head_size)
list(output, cache)
TransformerBlock(keras$layers$Layer) %py_class% {
initialize <- function(attn_head_size, attn_n_heads,
norm_eps = k_epsilon(), ...,
block_id = NULL) {
self$attention <- Attention(attn_head_size, attn_n_heads,
block_id = block_id)
self$feed_forward <- FeedForward(
hidden_dim = 4 * attn_head_size * attn_n_heads,
block_id = block_id)
self$attention_norm <- RMSNorm(eps = norm_eps, block_id = block_id,
feeds_into = "attention")
self$feed_forward_norm <- RMSNorm(eps = norm_eps, block_id = block_id,
feeds_into = "ffn")
call <- function(x, ..., cache = NULL) {
# norm and attention
x2 <- x |>
self$attention_norm() |>
self$attention(..., cache = cache)
# maybe unpack cache returned by Attention
c(x2, cache) %<-% x2
x <- x + x2 # add residual
# norm and swiglu projection
x2 <- x %>%
self$feed_forward_norm() %>%
x <- x + x2 # residual again
if(is.null(cache)) x else list(x, cache)
TransformerDecoder(keras$Model) %py_class% {
initialize <- function(vocab_size, n_blocks, n_heads, head_size, norm_eps) {
self$head_size <- head_size
self$n_heads <- n_heads
self$tok_embeddings <- keras$layers$Embedding(
input_dim = vocab_size,
output_dim = n_heads*head_size,
embeddings_initializer =
\(...) np$load(weights_path("7B/tok_embeddings.weight.npy")))
self$blocks <- lapply(seq_len0(n_blocks), function(block_id) {
TransformerBlock(attn_head_size = head_size,
attn_n_heads = n_heads,
norm_eps = norm_eps,
block_id = block_id)
self$norm <- RMSNorm(block_id = -1, eps = norm_eps)
self$output_proj <- Dense(
vocab_size, use_bias = FALSE,
kernel_initializer = \(...)
self$freqs <- precompute_rotarty_freqs(feature_dim = head_size,
seqlen = 2048L)
call <- function(tokens) {
c(bsz, seqlen) %<-% tf$unstack(tf$shape(tokens))
freqs <- self$freqs |> lapply(\(f) f[, NA:seqlen, , ])
mask <- make_mask(seqlen)
x <- tokens |>
for (block in self$blocks)
x <- block(x, freqs = freqs, mask = mask)
local_options(c(tensorflow.extract.warn_negatives_pythonic = FALSE))
x |>
self$norm() |>
_[, -1, ] |>
call_with_cache <- function(tokens, cache, position) {
c(batch_size, seqlen) %<-% tf$unstack(tf$shape(tokens))
# Sanity check: after the initial seeding of cache with the prompt, we
# should only be running inference on one token at a time.
tf_assert(position == 0 | seqlen == 1)
if(is.numeric(position) && position == 0L) {
# initial cache seeding
mask <- make_mask(seqlen)
freqs <- self$freqs |> lapply(\(f) f[, NA:seqlen, , ])
} else {
# inference with one token
position %<>% as_tensor(dtype = "int32")
freqs <- self$freqs |> lapply(\(f) f[, position, , ])
mask <- NULL
blocks <- self$blocks
stopifnot(is.list(cache), length(cache) == length(blocks))
x <- tokens |>
for (i in seq_along(blocks)) {
c(x, cache[[i]]) %<-% blocks[[i]](x, cache = cache[[i]],
cache_index = position,
freqs = freqs,
mask = mask)
local_options(c(tensorflow.extract.warn_negatives_pythonic = FALSE))
output <- x |>
self$norm() |>
_[,-1,] |>
list(output, cache)
.make_cache <- function(prompt_tokens, max_seqlen = 2048L) {
c(batch_size, seqlen) %<-% tf$unstack(tf$shape(prompt_tokens))
import_from({self}, head_size, n_heads)
max_seqlen <- min(max_seqlen + seqlen, 2048L)
cache_shape <- c(batch_size, max_seqlen, n_heads, head_size)
cache <- lapply(seq_along(self$blocks), \(.) {
list(k = tf$zeros(cache_shape), v = tf$zeros(cache_shape))
tokens_with_preallocated_space <-
tf$zeros(c(batch_size, max_seqlen), dtype = "int32") |>
dynamic_update_slice(update = prompt_tokens, indices = c(0L, 0L))
# run first forward pass to seed cache with initial prompt
# return (propmt_tokens, next_token_probs, cache)
self$call_with_cache(prompt_tokens, cache = cache, position = 0L))
private$sampler_fn <- \(logits) logits |>
tf$argmax(axis = -1L, output_type = "int32") |>
sampler %<-active% function(fn) {
private$sampler_fn <- fn
generate <- function(prompt, max_len = 20L) {
max_len %<>% as_tensor("int32")
prompt %<>% as_tensor()
# accept either tokens or a string
if (prompt$dtype$name == "string") {
if(length(dim(prompt)) == 0) # ensure a batch dim
prompt %<>% .[tf$newaxis]
tokens <- tokenizer$tokenize(prompt)$to_tensor()
} else {
tokens <- prompt
if(length(dim(prompt)) == 1) # ensure a batch dim
tokens %<>% .[tf$newaxis, ]
c(batch_size, initial_prompt_len) %<-% tf$unstack(tf$shape(tokens))
max_seqlen <- min(max_len + initial_prompt_len, 2048L)
c(tokens, next_token_probs, cache) %<-% self$.make_cache(tokens, max_len)
i <- initial_prompt_len
# enable `if` and `for` to accept tensors
for (i in tf$range(initial_prompt_len, max_seqlen, dtype = "int32")) {
next_token <- self$sampler(next_token_probs)
tokens %<>% dynamic_update_slice(next_token, c(0L, i))
if (any(next_token == 2L))
break # end-of-sequence token
c(next_token_probs, cache) %<-%
self$call_with_cache(next_token, cache, i)
tokens %<>% .[, NA:(i+1)] # drop unused preallocated space
if(prompt$dtype$name == "string")
# return string if supplied a string
# ---- load
weights_path <- function(rel_path) {
glue::glue(rel_path, .envir = parent.frame())
mustWork = TRUE
params <- jsonlite::read_json(weights_path("7B/params.json"))
tf_text <- reticulate::import("tensorflow_text")
tokenizer_path <- weights_path("tokenizer.model")
tokenizer <- tf_text$SentencepieceTokenizer(
tf$io$gfile$GFile(tokenizer_path, "rb")$read(),
add_bos = TRUE, add_eos = FALSE,
llama <- TransformerDecoder(vocab_size = tokenizer$vocab_size(),
n_blocks = params$n_layers,
n_heads = params$n_heads,
head_size = params$dim %/% params$n_heads,
norm_eps = params$norm_eps)
prompt <- "The best way to attract bees"
test_generate <- function() {
prompt |>
tokenizer$tokenize() |>
llama$generate(as_tensor(17L)) |>
tokenizer$detokenize() |>
as.character() |>
strwrap(60) |> writeLines()
## expected output with the argmax() sampler:
# The best way to attract bees to your garden is to plant a
# variety of flowers that bloom at different times.
# Timings on CPU:
# user system elapsed
# 99.562 0.149 89.057
# Compile to XLA
llama$generate %<>% tf_function(jit_compile = TRUE)
# First call includes tracing time
# user system elapsed
# 64.944 0.809 55.314
# Second call is pure graph mode
# user system elapsed
# 28.754 0.120 18.453
