Skip to content

Instantly share code, notes, and snippets.

@lorenzosinisi
Last active May 6, 2023 15:57
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lorenzosinisi/bb928554d665bdc53aada98c3710b0d5 to your computer and use it in GitHub Desktop.
Save lorenzosinisi/bb928554d665bdc53aada98c3710b0d5 to your computer and use it in GitHub Desktop.

Transformer in Elixir - highly experiemental WIP

This code is in dev mode and not yet finished, it probably won't work but I am using it to learn how to create a transformer from scratch

This gist is shared to help with this tweet https://twitter.com/LorenzoSinisi/status/1652756858459881473

MiniGPT - Elixir

Mix.install(
  [
    {:nx, "~> 0.5.3"},
    {:req, "~> 0.3.6"},
    {:kino_bumblebee, "~> 0.3.0"},
    {:exla, "~> 0.5.1"}
  ],
  config: [nx: [default_backend: EXLA.Backend]]
)

Load tinyshakespeare/input.txt

file_path = Path.absname("./input.txt")

text =
  if File.exists?(file_path) do
    IO.puts("File loaded from memory: #{file_path}")
    File.read!(file_path)
  else
    IO.puts(
      "File loaded from git: https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    )

    Req.get!(
      "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"
    )
  end
defmodule Minidecoder do
  @chars text |> String.codepoints() |> Enum.uniq() |> Enum.sort()
  @vocab_size Enum.count(@chars)
  @stoi Enum.reduce(@chars, %{}, fn ch, acc -> Map.put(acc, ch, Enum.count(acc)) end)
  @itos Enum.reduce(@stoi, %{}, fn {ch, i}, acc -> Map.put(acc, i, ch) end)
  def encode(text) do
    text |> String.codepoints() |> Enum.map(&@stoi[&1])
  end

  def decode(text) do
    text |> Enum.map(&@itos[&1]) |> Enum.join()
  end

  def tensor(text) do
    Nx.tensor(encode(text))
  end
end

data = Minidecoder.tensor(text)

Use 90% for training and 10% of validation

n = Kernel.round(Nx.size(data) * 0.9)
# take from index 0 till the end
train_data = Nx.slice(data, [0], [n])
# take from index 0 for size - n (to get all until end)
val_data = Nx.slice(data, [n], [Nx.size(data) - n])
{train_data, val_data}
block_size = 8
x = Nx.slice(train_data, [0], [block_size])
y = Nx.slice(train_data, [0], [block_size + 1])

Enum.map(0..(block_size - 1), fn t ->
  context = Nx.slice(x, [0], [t + 1])
  target = Nx.slice(y, [t + 1], [1])

  {Nx.to_list(context) |> inspect(charlists: :as_lists),
   Nx.to_list(target) |> inspect(charlists: :as_lists)}
end)
|> Enum.into(Map.new())
|> Enum.reverse()
batch_size = 4
block_size = 8

get_batch = fn split ->
  data = if(split == :train, do: train_data, else: val_data)
  ix = Nx.random_uniform({batch_size}, 0, Nx.size(data) - block_size)
  ix = Nx.to_list(ix)
  x = Enum.map(ix, fn i -> Nx.slice(data, [i], [block_size]) end) |> Nx.stack()
  y = Enum.map(ix, fn i -> Nx.slice(data, [i + 1], [block_size]) end) |> Nx.stack()
  {x, y}
end

get_batch.(:train)
defmodule BigramLanguageModel do
  alias Nx, as: N

  defstruct [:params, :state, :opts]

  def zeroes(x, y) do
    Nx.random_uniform({y, y}, 0, 0)
  end

  def create(vocab_size) do
    Nx.random_uniform({vocab_size, vocab_size})
  end

  def slice_last_token_logits(logits) do
    {batch_size, seq_length, vocab_size} = logits.shape
    start_indices = [0, seq_length - 1, 0]
    lengths = [batch_size, 1, vocab_size]
    Nx.slice(logits, start_indices, lengths)
  end

  def softmax(t, axis) do
    exp_t = Nx.exp(t)
    sum_exp_t = Nx.sum(exp_t, axes: [axis], keep_axes: true)
    Nx.divide(exp_t, sum_exp_t)
  end

  def generate(model, idx, max_new_tokens) do
    Enum.reduce(1..max_new_tokens, idx, fn _, idx ->
      {logits, _loss} = forward(model, idx)
      logits = slice_last_token_logits(logits)
      probs = softmax(logits, -1)
      Nx.concatenate([idx, Nx.argmax(probs, axis: -1)], axis: -1)
    end)
  end

  def forward(model, idx, targets \\ nil) do
    logits = Nx.take(model, idx)

    if is_nil(targets) do
      {logits, nil}
    else
      {b, t, c} = N.shape(logits)
      reshaped_logits = N.reshape(logits, {b * t, c})
      reshaped_targets = N.reshape(targets, {b * t})

      loss =
        Axon.Losses.categorical_cross_entropy(reshaped_targets, reshaped_logits,
          sparse: true,
          from_logits: true,
          reduction: :mean
        )

      {logits, loss}
    end
  end
end
model = BigramLanguageModel.create(65)
{xb, yb} = get_batch.(:train)
BigramLanguageModel.forward(model, xb, yb)
max_new_tokens = 1
result = BigramLanguageModel.generate(model, xb, max_new_tokens) |> Nx.flatten() |> Nx.to_list()
Minidecoder.decode(result)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment