Last active
February 5, 2023 23:59
-
-
Save jaykru/cbed528751c83c4fce01e49b2d9b1d03 to your computer and use it in GitHub Desktop.
Attempt at an energy minimization network for quadgrams
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"--2023-01-30 21:34:21-- https://raw.githubusercontent.com/karpathy/makemore/master/names.txt\r\n", | |
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2606:50c0:8001::154, 2606:50c0:8003::154, 2606:50c0:8002::154, ...\r\n", | |
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.\r\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"HTTP request sent, awaiting response... " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"200 OK\r\n", | |
"Length: 228145 (223K) [text/plain]\r\n", | |
"Saving to: ‘names.txt.5’\r\n", | |
"\r\n", | |
"\r", | |
"names.txt.5 0%[ ] 0 --.-KB/s " | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"\r", | |
"names.txt.5 100%[===================>] 222.80K --.-KB/s in 0.09s \r\n", | |
"\r\n", | |
"2023-01-30 21:34:22 (2.40 MB/s) - ‘names.txt.5’ saved [228145/228145]\r\n", | |
"\r\n" | |
] | |
} | |
], | |
"source": [ | |
"# Setup code\n", | |
"import torch\n", | |
"import torch.nn.functional as F\n", | |
"import matplotlib.pyplot as plt # for making figures\n", | |
"# torch.set_default_tensor_type('torch.cuda.FloatTensor')\n", | |
"\n", | |
"# download the names.txt file from github\n", | |
"!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt\n", | |
"\n", | |
"words = open('names.txt', 'r').read().splitlines()\n", | |
"\n", | |
"# build the vocabulary of characters and mappings to/from integers\n", | |
"chars = sorted(list(set(''.join(words))))\n", | |
"stoi = {s:i+1 for i,s in enumerate(chars)}\n", | |
"stoi['.'] = 0\n", | |
"itos = {i:s for s,i in stoi.items()}\n", | |
"\n", | |
"# build the dataset\n", | |
"block_size = 3 # context length: how many characters do we take to predict the next one?\n", | |
"\n", | |
"def build_dataset(words):\n", | |
" X, Y = [], []\n", | |
" for w in words:\n", | |
"\n", | |
" #print(w)\n", | |
" context = [0] * block_size\n", | |
" for ch in w + '.':\n", | |
" ix = stoi[ch]\n", | |
" X.append(context)\n", | |
" Y.append(ix)\n", | |
" #print(''.join(itos[i] for i in context), '--->', itos[ix])\n", | |
" context = context[1:] + [ix] # crop and append\n", | |
"\n", | |
" X = torch.tensor(X)\n", | |
" Y = torch.tensor(Y)\n", | |
" return X, Y\n", | |
"\n", | |
"import random\n", | |
"random.seed(42)\n", | |
"random.shuffle(words)\n", | |
"n1 = int(0.8*len(words))\n", | |
"n2 = int(0.9*len(words))\n", | |
"\n", | |
"Xtr, Ytr = build_dataset(words[:n1])\n", | |
"Xdev, Ydev = build_dataset(words[n1:n2])\n", | |
"Xte, Yte = build_dataset(words[n2:])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [], | |
"source": [ | |
"# Network structure\n", | |
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n", | |
"C = torch.randn((27, 10), generator=g)\n", | |
"W1 = torch.randn((40, 200), generator=g)\n", | |
"b1 = torch.randn(200, generator=g)\n", | |
"W2 = torch.randn((200), generator=g)\n", | |
"b2 = torch.randn(1, generator=g)\n", | |
"parameters = [C, W1, b1, W2, b2]\n", | |
"\n", | |
"sum(p.nelement() for p in parameters) # number of parameters in total\n", | |
"\n", | |
"for p in parameters:\n", | |
" p.requires_grad = True" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/jck/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: HIP initialization: Unexpected error from hipGetDeviceCount(). Did you run some cuda functions before calling NumHipDevices() that might have already set an error? Error 101: hipErrorInvalidDevice (Triggered internally at ../c10/hip/HIPFunctions.cpp:110.)\n", | |
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"0 loss: 16.510089874267578\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"1000 loss: 10.871416091918945\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"2000 loss: 7.956619739532471\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"3000 loss: 8.065092086791992\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"4000 loss: 6.2484846115112305\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"5000 loss: 5.066016674041748\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"6000 loss: 5.336665153503418\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"7000 loss: 5.228311061859131\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"8000 loss: 4.683386325836182\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"9000 loss: 4.035111427307129\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"10000 loss: 5.001573085784912\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"11000 loss: 3.6375625133514404\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"12000 loss: 4.359611511230469\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"13000 loss: 3.442596912384033\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"14000 loss: 3.5331358909606934\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"15000 loss: 3.7668650150299072\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"16000 loss: 3.753387451171875\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"17000 loss: 3.27445125579834\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"18000 loss: 3.042189598083496\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"19000 loss: 2.6472935676574707\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"20000 loss: 3.506399631500244\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"21000 loss: 3.478891611099243\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"22000 loss: 3.038069725036621\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"23000 loss: 2.371244192123413\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"24000 loss: 2.6314568519592285\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"25000 loss: 2.9546217918395996\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"26000 loss: 2.4968345165252686\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"27000 loss: 3.300184726715088\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"28000 loss: 2.9361672401428223\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"29000 loss: 3.5350775718688965\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"30000 loss: 2.3028242588043213\n" | |
] | |
}, | |
{ | |
"ename": "KeyboardInterrupt", | |
"evalue": "", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[4], line 41\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m parameters:\n\u001b[1;32m 40\u001b[0m p\u001b[38;5;241m.\u001b[39mgrad \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 41\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;66;03m# update\u001b[39;00m\n\u001b[1;32m 44\u001b[0m lr \u001b[38;5;241m=\u001b[39m lrs[i]\n", | |
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/_tensor.py:488\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 480\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 481\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 486\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 487\u001b[0m )\n\u001b[0;32m--> 488\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n", | |
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:197\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 192\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 197\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n", | |
"\u001b[0;31mKeyboardInterrupt\u001b[0m: " | |
] | |
} | |
], | |
"source": [ | |
"# Training\n", | |
"lre = torch.linspace(-3, 0, 200000)\n", | |
"lrs = 10**lre\n", | |
"\n", | |
"lri = []\n", | |
"lossi = []\n", | |
"stepi = []\n", | |
"\n", | |
"char_ints = [stoi[c] for c in chars]\n", | |
"\n", | |
"for i in range(200000):\n", | |
" # minibatch construct\n", | |
" ix = torch.randint(0, Xtr.shape[0], (32,))\n", | |
" emb = C[torch.cat((Xtr,Ytr.view(-1,1)), dim=1)][ix]\n", | |
" h = torch.tanh(emb.view(-1, 40) @ W1 + b1) # (32, 200)\n", | |
" energies = h @ W2 + b2 # 32 dim'l vector of energies\n", | |
" next_chars = torch.tensor([0] + char_ints).expand(32,-1)\n", | |
"\n", | |
" xs = Xtr[ix]\n", | |
" xs, y = xs.unsqueeze(1).expand(-1,27,-1), next_chars.unsqueeze(2)\n", | |
" seqs_vec = torch.cat((xs,y), dim=2)\n", | |
"\n", | |
" # seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\n", | |
" # of Xtr with an additional element drawn from nextn_chars\n", | |
" seqs = seqs_vec # torch.stack([torch.stack([torch.cat((row, torch.tensor([c])),dim=0) for c in [0] + char_ints]) for row in Xtr[ix]])\n", | |
"\n", | |
" # forward pass\n", | |
" embs = C[seqs]\n", | |
"\n", | |
" hs = torch.tanh(embs.view(-1,40) @ W1 + b1).view(-1,200)\n", | |
" all_energies = (hs @ W2 + b2).view(-1,27)\n", | |
"\n", | |
" # probabilities for each example that the example forms a sequence. we want this to be high.\n", | |
" probs = torch.div(torch.exp(-energies), torch.sum(torch.exp(-all_energies),dim=1))\n", | |
" nll = -torch.log(probs)\n", | |
" loss = torch.mean(nll)\n", | |
"\n", | |
" # backward pass\n", | |
" for p in parameters:\n", | |
" p.grad = None\n", | |
" loss.backward()\n", | |
" \n", | |
" # update\n", | |
" lr = lrs[i]\n", | |
" for p in parameters:\n", | |
" p.data += -lr * p.grad\n", | |
"\n", | |
" # track stats\n", | |
" lri.append(lre[i])\n", | |
" stepi.append(i)\n", | |
" lossi.append(loss.log10().item())\n", | |
"\n", | |
"\n", | |
" if i % 1000 == 0:\n", | |
" print(i, \"loss: \", loss.item())\n", | |
" \n", | |
"print(\"final loss: \", loss.item())\n", | |
"\n", | |
"plt.plot(stepi, lossi)\n", | |
"\n", | |
"# visualize dimensions 0 and 1 of the embedding matrix C for all characters\n", | |
"plt.figure(figsize=(8,8))\n", | |
"plt.scatter(C[:,0].data, C[:,1].data, s=200)\n", | |
"for i in range(C.shape[0]):\n", | |
" plt.text(C[i,0].item(), C[i,1].item(), itos[i], ha=\"center\", va=\"center\", color='white')\n", | |
"plt.grid('minor')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"ename": "RuntimeError", | |
"evalue": "The expanded size of the tensor (27) must match the existing size (3) at non-singleton dimension 1. Target sizes: [-1, 27, -1]. Tensor sizes: [3, 1]", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[8], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# next_chars = torch.tensor([0] + char_ints).expand(32,-1)\u001b[39;00m\n\u001b[1;32m 6\u001b[0m xs \u001b[38;5;241m=\u001b[39m Xtr[ix]\n\u001b[0;32m----> 7\u001b[0m xs, y \u001b[38;5;241m=\u001b[39m \u001b[43mxs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpand\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m27\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m, next_chars\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 8\u001b[0m seqs_vec \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat((xs,y), dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# of Xtr with an additional element drawn from next_chars\u001b[39;00m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: The expanded size of the tensor (27) must match the existing size (3) at non-singleton dimension 1. Target sizes: [-1, 27, -1]. Tensor sizes: [3, 1]" | |
] | |
} | |
], | |
"source": [ | |
"emb = C[torch.cat((Xtr,Ytr.view(-1,1)), dim=1)][ix]\n", | |
"h = torch.tanh(emb.view(-1, 40) @ W1 + b1) # (32, 200)\n", | |
"energies = h @ W2 + b2 # 32 dim'l vector of energies\n", | |
"# next_chars = torch.tensor([0] + char_ints).expand(32,-1)\n", | |
"\n", | |
"xs = Xtr[ix]\n", | |
"xs, y = xs.unsqueeze(1).expand(-1,27,-1), next_chars.unsqueeze(2)\n", | |
"seqs_vec = torch.cat((xs,y), dim=2)\n", | |
"\n", | |
"# seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\n", | |
"# of Xtr with an additional element drawn from next_chars\n", | |
"seqs = torch.stack([torch.stack([torch.cat((row, torch.tensor([c])),dim=0) for c in [0] + char_ints]) for row in Xtr[ix]])\n", | |
"Xtr[ix].unsqueeze(1).expand(-1,27,-1)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"ename": "RuntimeError", | |
"evalue": "The expanded size of the tensor (27) must match the existing size (3) at non-singleton dimension 1. Target sizes: [-1, 27, -1]. Tensor sizes: [3, 1]", | |
"output_type": "error", | |
"traceback": [ | |
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", | |
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)", | |
"Cell \u001b[0;32mIn[9], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# next_chars = torch.tensor([0] + char_ints).expand(32,-1)\u001b[39;00m\n\u001b[1;32m 6\u001b[0m xs \u001b[38;5;241m=\u001b[39m Xtr[ix]\n\u001b[0;32m----> 7\u001b[0m xs, y \u001b[38;5;241m=\u001b[39m \u001b[43mxs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpand\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m27\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m, next_chars\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 8\u001b[0m seqs_vec \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat((xs,y), dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# of Xtr with an additional element drawn from next_chars\u001b[39;00m\n", | |
"\u001b[0;31mRuntimeError\u001b[0m: The expanded size of the tensor (27) must match the existing size (3) at non-singleton dimension 1. Target sizes: [-1, 27, -1]. Tensor sizes: [3, 1]" | |
] | |
} | |
], | |
"source": [ | |
"emb = C[torch.cat((Xtr,Ytr.view(-1,1)), dim=1)][ix]\n", | |
"h = torch.tanh(emb.view(-1, 40) @ W1 + b1) # (32, 200)\n", | |
"energies = h @ W2 + b2 # 32 dim'l vector of energies\n", | |
"# next_chars = torch.tensor([0] + char_ints).expand(32,-1)\n", | |
"\n", | |
"xs = Xtr[ix]\n", | |
"xs, y = xs.unsqueeze(1).expand(-1,27,-1), next_chars.unsqueeze(2)\n", | |
"seqs_vec = torch.cat((xs,y), dim=2)\n", | |
"\n", | |
"# seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\n", | |
"# of Xtr with an additional element drawn from next_chars\n", | |
"seqs = torch.stack([torch.stack([torch.cat((row, torch.tensor([c])),dim=0) for c in [0] + char_ints]) for row in Xtr[ix]])\n", | |
"Xtr[ix].unsqueeze(1).expand(-1,27,-1)\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 20, | |
"metadata": { | |
"collapsed": false | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"po.\n", | |
"rakaal.\n", | |
"niu.\n", | |
"kesha.\n", | |
"asytein.\n", | |
"hioa.\n", | |
"brai.\n", | |
"kassa.\n", | |
"jeday.\n", | |
"zer.\n", | |
"jntyil.\n", | |
"lanyn.\n", | |
"doryyvetninaea.\n", | |
"konarn.\n", | |
"ineiah.\n", | |
"mh.\n", | |
"sipol.\n", | |
"lkl.\n", | |
"u.\n", | |
"bynie.\n", | |
"pyka.\n", | |
"nyyeu.\n", | |
"bsaadalaema.\n", | |
"madii.\n", | |
"oeuee.\n", | |
"adntiyn.\n", | |
"iinia.\n", | |
"dmtyaedes.\n", | |
"angenn.\n", | |
"gyl.\n", | |
"kybydknt.\n", | |
"narette.\n", | |
"arenpsgde.\n", | |
"ka.\n", | |
"b.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"kiit.\n", | |
"kamnq.\n", | |
"gavn.\n", | |
"elaeazin.\n", | |
".\n", | |
"ale.\n", | |
"ftll.\n", | |
"au.\n", | |
"lilu.\n", | |
"hada.\n", | |
"niw.\n", | |
"bern.\n", | |
"flyaneeliel.\n", | |
"dohylam.\n", | |
"orarnnn.\n", | |
"stena.\n", | |
"kysynray.\n", | |
"jok.\n", | |
"fien.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"krinlaiyaedvhaya.\n", | |
"kawie.\n", | |
"jalay.\n", | |
"ied.\n", | |
"ezn.\n", | |
"beoptnthewt.\n", | |
"arvtrhya.\n", | |
"kalmgbe.\n", | |
"jauzelia.\n", | |
"zylh.\n", | |
"z.\n", | |
"cf.\n", | |
"nidn.\n", | |
"roa.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"oawaivcakonwnylaydeiha.\n", | |
"btna.\n", | |
"ami.\n", | |
"fovt.\n", | |
"halzoa.\n", | |
"kianeah.\n", | |
"alyna.\n", | |
"asaomaan.\n", | |
"maiallo.\n", | |
"mabnieayde.\n", | |
"coracselnaem.\n", | |
"zlor.\n", | |
"tohtniy.\n", | |
"jiut.\n", | |
"tkaekai.\n", | |
"bzlennel.\n", | |
"sabicwawldya.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"maxhd.\n", | |
"zorynioh.\n", | |
"rivlnyhmieiaon.\n", | |
"i.\n", | |
"aninee.\n", | |
"maglit.\n", | |
"chni.\n", | |
"ehd.\n", | |
"asaisha.\n", | |
"b.\n", | |
"kaemunn.\n", | |
"kansva.\n", | |
"wynieanuka.\n", | |
"zo.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"nid.\n", | |
"nal.\n", | |
"ayyarolii.\n", | |
"za.\n", | |
"jhda.\n", | |
"anaqir.\n", | |
"ayaezolzaveil.\n", | |
"tres.\n", | |
"kaiha.\n", | |
"qayi.\n", | |
"fovz.\n", | |
"tuasntenkaahe.\n", | |
"roiiah.\n", | |
"anotgae.\n", | |
"ewn.\n", | |
"ldrloryny.\n", | |
"aniiah.\n", | |
"lah.\n", | |
".\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"jofvcmirn.\n", | |
"aqb.\n", | |
"toalaa.\n", | |
"gv.\n", | |
"avmotniyah.\n", | |
"mloatoi.\n", | |
"cnnnusiy.\n", | |
"naelnll.\n", | |
"javagl.\n", | |
"ooqb.\n", | |
"zoxskdrel.\n", | |
"a.\n", | |
"taxyrie.\n", | |
"abniea.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"cieneoyuis.\n", | |
"keu.\n", | |
"jamnllnsmaehi.\n", | |
"emiui.\n", | |
"avcyniarnn.\n", | |
"vayretwiyekamfkmvwei.\n", | |
"biararailon.\n", | |
"reo.\n", | |
"raeodnhany.\n", | |
"iasyuninyni.\n", | |
"amona.\n", | |
"rovhu.\n", | |
"isra.\n", | |
"hi.\n", | |
"zeynn.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"kayena.\n", | |
"eigh.\n", | |
"nil.\n", | |
"sa.\n", | |
"aneanle.\n", | |
"dnn.\n", | |
"hlsynelak.\n", | |
"cwe.\n", | |
"gd.\n", | |
"fay.\n", | |
"dessri.\n", | |
"hgn.\n", | |
"dorh.\n", | |
"sdlan.\n", | |
"abi.\n", | |
"b.\n", | |
"teasyueeln.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"laiio.\n", | |
"euseazan.\n", | |
"eorneilee.\n", | |
"jfa.\n", | |
"abfiedil.\n", | |
"atnliea.\n", | |
"soalku.\n", | |
"zi.\n", | |
"kerdyan.\n", | |
"jayza.\n", | |
"tiia.\n", | |
"p.\n", | |
"ken.\n", | |
"ldanen.\n", | |
"marinnnn.\n", | |
"xakobua.\n", | |
"jah.\n", | |
"arii.\n", | |
"deus.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"saooaidaanna.\n", | |
"tdhunestavarfiiyansyn.\n", | |
"aou.\n", | |
"biaagoles.\n", | |
"ro.\n", | |
"ervhnn.\n", | |
"zeaalaa.\n", | |
"tdin.\n", | |
"esn.\n", | |
"hninillargi.\n", | |
"lislaln.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"ita.\n", | |
"eoe.\n", | |
"maraylsinslema.\n", | |
"dixed.\n", | |
"aydrs.\n", | |
"sateisyld.\n", | |
"po.\n", | |
"rakaal.\n", | |
"niu.\n", | |
"kesha.\n", | |
"asytein.\n", | |
"hioa.\n", | |
"brai.\n", | |
"kassa.\n", | |
"jeday.\n", | |
"zer.\n", | |
"jntyil.\n", | |
"lanyn.\n", | |
"doryyvetninaea.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"konarn.\n", | |
"ineiah.\n", | |
"mh.\n", | |
"sipol.\n", | |
"lkl.\n", | |
"u.\n", | |
"bynie.\n", | |
"pyka.\n", | |
"nyyeu.\n", | |
"bsaadalaema.\n", | |
"madii.\n", | |
"oeuee.\n", | |
"adntiyn.\n", | |
"iinia.\n", | |
"dmtyaedes.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"angenn.\n", | |
"gyl.\n", | |
"kybydknt.\n", | |
"narette.\n", | |
"arenpsgde.\n", | |
"ka.\n", | |
"b.\n", | |
"kiit.\n", | |
"kamnq.\n", | |
"gavn.\n", | |
"elaeazin.\n", | |
".\n", | |
"ale.\n", | |
"ftll.\n", | |
"au.\n", | |
"lilu.\n", | |
"hada.\n", | |
"niw.\n", | |
"bern.\n", | |
"flyaneeliel.\n", | |
"dohylam.\n", | |
"orarnnn.\n" | |
] | |
}, | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"stena.\n", | |
"kysynray.\n", | |
"jok.\n", | |
"fien.\n" | |
] | |
} | |
], | |
"source": [ | |
"# sample from the model\n", | |
"g = torch.Generator() # .manual_seed(2147483647)\n", | |
"import random\n", | |
"\n", | |
"\n", | |
"for _ in range(200):\n", | |
" out = []\n", | |
" context = [0] * block_size # initialize with all ...\n", | |
" while True:\n", | |
" next_chars = torch.tensor([0] + [stoi[c] for c in chars])\n", | |
" seq_options = torch.stack([torch.tensor(context + [c]) for c in next_chars])\n", | |
" # forward pass\n", | |
" embs = C[seq_options] # (1,block_size,d)\n", | |
"\n", | |
" hs = [torch.tanh(emb.view(-1, 40) @ W1 + b1) for emb in embs] # (32, 200)\n", | |
" energies = [h @ W2 + b2 for h in hs] # (32, 27)\n", | |
" energies = torch.tensor(energies)\n", | |
"\n", | |
" probs = torch.exp(-energies) / sum(torch.exp(-energies))\n", | |
" ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n", | |
" context = context[1:] + [ix]\n", | |
" out.append(ix)\n", | |
" if ix == 0:\n", | |
" break\n", | |
"\n", | |
" print(''.join(itos[i] for i in out))\n" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"argv": [ | |
"python", | |
"-m", | |
"ipykernel_launcher", | |
"-f", | |
"{connection_file}" | |
], | |
"display_name": "Python 3 (ipykernel)", | |
"env": null, | |
"interrupt_mode": "signal", | |
"language": "python", | |
"metadata": { | |
"debugger": true | |
}, | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.10.4" | |
}, | |
"name": "energy_mlp.ipynb", | |
"vscode": { | |
"interpreter": { | |
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" | |
} | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment