Install MLX LM and openai
:
pip install mlx-lm openai
# install cudnn so we can use FlashAttention and run fast (optional) | |
# https://developer.nvidia.com/cudnn-downloads | |
# for me, CUDA 12 (run `nvcc --version`) running on Linux x86_64 Ubuntu 22.04 | |
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64/cuda-keyring_1.1-1_all.deb | |
sudo dpkg -i cuda-keyring_1.1-1_all.deb | |
sudo apt-get update | |
sudo apt-get -y install libcudnn9-dev-cuda-12 | |
# "install" cudnn-frontend to ~/ | |
git clone https://github.com/NVIDIA/cudnn-frontend.git |
from datasets import load_dataset
fineweb100b = load_dataset("HuggingFaceFW/fineweb-edu", name="sample-100BT", split="train")
./train_gpt2cu \
-i "dev/data/fineweb10B/fineweb_train_*.bin" \
-j "dev/data/fineweb10B/fineweb_val_*.bin" \
-o log124M \
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA RTX A6000 On | 00000000:05:00.0 Off | Off |
GPTModel( | |
(tok_emb): Embedding(50257, 768) | |
(pos_emb): Embedding(256, 768) | |
(drop_emb): Dropout(p=0.1, inplace=False) | |
(trf_blocks): Sequential( | |
(0): TransformerBlock( | |
(att): MultiHeadAttention( | |
(W_query): Linear(in_features=768, out_features=768, bias=False) | |
(W_key): Linear(in_features=768, out_features=768, bias=False) | |
(W_value): Linear(in_features=768, out_features=768, bias=False) |
GPTModel( | |
(tok_emb): Embedding(50257, 768) | |
(pos_emb): Embedding(1024, 768) | |
(drop_emb): Dropout(p=0.0, inplace=False) | |
(trf_blocks): Sequential( | |
(0): TransformerBlock( | |
(att): MultiHeadAttention( | |
(W_query): LinearWithLoRA( | |
(linear): Linear(in_features=768, out_features=768, bias=True) | |
(lora): LoRALayer() |
+-----------------------------------------------------------------------------------------+ | |
| NVIDIA-SMI 550.54.15 Driver Version: 550.54.15 CUDA Version: 12.4 | | |
|-----------------------------------------+------------------------+----------------------+ | |
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC | | |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. | | |
| | | MIG M. | | |
|=========================================+========================+======================| | |
| 0 NVIDIA GeForce RTX 4070 ... Off | 00000000:02:00.0 Off | N/A | | |
| 30% 60C P2 283W / 285W | 2162MiB / 16376MiB | 100% Default | | |
| | | N/A | |
# single
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA RTX A6000 On | 00000000:05:00.0 Off | Off |
# single
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |
|-----------------------------------------+----------------------+----------------------+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
| | | MIG M. |
|=========================================+======================+======================|
| 0 NVIDIA RTX A6000 On | 00000000:05:00.0 Off | Off |
--- General Information for device 0 --- | |
Name: NVIDIA GeForce RTX 4070 Ti SUPER | |
Compute capability: 8.9 | |
Clock rate: 2610000 | |
Device copy overlap: Enabled | |
Kernel execution timeout : Disabled | |
--- Memory Information for device 0 --- | |
Total global mem: 16852516864 | |
Total constant Mem: 65536 | |
Max mem pitch: 2147483647 |