Skip to content

Instantly share code, notes, and snippets.

@masuidrive
Last active August 19, 2023 07:55
Show Gist options
  • Star 6 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save masuidrive/d810f2b4c0b52b041a83f5c3a80f0289 to your computer and use it in GitHub Desktop.
Save masuidrive/d810f2b4c0b52b041a83f5c3a80f0289 to your computer and use it in GitHub Desktop.
/*
https://github.com/karpathy/llama2.c/blob/master/run.c
GPT-4による解説
このプログラムは、Transformerネットワークを実装し、トークン化されたテキスト入力から次の最も適したトークンを予測します。具体的には以下のようになります:
先頭の部分は、TransformerWeightsとRunStateという2つのデータ構造とそれらの関連するメモリの管理を含みます。
Configという構造体は、トランスフォーマーネットワークのパラメータを保持します。
次に、指定されたチェックポイントファイルから重みを初期化する関数があります。この関数は、チェックポイントファイルからトランスフォーマーネットワークの重みを読み込み、適切に配置します。
さらに、各種ニューラルネットワークのブロック(関数)が存在します:
- 'accum': ベクトル'a'にベクトル'b'の要素を加えます。
- 'rmsnorm': RMS正規化を行います。
- 'softmax': ソフトマックスを適用します。
- 'matmul': 行列の積算を行います。
'transformer'関数は、全てのレイヤーを通じてトークンをプッシュします。これは、各トークンに対してアテンションとフィードフォワードネットワークを通じて情報を伝播します。
'main'関数では、特定の条件(温度とステップ数)で指定されたチェックポイントファイルを基に入力から次のトークンをサンプリングします。ここで、サンプリングは確率的(温度が0でない場合)または確定的(温度が0の場合)に行われます。
最終的には、このモデルをいくつのトークンを処理できるか、単位時間におけるトークン処理数(トークン/秒)を出力します。
*/
/* https://github.com/karpathy/llama2.c/blob/master/run.c */
/*
Inference for Llama-2 Transformer model in pure C.
Example compile: (see README for more details)
$ gcc -O3 -o run run.c -lm
Then run with:
$ ./run
*/
//各々のライブラリをインポートしています。
#include <stdio.h>
#include <stdlib.h>
#include <time.h>
#include <math.h>
#include <string.h>
#include <unistd.h>
#include <fcntl.h>
#include <sys/mman.h>
// TransformerとRunState構造体、関連するメモリ管理
typedef struct {
int dim; // transformerの次元
int hidden_dim; // ffnレイヤー用
int n_layers; // レイヤーの数
int n_heads; // query headの数
int n_kv_heads; // キー/値 headの数 (multiqueryのためにquery headよりも少なくてもOK)
int vocab_size; // 語彙のサイズ、通常は256 (バイトレベル)
int seq_len; // 最大シーケンス長
} Config;
typedef struct {
// トークン埋め込みテーブル
float* token_embedding_table; // (語彙サイズ, 次元)
// rmsnormsのための重み
float* rms_att_weight; // (レイヤー, 次元) rmsnormの重み
float* rms_ffn_weight; // (レイヤー, 次元)
// matmulsのための重み
float* wq; // (レイヤー, 次元, 次元)
float* wk; // (レイヤー, 次元, 次元)
float* wv; // (レイヤー, 次元, 次元)
float* wo; // (レイヤー, 次元, 次元)
// ffnのための重み
float* w1; // (レイヤー, hidden_dim, 次元)
float* w2; // (レイヤー, 次元, hidden_dim)
float* w3; // (レイヤー, hidden_dim, 次元)
// 最終rmsnorm
float* rms_final_weight; // (次元,)
// RoPE相対位置埋め込みのためのfreq_cis
float* freq_cis_real; // (シーケンス長, 次元/2)
float* freq_cis_imag; // (シーケンス長, 次元/2)
// (オプション)最終レイヤーのlogitsの分類器重み
float* wcls;
} TransformerWeights;
typedef struct {
// 現在の活性化の波
float *x; // カレントタイムスタンプの活性化 (次元,)
float *xb; // 同様、ただし内部に残留ブランチ (次元,)
float *xb2; // 便宜上追加のバッファ (次元,)
float *hb; // ffn内の隠れ次元のバッファ (hidden_dim,)
float *hb2; // ffn内の隠れ次元のバッファ (hidden_dim,)
float *q; // query (次元,)
float *k; // key (次元,)
float *v; // value (次元,)
float *att; // スコア/注目値のバッファ (n_heads, シーケンス長)
float *logits; // 出力logits
// kvキャッシュ
float* key_cache; // (レイヤー, シーケンス長, 次元)
float* value_cache; // (レイヤー, シーケンス長, 次元)
} RunState;
// RunStateの確保と初期化を行う関数
void malloc_run_state(RunState* s, Config* p) {
// valgrindを満足させるためにcallocではなくmallocを使用
s->x = calloc(p->dim, sizeof(float));
s->xb = calloc(p->dim, sizeof(float));
s->xb2 = calloc(p->dim, sizeof(float));
s->hb = calloc(p->hidden_dim, sizeof(float));
s->hb2 = calloc(p->hidden_dim, sizeof(float));
s->q = calloc(p->dim, sizeof(float));
s->k = calloc(p->dim, sizeof(float));
s->v = calloc(p->dim, sizeof(float));
s->att = calloc(p->n_heads * p->seq_len, sizeof(float));
s->logits = calloc(p->vocab_size, sizeof(float));
s->key_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
s->value_cache = calloc(p->n_layers * p->seq_len * p->dim, sizeof(float));
// すべてのmallocが正常に実行されていることを確認します
if (!s->x || !s->xb || !s->xb2 || !s->hb || !s->hb2 || !s->q
|| !s->k || !s->v || !s->att || !s->logits || !s->key_cache
|| !s->value_cache) {
printf("malloc failed!\n");
exit(1);
}
}
// RunStateのメモリ解放関数
void free_run_state(RunState* s) {
free(s->x);
free(s->xb);
free(s->xb2);
free(s->hb);
free(s->hb2);
free(s->q);
free(s->k);
free(s->v);
free(s->att);
free(s->logits);
free(s->key_cache);
free(s->value_cache);
}
// --------------------------------------------------------------
// 初期化:checkpointから読み出し
// checkpointから重みを初期化する関数
void checkpoint_init_weights(TransformerWeights *w, Config* p, float* f, int shared_weights) {
float* ptr = f;
w->token_embedding_table = ptr;
ptr += p->vocab_size * p->dim;
w->rms_att_weight = ptr;
ptr += p->n_layers * p->dim;
w->wq = ptr;
ptr += p->n_layers * p->dim * p->dim;
w->wk = ptr;
ptr += p->n_layers * p->dim * p->dim;
w->wv = ptr;
ptr += p->n_layers * p->dim * p->dim;
w->wo = ptr;
ptr += p->n_layers * p->dim * p->dim;
w->rms_ffn_weight = ptr;
ptr += p->n_layers * p->dim;
w->w1 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
w->w2 = ptr;
ptr += p->n_layers * p->hidden_dim * p->dim;
w->w3 = ptr;
ptr += p->n_layers * p->dim * p->hidden_dim;
w->rms_final_weight = ptr;
ptr += p->dim;
w->freq_cis_real = ptr;
int head_size = p->dim / p->n_heads;
ptr += p->seq_len * head_size / 2;
w->freq_cis_imag = ptr;
ptr += p->seq_len * head_size / 2;
w->wcls = shared_weights ? w->token_embedding_table : ptr;
}
// --------------------------------------------------------------
// ニューラルネットワークのブロック
// ベクトルbをベクトルaに累積する関数
void accum(float *a, float *b, int size) {
// 全ての要素に対してbをaに加算
for (int i = 0; i < size; i++) {
a[i] += b[i];
}
}
// RMS正規化を行う関数
void rmsnorm(float* o, float* x, float* weight, int size) {
// 平方和を計算
float ss = 0.0f;
for (int j = 0; j < size; j++) {
ss += x[j] * x[j];
}
ss /= size;
ss += 1e-5f;
ss = 1.0f / sqrtf(ss);
// 正規化してスケール
for (int j = 0; j < size; j++) {
o[j] = weight[j] * (ss * x[j]);
}
}
// softmax関数です。配列xの全ての要素に対してsoftmaxを計算します。
void softmax(float* x, int size) {
// 最大の値を見つける(数値安定性のため)
float max_val = x[0];
for (int i = 1; i < size; i++) {
if (x[i] > max_val) {
max_val = x[i];
}
}
// expを計算し、合計を求める
float sum = 0.0f;
for (int i = 0; i < size; i++) {
x[i] = expf(x[i] - max_val);
sum += x[i];
}
// 正規化する
for (int i = 0; i < size; i++) {
x[i] /= sum;
}
}
// 行列乗算の計算をします。入力行列xに重み行列wを掛け合わせ、結果をxoutに格納します。
void matmul(float* xout, float* x, float* w, int n, int d) {
// W (d,n) @ x (n,) -> xout (d,)
#pragma omp parallel for
for (int i = 0; i < d; i++) {
float val = 0.0f;
for (int j = 0; j < n; j++) {
val += w[i * n + j] * x[j];
}
xout[i] = val;
}
}
// この関数は、指定されたトークンと位置に基づいてTransformerモデルを前方に進行させます。
// 具体的には、埋め込みのコピー、positionの抽出、すべての層についてのループ(注意スコアの計算、softmaxの適用、
// 重み付き和の計算、残差の接続)、最終的なSoftmaxを経て、分類器によるLogitsの計算を行います。
void transformer(int token, int pos, Config* p, RunState* s, TransformerWeights* w) {
// いくつかの便利な変数を定義します。
float *x = s->x;
int dim = p->dim;
int hidden_dim = p->hidden_dim;
int head_size = dim / p->n_heads;
// トークンの埋め込みをxにコピーします。
float* content_row = &(w->token_embedding_table[token * dim]);
memcpy(x, content_row, dim*sizeof(*x));
// freq_cis_realとfreq_cis_imagの"pos"行を求めます。
float* freq_cis_real_row = w->freq_cis_real + pos * head_size / 2;
float* freq_cis_imag_row = w->freq_cis_imag + pos * head_size / 2;
// 全てのレイヤーについて順に進めます。
for(int l = 0; l < p->n_layers; l++) {
// 注目ベクトルのrmsnorm
rmsnorm(s->xb, x, w->rms_att_weight + l*dim, dim);
// 該当する位置に対するqkv matmuls
matmul(s->q, s->xb, w->wq + l*dim*dim, dim, dim);
matmul(s->k, s->xb, w->wk + l*dim*dim, dim, dim);
matmul(s->v, s->xb, w->wv + l*dim*dim, dim, dim);
// 各ヘッドに対してRoPE回転をqとkベクトルに適用します。
for (int h = 0; h < p->n_heads; h++) {
// このヘッドのqとkベクトルを取ってきます。
float* q = s->q + h * head_size;
float* k = s->k + h * head_size;
// freq_cis_realとfreq_cis_imagによるqとkの回転
for (int i = 0; i < head_size; i+=2) {
float q0 = q[i];
float q1 = q[i+1];
float k0 = k[i];
float k1 = k[i+1];
float fcr = freq_cis_real_row[i/2];
float fci = freq_cis_imag_row[i/2];
q[i] = q0 * fcr - q1 * fci;
q[i+1] = q0 * fci + q1 * fcr;
k[i] = k0 * fcr - k1 * fci;
k[i+1] = k0 * fci + k1 * fcr;
}
}
// このタイムステップ(pos)のキーと値をkvキャッシュに保存します。
int loff = l * p->seq_len * dim;
float* key_cache_row = s->key_cache + loff + pos * dim;
float* value_cache_row = s->value_cache + loff + pos * dim;
memcpy(key_cache_row, s->k, dim*sizeof(*key_cache_row));
memcpy(value_cache_row, s->v, dim*sizeof(*value_cache_row));
// multihead注意処理。全てのヘッドについて繰り返します。
#pragma omp parallel for
for (int h = 0; h < p->n_heads; h++) {
// このヘッドのクエリーベクトルを取得
float* q = s->q + h * head_size;
// このヘッドの注意スコア
float* att = s->att + h * p->seq_len;
// すべてのタイムステップについて繰り返し実行、現在のものを含む
for (int t = 0; t <= pos; t++) {
// このヘッドとこのタイムステップのキーベクトルを取得
float* k = s->key_cache + loff + t * dim + h * head_size;
// 注意スコアをqとkのドット積として計算します
float score = 0.0f;
for (int i = 0; i < head_size; i++) {
score += q[i] * k[i];
}
score /= sqrtf(head_size);
// スコアを注意バッファに保存します
att[t] = score;
}
// 0..posのスコアをsoftmax適用して注意の重みを得る。
softmax(att, pos + 1);
// 値の加重和を計算し、xbに保存します。
for (int i = 0; i < head_size; i++) {
float val = 0.0f;
for (int t = 0; t <= pos; t++) {
val += att[t] * s->value_cache[loff + t * dim + h * head_size + i];
}
s->xb[h * head_size + i] = val;
}
}
// 注意の出力を得るための最終的な行列乗算
matmul(s->xb2, s->xb, w->wo + l*dim*dim, dim, dim);
// レジデュアル接続をxにバック
accum(x, s->xb2, dim);
// フィードフォワードネットワーク用OFRMノーム
rmsnorm(s->xb, x, w->rms_ffn_weight + l*dim, dim);
// PyTorchのFFNでは次のようになります: self.w2(F.silu(self.w1(x)) * self.w3(x))
// 最初にself.w1(x)とself.w3(x)を計算します
matmul(s->hb, s->xb, w->w1 + l*dim*hidden_dim, dim, hidden_dim);
matmul(s->hb2, s->xb, w->w3 + l*dim*hidden_dim, dim, hidden_dim);
// F.silu; silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
for (int i = 0; i < hidden_dim; i++) {
s->hb[i] = s->hb[i] * (1.0f / (1.0f + expf(-s->hb[i])));
}
// w3(x)との要素ごとの積
for (int i = 0; i < hidden_dim; i++) {
s->hb[i] = s->hb[i] * s->hb2[i];
}
// FFNの出力を得るための最終的なmatmul
matmul(s->xb, s->hb, w->w2 + l*dim*hidden_dim, hidden_dim, dim);
// レジデュアル接続
accum(x, s->xb, dim);
}
// 最終的なrmsnorm
rmsnorm(x, x, w->rms_final_weight, dim);
// ロジットへのクラス分類器
matmul(s->logits, x, w->wcls, p->dim, p->vocab_size);
}
// 確率配列からインデックスをサンプリングします。確率の総和は1でなければなりません。
// これにより、softmax関数の出力(確率)に基づいて次のトークンが選択されます
int sample(float* probabilities, int n) {
// sample index from probabilities, they must sum to 1
float r = (float)rand() / (float)RAND_MAX;
float cdf = 0.0f;
for (int i = 0; i < n; i++) {
cdf += probabilities[i];
if (r < cdf) {
return i;
}
}
return n - 1; // in case of rounding errors
}
// 配列 v の中で最も大きい値を持つインデックスを返します。
// これは、temperatureが0の場合に使用されます。(すなわち、確率最大のトークンだけが選択されます)
int argmax(float* v, int n) {
// return argmax of v in elements 0..n
int max_i = 0;
float max_p = v[0];
for (int i = 1; i < n; i++) {
if (v[i] > max_p) {
max_i = i;
max_p = v[i];
}
}
return max_i;
}
// ----------------------------------------------------------------------------
long time_in_ms() {
struct timespec time;
// Get the current time with nanosecond precision
if (clock_gettime(CLOCK_REALTIME, &time) == 0) {
return time.tv_sec * 1000 + time.tv_nsec / 1000000;
} else {
perror("clock_gettime");
return -1; // Return -1 to indicate an error
}
}
// C型の引数解析を行っています
// 必要な引数は 'checkpoint' (モデルの重みが保存されたファイル)です
// オプション引数には、生成するテキストの多様性を制御するための 'temperature' があり、
// また最大のステップ数も設定できます
// 乱数生成器は現在時刻でシードされます。決定論的な動作が望まれる場合は、temperatureを0.0に設定します
//
// 残りの部分は、モデルとトークナイザを読み込み、メモリ内で必要なスペースを確保し、
// トークン生成プロセスが終了するまでトークンを繰り返し生成します。
int main(int argc, char *argv[]) {
// C型の引数解析を行っています
char *checkpoint = NULL; // e.g. out/model.bin
float temperature = 0.9f; // e.g. 1.0, or 0.0
int steps = 256; // max number of steps to run for, 0: use seq_len
// 'checkpoint' is necessary arg
if (argc < 2) {
printf("Usage: %s <checkpoint_file> [temperature] [steps]\n", argv[0]);
return 1;
}
if (argc >= 2) {
checkpoint = argv[1];
}
if (argc >= 3) {
// optional temperature. 0.0 = (deterministic) argmax sampling. 1.0 = baseline
temperature = atof(argv[2]);
}
if (argc >= 4) {
steps = atoi(argv[3]);
}
// seed rng with time. if you want deterministic behavior use temperature 0.0
srand((unsigned int)time(NULL));
// read in the model.bin file
Config config;
TransformerWeights weights;
int fd = 0;
float* data = NULL;
long file_size;
{
FILE *file = fopen(checkpoint, "rb");
if (!file) {
printf("Unable to open the checkpoint file %s!\n", checkpoint);
return 1;
}
// read in the config header
if(fread(&config, sizeof(Config), 1, file) != 1) { return 1; }
// negative vocab size is hacky way of signaling unshared weights. bit yikes.
int shared_weights = config.vocab_size > 0 ? 1 : 0;
config.vocab_size = abs(config.vocab_size);
// figure out the file size
fseek(file, 0, SEEK_END); // move file pointer to end of file
file_size = ftell(file); // get the file size, in bytes
fclose(file);
// memory map the Transformer weights into the data pointer
fd = open(checkpoint, O_RDONLY); // open in read only mode
if (fd == -1) { printf("open failed!\n"); return 1; }
data = mmap(NULL, file_size, PROT_READ, MAP_PRIVATE, fd, 0);
if (data == MAP_FAILED) { printf("mmap failed!\n"); return 1; }
float* weights_ptr = data + sizeof(Config)/sizeof(float);
checkpoint_init_weights(&weights, &config, weights_ptr, shared_weights);
}
// right now we cannot run for more than config.seq_len steps
if (steps <= 0 || steps > config.seq_len) { steps = config.seq_len; }
// read in the tokenizer.bin file
char** vocab = (char**)malloc(config.vocab_size * sizeof(char*));
{
FILE *file = fopen("tokenizer.bin", "rb");
if (!file) {
printf("Unable to open the tokenizer file tokenizer.bin! Run "
"python tokenizer.py to convert tokenizer.model -> tokenizer.bin\n");
return 1;
}
int len;
for (int i = 0; i < config.vocab_size; i++) {
if(fread(&len, sizeof(int), 1, file) != 1) { return 1; }
vocab[i] = (char *)malloc(len + 1);
if(fread(vocab[i], len, 1, file) != 1) { return 1; }
vocab[i][len] = '\0'; // add the string terminating token
}
fclose(file);
}
// create and init the application RunState
RunState state;
malloc_run_state(&state, &config);
// the current position we are in
long start = time_in_ms();
int next;
int token = 1; // 1 = BOS token in Llama-2 sentencepiece
int pos = 0;
printf("<s>\n"); // explicit print the initial BOS token (=1), stylistically symmetric
while (pos < steps) {
// forward the transformer to get logits for the next token
transformer(token, pos, &config, &state, &weights);
// sample the next token
if(temperature == 0.0f) {
// greedy argmax sampling
next = argmax(state.logits, config.vocab_size);
} else {
// apply the temperature to the logits
for (int q=0; q<config.vocab_size; q++) { state.logits[q] /= temperature; }
// apply softmax to the logits to get the probabilities for next token
softmax(state.logits, config.vocab_size);
// we now want to sample from this distribution to get the next token
next = sample(state.logits, config.vocab_size);
}
printf("%s", vocab[next]);
fflush(stdout);
// advance forward
token = next;
pos++;
}
// report achieved tok/s
long end = time_in_ms();
printf("\nachieved tok/s: %f\n", steps / (double)(end-start)*1000);
// memory and file handles cleanup
free_run_state(&state);
for (int i = 0; i < config.vocab_size; i++) { free(vocab[i]); }
free(vocab);
if (data != MAP_FAILED) munmap(data, file_size);
if (fd != -1) close(fd);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment