I implemented a transformer from scratch, with different sampling methods, beam search and caching. Made it while working through https://www.arena.education/.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# %% This file should | |
import torch as t | |
from torch import nn | |
import torch.utils.data as data | |
from torchvision import datasets, transforms | |
from dataclasses import dataclass | |
from typing import Literal | |
from tqdm import tqdm | |
import matplotlib.pyplot as plt | |
from IPython import display |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#lang typed/racket | |
(require typed/rackunit) | |
; Data definitions | |
(define-type ExprC (U NumC IdC AppC IfC LamTC StringC)) | |
(struct NumC ([n : Real])#:transparent) | |
(struct IdC ([s : Symbol])#:transparent) | |
(struct AppC ([fun : ExprC] [arg : (Listof ExprC)]) #:transparent) | |
(struct IfC ([test : ExprC] [then : ExprC] [else : ExprC]) #:transparent) | |
(struct LamTC ([args : (Listof Symbol)] [arg-types : (Listof ty)] [body : ExprC]) #:transparent) |