Skip to content

Instantly share code, notes, and snippets.

@jaykru
Last active February 5, 2023 23:59
Show Gist options
  • Save jaykru/cbed528751c83c4fce01e49b2d9b1d03 to your computer and use it in GitHub Desktop.
Save jaykru/cbed528751c83c4fce01e49b2d9b1d03 to your computer and use it in GitHub Desktop.
Attempt at an energy minimization network for quadgrams
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--2023-01-30 21:34:21-- https://raw.githubusercontent.com/karpathy/makemore/master/names.txt\r\n",
"Resolving raw.githubusercontent.com (raw.githubusercontent.com)... "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2606:50c0:8001::154, 2606:50c0:8003::154, 2606:50c0:8002::154, ...\r\n",
"Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8001::154|:443... connected.\r\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"HTTP request sent, awaiting response... "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"200 OK\r\n",
"Length: 228145 (223K) [text/plain]\r\n",
"Saving to: ‘names.txt.5’\r\n",
"\r\n",
"\r",
"names.txt.5 0%[ ] 0 --.-KB/s "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"names.txt.5 100%[===================>] 222.80K --.-KB/s in 0.09s \r\n",
"\r\n",
"2023-01-30 21:34:22 (2.40 MB/s) - ‘names.txt.5’ saved [228145/228145]\r\n",
"\r\n"
]
}
],
"source": [
"# Setup code\n",
"import torch\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt # for making figures\n",
"# torch.set_default_tensor_type('torch.cuda.FloatTensor')\n",
"\n",
"# download the names.txt file from github\n",
"!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt\n",
"\n",
"words = open('names.txt', 'r').read().splitlines()\n",
"\n",
"# build the vocabulary of characters and mappings to/from integers\n",
"chars = sorted(list(set(''.join(words))))\n",
"stoi = {s:i+1 for i,s in enumerate(chars)}\n",
"stoi['.'] = 0\n",
"itos = {i:s for s,i in stoi.items()}\n",
"\n",
"# build the dataset\n",
"block_size = 3 # context length: how many characters do we take to predict the next one?\n",
"\n",
"def build_dataset(words):\n",
" X, Y = [], []\n",
" for w in words:\n",
"\n",
" #print(w)\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" #print(''.join(itos[i] for i in context), '--->', itos[ix])\n",
" context = context[1:] + [ix] # crop and append\n",
"\n",
" X = torch.tensor(X)\n",
" Y = torch.tensor(Y)\n",
" return X, Y\n",
"\n",
"import random\n",
"random.seed(42)\n",
"random.shuffle(words)\n",
"n1 = int(0.8*len(words))\n",
"n2 = int(0.9*len(words))\n",
"\n",
"Xtr, Ytr = build_dataset(words[:n1])\n",
"Xdev, Ydev = build_dataset(words[n1:n2])\n",
"Xte, Yte = build_dataset(words[n2:])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"collapsed": false
},
"outputs": [],
"source": [
"# Network structure\n",
"g = torch.Generator().manual_seed(2147483647) # for reproducibility\n",
"C = torch.randn((27, 10), generator=g)\n",
"W1 = torch.randn((40, 200), generator=g)\n",
"b1 = torch.randn(200, generator=g)\n",
"W2 = torch.randn((200), generator=g)\n",
"b2 = torch.randn(1, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"sum(p.nelement() for p in parameters) # number of parameters in total\n",
"\n",
"for p in parameters:\n",
" p.requires_grad = True"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/jck/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:197: UserWarning: HIP initialization: Unexpected error from hipGetDeviceCount(). Did you run some cuda functions before calling NumHipDevices() that might have already set an error? Error 101: hipErrorInvalidDevice (Triggered internally at ../c10/hip/HIPFunctions.cpp:110.)\n",
" Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 loss: 16.510089874267578\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"1000 loss: 10.871416091918945\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"2000 loss: 7.956619739532471\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"3000 loss: 8.065092086791992\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"4000 loss: 6.2484846115112305\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"5000 loss: 5.066016674041748\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"6000 loss: 5.336665153503418\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"7000 loss: 5.228311061859131\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"8000 loss: 4.683386325836182\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"9000 loss: 4.035111427307129\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"10000 loss: 5.001573085784912\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"11000 loss: 3.6375625133514404\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"12000 loss: 4.359611511230469\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"13000 loss: 3.442596912384033\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"14000 loss: 3.5331358909606934\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"15000 loss: 3.7668650150299072\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"16000 loss: 3.753387451171875\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"17000 loss: 3.27445125579834\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"18000 loss: 3.042189598083496\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"19000 loss: 2.6472935676574707\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"20000 loss: 3.506399631500244\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"21000 loss: 3.478891611099243\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"22000 loss: 3.038069725036621\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"23000 loss: 2.371244192123413\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"24000 loss: 2.6314568519592285\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"25000 loss: 2.9546217918395996\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"26000 loss: 2.4968345165252686\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"27000 loss: 3.300184726715088\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"28000 loss: 2.9361672401428223\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"29000 loss: 3.5350775718688965\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"30000 loss: 2.3028242588043213\n"
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 41\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m p \u001b[38;5;129;01min\u001b[39;00m parameters:\n\u001b[1;32m 40\u001b[0m p\u001b[38;5;241m.\u001b[39mgrad \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m---> 41\u001b[0m \u001b[43mloss\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 43\u001b[0m \u001b[38;5;66;03m# update\u001b[39;00m\n\u001b[1;32m 44\u001b[0m lr \u001b[38;5;241m=\u001b[39m lrs[i]\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/_tensor.py:488\u001b[0m, in \u001b[0;36mTensor.backward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 478\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m has_torch_function_unary(\u001b[38;5;28mself\u001b[39m):\n\u001b[1;32m 479\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m handle_torch_function(\n\u001b[1;32m 480\u001b[0m Tensor\u001b[38;5;241m.\u001b[39mbackward,\n\u001b[1;32m 481\u001b[0m (\u001b[38;5;28mself\u001b[39m,),\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 486\u001b[0m inputs\u001b[38;5;241m=\u001b[39minputs,\n\u001b[1;32m 487\u001b[0m )\n\u001b[0;32m--> 488\u001b[0m \u001b[43mtorch\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mautograd\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mbackward\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 489\u001b[0m \u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgradient\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43minputs\u001b[49m\n\u001b[1;32m 490\u001b[0m \u001b[43m\u001b[49m\u001b[43m)\u001b[49m\n",
"File \u001b[0;32m~/.local/lib/python3.10/site-packages/torch/autograd/__init__.py:197\u001b[0m, in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 192\u001b[0m retain_graph \u001b[38;5;241m=\u001b[39m create_graph\n\u001b[1;32m 194\u001b[0m \u001b[38;5;66;03m# The reason we repeat same the comment below is that\u001b[39;00m\n\u001b[1;32m 195\u001b[0m \u001b[38;5;66;03m# some Python versions print out the first line of a multi-line function\u001b[39;00m\n\u001b[1;32m 196\u001b[0m \u001b[38;5;66;03m# calls in the traceback and some print out the last line\u001b[39;00m\n\u001b[0;32m--> 197\u001b[0m \u001b[43mVariable\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_execution_engine\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mrun_backward\u001b[49m\u001b[43m(\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;66;43;03m# Calls into the C++ engine to run the backward pass\u001b[39;49;00m\n\u001b[1;32m 198\u001b[0m \u001b[43m \u001b[49m\u001b[43mtensors\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mgrad_tensors_\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mretain_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mcreate_graph\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 199\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_unreachable\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43maccumulate_grad\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# Training\n",
"lre = torch.linspace(-3, 0, 200000)\n",
"lrs = 10**lre\n",
"\n",
"lri = []\n",
"lossi = []\n",
"stepi = []\n",
"\n",
"char_ints = [stoi[c] for c in chars]\n",
"\n",
"for i in range(200000):\n",
" # minibatch construct\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
" emb = C[torch.cat((Xtr,Ytr.view(-1,1)), dim=1)][ix]\n",
" h = torch.tanh(emb.view(-1, 40) @ W1 + b1) # (32, 200)\n",
" energies = h @ W2 + b2 # 32 dim'l vector of energies\n",
" next_chars = torch.tensor([0] + char_ints).expand(32,-1)\n",
"\n",
" xs = Xtr[ix]\n",
" xs, y = xs.unsqueeze(1).expand(-1,27,-1), next_chars.unsqueeze(2)\n",
" seqs_vec = torch.cat((xs,y), dim=2)\n",
"\n",
" # seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\n",
" # of Xtr with an additional element drawn from nextn_chars\n",
" seqs = seqs_vec # torch.stack([torch.stack([torch.cat((row, torch.tensor([c])),dim=0) for c in [0] + char_ints]) for row in Xtr[ix]])\n",
"\n",
" # forward pass\n",
" embs = C[seqs]\n",
"\n",
" hs = torch.tanh(embs.view(-1,40) @ W1 + b1).view(-1,200)\n",
" all_energies = (hs @ W2 + b2).view(-1,27)\n",
"\n",
" # probabilities for each example that the example forms a sequence. we want this to be high.\n",
" probs = torch.div(torch.exp(-energies), torch.sum(torch.exp(-all_energies),dim=1))\n",
" nll = -torch.log(probs)\n",
" loss = torch.mean(nll)\n",
"\n",
" # backward pass\n",
" for p in parameters:\n",
" p.grad = None\n",
" loss.backward()\n",
" \n",
" # update\n",
" lr = lrs[i]\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
"\n",
" # track stats\n",
" lri.append(lre[i])\n",
" stepi.append(i)\n",
" lossi.append(loss.log10().item())\n",
"\n",
"\n",
" if i % 1000 == 0:\n",
" print(i, \"loss: \", loss.item())\n",
" \n",
"print(\"final loss: \", loss.item())\n",
"\n",
"plt.plot(stepi, lossi)\n",
"\n",
"# visualize dimensions 0 and 1 of the embedding matrix C for all characters\n",
"plt.figure(figsize=(8,8))\n",
"plt.scatter(C[:,0].data, C[:,1].data, s=200)\n",
"for i in range(C.shape[0]):\n",
" plt.text(C[i,0].item(), C[i,1].item(), itos[i], ha=\"center\", va=\"center\", color='white')\n",
"plt.grid('minor')"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"collapsed": false
},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "The expanded size of the tensor (27) must match the existing size (3) at non-singleton dimension 1. Target sizes: [-1, 27, -1]. Tensor sizes: [3, 1]",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[8], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# next_chars = torch.tensor([0] + char_ints).expand(32,-1)\u001b[39;00m\n\u001b[1;32m 6\u001b[0m xs \u001b[38;5;241m=\u001b[39m Xtr[ix]\n\u001b[0;32m----> 7\u001b[0m xs, y \u001b[38;5;241m=\u001b[39m \u001b[43mxs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpand\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m27\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m, next_chars\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 8\u001b[0m seqs_vec \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat((xs,y), dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# of Xtr with an additional element drawn from next_chars\u001b[39;00m\n",
"\u001b[0;31mRuntimeError\u001b[0m: The expanded size of the tensor (27) must match the existing size (3) at non-singleton dimension 1. Target sizes: [-1, 27, -1]. Tensor sizes: [3, 1]"
]
}
],
"source": [
"emb = C[torch.cat((Xtr,Ytr.view(-1,1)), dim=1)][ix]\n",
"h = torch.tanh(emb.view(-1, 40) @ W1 + b1) # (32, 200)\n",
"energies = h @ W2 + b2 # 32 dim'l vector of energies\n",
"# next_chars = torch.tensor([0] + char_ints).expand(32,-1)\n",
"\n",
"xs = Xtr[ix]\n",
"xs, y = xs.unsqueeze(1).expand(-1,27,-1), next_chars.unsqueeze(2)\n",
"seqs_vec = torch.cat((xs,y), dim=2)\n",
"\n",
"# seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\n",
"# of Xtr with an additional element drawn from next_chars\n",
"seqs = torch.stack([torch.stack([torch.cat((row, torch.tensor([c])),dim=0) for c in [0] + char_ints]) for row in Xtr[ix]])\n",
"Xtr[ix].unsqueeze(1).expand(-1,27,-1)\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"collapsed": false
},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "The expanded size of the tensor (27) must match the existing size (3) at non-singleton dimension 1. Target sizes: [-1, 27, -1]. Tensor sizes: [3, 1]",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 7\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# next_chars = torch.tensor([0] + char_ints).expand(32,-1)\u001b[39;00m\n\u001b[1;32m 6\u001b[0m xs \u001b[38;5;241m=\u001b[39m Xtr[ix]\n\u001b[0;32m----> 7\u001b[0m xs, y \u001b[38;5;241m=\u001b[39m \u001b[43mxs\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43munsqueeze\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mexpand\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m27\u001b[39;49m\u001b[43m,\u001b[49m\u001b[38;5;241;43m-\u001b[39;49m\u001b[38;5;241;43m1\u001b[39;49m\u001b[43m)\u001b[49m, next_chars\u001b[38;5;241m.\u001b[39munsqueeze(\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 8\u001b[0m seqs_vec \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mcat((xs,y), dim\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m2\u001b[39m)\n\u001b[1;32m 10\u001b[0m \u001b[38;5;66;03m# seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\u001b[39;00m\n\u001b[1;32m 11\u001b[0m \u001b[38;5;66;03m# of Xtr with an additional element drawn from next_chars\u001b[39;00m\n",
"\u001b[0;31mRuntimeError\u001b[0m: The expanded size of the tensor (27) must match the existing size (3) at non-singleton dimension 1. Target sizes: [-1, 27, -1]. Tensor sizes: [3, 1]"
]
}
],
"source": [
"emb = C[torch.cat((Xtr,Ytr.view(-1,1)), dim=1)][ix]\n",
"h = torch.tanh(emb.view(-1, 40) @ W1 + b1) # (32, 200)\n",
"energies = h @ W2 + b2 # 32 dim'l vector of energies\n",
"# next_chars = torch.tensor([0] + char_ints).expand(32,-1)\n",
"\n",
"xs = Xtr[ix]\n",
"xs, y = xs.unsqueeze(1).expand(-1,27,-1), next_chars.unsqueeze(2)\n",
"seqs_vec = torch.cat((xs,y), dim=2)\n",
"\n",
"# seqs should be a (32, 26, 4) tensor of whose last layer of rows are the rows\n",
"# of Xtr with an additional element drawn from next_chars\n",
"seqs = torch.stack([torch.stack([torch.cat((row, torch.tensor([c])),dim=0) for c in [0] + char_ints]) for row in Xtr[ix]])\n",
"Xtr[ix].unsqueeze(1).expand(-1,27,-1)\n"
]
},
{
"cell_type": "code",
"execution_count": 20,
"metadata": {
"collapsed": false
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"po.\n",
"rakaal.\n",
"niu.\n",
"kesha.\n",
"asytein.\n",
"hioa.\n",
"brai.\n",
"kassa.\n",
"jeday.\n",
"zer.\n",
"jntyil.\n",
"lanyn.\n",
"doryyvetninaea.\n",
"konarn.\n",
"ineiah.\n",
"mh.\n",
"sipol.\n",
"lkl.\n",
"u.\n",
"bynie.\n",
"pyka.\n",
"nyyeu.\n",
"bsaadalaema.\n",
"madii.\n",
"oeuee.\n",
"adntiyn.\n",
"iinia.\n",
"dmtyaedes.\n",
"angenn.\n",
"gyl.\n",
"kybydknt.\n",
"narette.\n",
"arenpsgde.\n",
"ka.\n",
"b.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"kiit.\n",
"kamnq.\n",
"gavn.\n",
"elaeazin.\n",
".\n",
"ale.\n",
"ftll.\n",
"au.\n",
"lilu.\n",
"hada.\n",
"niw.\n",
"bern.\n",
"flyaneeliel.\n",
"dohylam.\n",
"orarnnn.\n",
"stena.\n",
"kysynray.\n",
"jok.\n",
"fien.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"krinlaiyaedvhaya.\n",
"kawie.\n",
"jalay.\n",
"ied.\n",
"ezn.\n",
"beoptnthewt.\n",
"arvtrhya.\n",
"kalmgbe.\n",
"jauzelia.\n",
"zylh.\n",
"z.\n",
"cf.\n",
"nidn.\n",
"roa.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"oawaivcakonwnylaydeiha.\n",
"btna.\n",
"ami.\n",
"fovt.\n",
"halzoa.\n",
"kianeah.\n",
"alyna.\n",
"asaomaan.\n",
"maiallo.\n",
"mabnieayde.\n",
"coracselnaem.\n",
"zlor.\n",
"tohtniy.\n",
"jiut.\n",
"tkaekai.\n",
"bzlennel.\n",
"sabicwawldya.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"maxhd.\n",
"zorynioh.\n",
"rivlnyhmieiaon.\n",
"i.\n",
"aninee.\n",
"maglit.\n",
"chni.\n",
"ehd.\n",
"asaisha.\n",
"b.\n",
"kaemunn.\n",
"kansva.\n",
"wynieanuka.\n",
"zo.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"nid.\n",
"nal.\n",
"ayyarolii.\n",
"za.\n",
"jhda.\n",
"anaqir.\n",
"ayaezolzaveil.\n",
"tres.\n",
"kaiha.\n",
"qayi.\n",
"fovz.\n",
"tuasntenkaahe.\n",
"roiiah.\n",
"anotgae.\n",
"ewn.\n",
"ldrloryny.\n",
"aniiah.\n",
"lah.\n",
".\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"jofvcmirn.\n",
"aqb.\n",
"toalaa.\n",
"gv.\n",
"avmotniyah.\n",
"mloatoi.\n",
"cnnnusiy.\n",
"naelnll.\n",
"javagl.\n",
"ooqb.\n",
"zoxskdrel.\n",
"a.\n",
"taxyrie.\n",
"abniea.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"cieneoyuis.\n",
"keu.\n",
"jamnllnsmaehi.\n",
"emiui.\n",
"avcyniarnn.\n",
"vayretwiyekamfkmvwei.\n",
"biararailon.\n",
"reo.\n",
"raeodnhany.\n",
"iasyuninyni.\n",
"amona.\n",
"rovhu.\n",
"isra.\n",
"hi.\n",
"zeynn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"kayena.\n",
"eigh.\n",
"nil.\n",
"sa.\n",
"aneanle.\n",
"dnn.\n",
"hlsynelak.\n",
"cwe.\n",
"gd.\n",
"fay.\n",
"dessri.\n",
"hgn.\n",
"dorh.\n",
"sdlan.\n",
"abi.\n",
"b.\n",
"teasyueeln.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"laiio.\n",
"euseazan.\n",
"eorneilee.\n",
"jfa.\n",
"abfiedil.\n",
"atnliea.\n",
"soalku.\n",
"zi.\n",
"kerdyan.\n",
"jayza.\n",
"tiia.\n",
"p.\n",
"ken.\n",
"ldanen.\n",
"marinnnn.\n",
"xakobua.\n",
"jah.\n",
"arii.\n",
"deus.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"saooaidaanna.\n",
"tdhunestavarfiiyansyn.\n",
"aou.\n",
"biaagoles.\n",
"ro.\n",
"ervhnn.\n",
"zeaalaa.\n",
"tdin.\n",
"esn.\n",
"hninillargi.\n",
"lislaln.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"ita.\n",
"eoe.\n",
"maraylsinslema.\n",
"dixed.\n",
"aydrs.\n",
"sateisyld.\n",
"po.\n",
"rakaal.\n",
"niu.\n",
"kesha.\n",
"asytein.\n",
"hioa.\n",
"brai.\n",
"kassa.\n",
"jeday.\n",
"zer.\n",
"jntyil.\n",
"lanyn.\n",
"doryyvetninaea.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"konarn.\n",
"ineiah.\n",
"mh.\n",
"sipol.\n",
"lkl.\n",
"u.\n",
"bynie.\n",
"pyka.\n",
"nyyeu.\n",
"bsaadalaema.\n",
"madii.\n",
"oeuee.\n",
"adntiyn.\n",
"iinia.\n",
"dmtyaedes.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"angenn.\n",
"gyl.\n",
"kybydknt.\n",
"narette.\n",
"arenpsgde.\n",
"ka.\n",
"b.\n",
"kiit.\n",
"kamnq.\n",
"gavn.\n",
"elaeazin.\n",
".\n",
"ale.\n",
"ftll.\n",
"au.\n",
"lilu.\n",
"hada.\n",
"niw.\n",
"bern.\n",
"flyaneeliel.\n",
"dohylam.\n",
"orarnnn.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"stena.\n",
"kysynray.\n",
"jok.\n",
"fien.\n"
]
}
],
"source": [
"# sample from the model\n",
"g = torch.Generator() # .manual_seed(2147483647)\n",
"import random\n",
"\n",
"\n",
"for _ in range(200):\n",
" out = []\n",
" context = [0] * block_size # initialize with all ...\n",
" while True:\n",
" next_chars = torch.tensor([0] + [stoi[c] for c in chars])\n",
" seq_options = torch.stack([torch.tensor(context + [c]) for c in next_chars])\n",
" # forward pass\n",
" embs = C[seq_options] # (1,block_size,d)\n",
"\n",
" hs = [torch.tanh(emb.view(-1, 40) @ W1 + b1) for emb in embs] # (32, 200)\n",
" energies = [h @ W2 + b2 for h in hs] # (32, 27)\n",
" energies = torch.tensor(energies)\n",
"\n",
" probs = torch.exp(-energies) / sum(torch.exp(-energies))\n",
" ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n",
" context = context[1:] + [ix]\n",
" out.append(ix)\n",
" if ix == 0:\n",
" break\n",
"\n",
" print(''.join(itos[i] for i in out))\n"
]
}
],
"metadata": {
"kernelspec": {
"argv": [
"python",
"-m",
"ipykernel_launcher",
"-f",
"{connection_file}"
],
"display_name": "Python 3 (ipykernel)",
"env": null,
"interrupt_mode": "signal",
"language": "python",
"metadata": {
"debugger": true
},
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.4"
},
"name": "energy_mlp.ipynb",
"vscode": {
"interpreter": {
"hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment