Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yangju2011/90f34cf6646119a05a7e85abc67c3b44 to your computer and use it in GitHub Desktop.
Save yangju2011/90f34cf6646119a05a7e85abc67c3b44 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "dccd1777",
"metadata": {},
"source": [
"Original paper: Bengio et al. 2003 A Neural Probabilistic Language Model\n",
"- challenges\n",
" - curse of dimensionality for joint probability \n",
" - context: word sequence\n",
" - similarity between words\n",
"- solution\n",
" - a learned distributed feature vector (30-dimension) to represent each word in the vocabulary of 17000 words\n",
" - each training sequence informs the model about a combinatorial number of other sentences\n",
" - input layer is 30-dim vector per word, 3 words -> 90 neurons -> hidden layer -> softmax -> 17000 output neurons -> softmax P(w_i|context). most computation happens at the softmax output step.\n",
"\n",
"In this notebook, we will **predict the next character based on 3 previous characters**."
]
},
{
"cell_type": "markdown",
"id": "14b30ea6",
"metadata": {},
"source": [
"# 1. string <> integer map for vocabulary"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6ed51a65",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a72e7a3c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['emma', 'olivia', 'ava', 'isabella', 'sophia']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"words = open('names.txt','r').read().splitlines()\n",
"words[:5]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6f531a94",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"32033"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# corpus\n",
"len(words)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e83a6060",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}\n"
]
}
],
"source": [
"# build the vocabulary of chars and mappings to/from integers\n",
"# same as in bag-of-words\n",
"chars = sorted(list(set(''.join(words))))\n",
"# string to integer\n",
"stoi = {s:i+1 for i, s in enumerate(chars)}\n",
"# special char to represent start and end\n",
"stoi['.'] = 0\n",
"# integer to string\n",
"itos = {i:s for s, i in stoi.items()}\n",
"print (itos)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ca1a37a3",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------\n",
"emma\n",
"... ---> e\n",
"context before: [0, 0, 0]\n",
"output char: e\n",
"output index 5\n",
"context after: [0, 0, 5]\n",
"..e ---> m\n",
"context before: [0, 0, 5]\n",
"output char: m\n",
"output index 13\n",
"context after: [0, 5, 13]\n",
".em ---> m\n",
"context before: [0, 5, 13]\n",
"output char: m\n",
"output index 13\n",
"context after: [5, 13, 13]\n",
"emm ---> a\n",
"context before: [5, 13, 13]\n",
"output char: a\n",
"output index 1\n",
"context after: [13, 13, 1]\n",
"mma ---> .\n",
"context before: [13, 13, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [13, 1, 0]\n",
"--------------\n",
"olivia\n",
"... ---> o\n",
"context before: [0, 0, 0]\n",
"output char: o\n",
"output index 15\n",
"context after: [0, 0, 15]\n",
"..o ---> l\n",
"context before: [0, 0, 15]\n",
"output char: l\n",
"output index 12\n",
"context after: [0, 15, 12]\n",
".ol ---> i\n",
"context before: [0, 15, 12]\n",
"output char: i\n",
"output index 9\n",
"context after: [15, 12, 9]\n",
"oli ---> v\n",
"context before: [15, 12, 9]\n",
"output char: v\n",
"output index 22\n",
"context after: [12, 9, 22]\n",
"liv ---> i\n",
"context before: [12, 9, 22]\n",
"output char: i\n",
"output index 9\n",
"context after: [9, 22, 9]\n",
"ivi ---> a\n",
"context before: [9, 22, 9]\n",
"output char: a\n",
"output index 1\n",
"context after: [22, 9, 1]\n",
"via ---> .\n",
"context before: [22, 9, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [9, 1, 0]\n",
"--------------\n",
"ava\n",
"... ---> a\n",
"context before: [0, 0, 0]\n",
"output char: a\n",
"output index 1\n",
"context after: [0, 0, 1]\n",
"..a ---> v\n",
"context before: [0, 0, 1]\n",
"output char: v\n",
"output index 22\n",
"context after: [0, 1, 22]\n",
".av ---> a\n",
"context before: [0, 1, 22]\n",
"output char: a\n",
"output index 1\n",
"context after: [1, 22, 1]\n",
"ava ---> .\n",
"context before: [1, 22, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [22, 1, 0]\n",
"--------------\n",
"isabella\n",
"... ---> i\n",
"context before: [0, 0, 0]\n",
"output char: i\n",
"output index 9\n",
"context after: [0, 0, 9]\n",
"..i ---> s\n",
"context before: [0, 0, 9]\n",
"output char: s\n",
"output index 19\n",
"context after: [0, 9, 19]\n",
".is ---> a\n",
"context before: [0, 9, 19]\n",
"output char: a\n",
"output index 1\n",
"context after: [9, 19, 1]\n",
"isa ---> b\n",
"context before: [9, 19, 1]\n",
"output char: b\n",
"output index 2\n",
"context after: [19, 1, 2]\n",
"sab ---> e\n",
"context before: [19, 1, 2]\n",
"output char: e\n",
"output index 5\n",
"context after: [1, 2, 5]\n",
"abe ---> l\n",
"context before: [1, 2, 5]\n",
"output char: l\n",
"output index 12\n",
"context after: [2, 5, 12]\n",
"bel ---> l\n",
"context before: [2, 5, 12]\n",
"output char: l\n",
"output index 12\n",
"context after: [5, 12, 12]\n",
"ell ---> a\n",
"context before: [5, 12, 12]\n",
"output char: a\n",
"output index 1\n",
"context after: [12, 12, 1]\n",
"lla ---> .\n",
"context before: [12, 12, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [12, 1, 0]\n",
"--------------\n",
"sophia\n",
"... ---> s\n",
"context before: [0, 0, 0]\n",
"output char: s\n",
"output index 19\n",
"context after: [0, 0, 19]\n",
"..s ---> o\n",
"context before: [0, 0, 19]\n",
"output char: o\n",
"output index 15\n",
"context after: [0, 19, 15]\n",
".so ---> p\n",
"context before: [0, 19, 15]\n",
"output char: p\n",
"output index 16\n",
"context after: [19, 15, 16]\n",
"sop ---> h\n",
"context before: [19, 15, 16]\n",
"output char: h\n",
"output index 8\n",
"context after: [15, 16, 8]\n",
"oph ---> i\n",
"context before: [15, 16, 8]\n",
"output char: i\n",
"output index 9\n",
"context after: [16, 8, 9]\n",
"phi ---> a\n",
"context before: [16, 8, 9]\n",
"output char: a\n",
"output index 1\n",
"context after: [8, 9, 1]\n",
"hia ---> .\n",
"context before: [8, 9, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [9, 1, 0]\n"
]
}
],
"source": [
"# build the dataset and see how it works\n",
"\n",
"block_size = 3\n",
"X, Y = [], []\n",
"for w in words[:5]:\n",
" print ('--------------')\n",
" print (w)\n",
" # integer 0 -> .\n",
" # ... ---> e\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" print (''.join(itos[i] for i in context), '--->', itos[ix])\n",
" print ('context before:', context)\n",
" print ('output char:',ch)\n",
" print ('output index',ix)\n",
"\n",
" # sliding window to the next context\n",
" context = context[1:] + [ix]\n",
" print ('context after:', context)\n",
" \n",
" \n",
"X=torch.tensor(X)\n",
"Y=torch.tensor(Y)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8d3190c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape, X.dtype, Y.shape, Y.dtype"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cdf7a0f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 0, 0],\n",
" [ 0, 0, 5],\n",
" [ 0, 5, 13],\n",
" [ 5, 13, 13],\n",
" [13, 13, 1]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[:5]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "85ca81b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 5, 13, 13, 1, 0])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y[:5]"
]
},
{
"cell_type": "markdown",
"id": "c55a1fab",
"metadata": {},
"source": [
"# 2. feature vector"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "38c6aecf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"27"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Vdim = len(stoi.keys())\n",
"Vdim"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "61fc8fec",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.0174, 0.2338],\n",
" [ 0.3223, 0.4244],\n",
" [-0.3297, -0.0899],\n",
" [-0.3791, 0.2541],\n",
" [-0.8308, 0.4593],\n",
" [ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066],\n",
" [ 0.5257, 1.8360],\n",
" [-0.7879, 1.7276],\n",
" [ 0.5974, 1.3534],\n",
" [-1.2337, -0.4362],\n",
" [-1.8687, -1.5495],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.2043, -0.8093],\n",
" [-0.0207, 0.7163],\n",
" [ 0.0288, 1.1582],\n",
" [-1.0672, -1.6540],\n",
" [-0.9492, 0.2938],\n",
" [-0.1061, -1.3482],\n",
" [ 0.3623, -0.1373],\n",
" [-0.0161, -0.2087],\n",
" [-0.8985, -0.2570],\n",
" [-0.4422, 0.4705]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# take 27 unique chars, put into a 2-D feature vector\n",
"# layer with 27 input and 2 output\n",
"\n",
"Cdim = 2\n",
"# initialize with random numbers from a normal distribution with mean `0` and variance `1`\n",
"C = torch.randn((Vdim, Cdim))\n",
"C"
]
},
{
"cell_type": "markdown",
"id": "9b184414",
"metadata": {},
"source": [
"## embedding of the integer (feature vector for a char)\n",
"1. index for C\n",
"2. one-hot encodding @ C"
]
},
{
"cell_type": "markdown",
"id": "e37da9d0",
"metadata": {},
"source": [
"### 1. index for C"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5058e4c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4464, -0.6289])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# prefered. it is FASTER than matrix multiplication\n",
"C[0]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "09e5c2ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4464, -0.6289])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C[stoi['.']]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "80d34c62",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.3297, -0.0899])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C[5]"
]
},
{
"cell_type": "markdown",
"id": "6b8ef1c0",
"metadata": {},
"source": [
"### 2. one-hot encodding @ C"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3d832521",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.one_hot(torch.tensor(5), num_classes=27)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a9c49619",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "expected scalar type Long but found Float",
"output_type": "error",
"traceback": [
"\u001b[0;31m-------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-15-246e9ef054ac>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# one_hot is int, C is float, does not know how to multiply\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mone_hot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_classes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m27\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mC\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m: expected scalar type Long but found Float"
]
}
],
"source": [
"# one_hot is int, C is float, does not know how to multiply\n",
"F.one_hot(torch.tensor(5), num_classes=27) @ C"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "4ac24f98",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.int64"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.one_hot(torch.tensor(5), num_classes=27).dtype"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4e1dfed9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.3297, -0.0899])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# matrix multiplication to get the i-th row from C\n",
"F.one_hot(torch.tensor(5), num_classes=27).float() @ C"
]
},
{
"cell_type": "markdown",
"id": "a87d59b9",
"metadata": {},
"source": [
"## embedding of X"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f5173d6a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4464, -0.6289])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C[0]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "eeefed64",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [ 1.3732, 0.4287],\n",
" [ 1.3732, 0.4287]])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C[[0,1,2,2,2]]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "6791fef3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [ 1.3732, 0.4287],\n",
" [ 1.3732, 0.4287]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# retrieve value from a matrix\n",
"C[torch.tensor([0,1,2,2,2])]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7b327d78",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# from 2 words, 32 training samples\n",
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "0f104d34",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 0, 0],\n",
" [ 0, 0, 5],\n",
" [ 0, 5, 13],\n",
" [ 5, 13, 13],\n",
" [13, 13, 1],\n",
" [ 0, 0, 0],\n",
" [ 0, 0, 15],\n",
" [ 0, 15, 12],\n",
" [15, 12, 9],\n",
" [12, 9, 22],\n",
" [ 9, 22, 9],\n",
" [22, 9, 1],\n",
" [ 0, 0, 0],\n",
" [ 0, 0, 1],\n",
" [ 0, 1, 22],\n",
" [ 1, 22, 1],\n",
" [ 0, 0, 0],\n",
" [ 0, 0, 9],\n",
" [ 0, 9, 19],\n",
" [ 9, 19, 1],\n",
" [19, 1, 2],\n",
" [ 1, 2, 5],\n",
" [ 2, 5, 12],\n",
" [ 5, 12, 12],\n",
" [12, 12, 1],\n",
" [ 0, 0, 0],\n",
" [ 0, 0, 19],\n",
" [ 0, 19, 15],\n",
" [19, 15, 16],\n",
" [15, 16, 8],\n",
" [16, 8, 9],\n",
" [ 8, 9, 1]])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X\n",
"# X is training data, each X[i] is 3 word in a row\n",
"# X[a,b] is the b-th char in the a-th record"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "0f8bd881",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(9)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[10,2]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "9f4c8ffe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'i'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"itos[X[10,2].item()]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "64227f92",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.5079, 0.7066])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# index 1 of C is a 2-dim value\n",
"# each char is encoded in a 2-d array in C\n",
"# get the embedding fromm C for a char \n",
"C[X[10,2]]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0da98bb8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3, 2])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# [32, 3]\n",
"# [27, 2]\n",
"# [32, 3, 2]\n",
"C[X].shape"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "942cde17",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.3297, -0.0899]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.3297, -0.0899],\n",
" [-1.2337, -0.4362]],\n",
"\n",
" [[-0.3297, -0.0899],\n",
" [-1.2337, -0.4362],\n",
" [-1.2337, -0.4362]],\n",
"\n",
" [[-1.2337, -0.4362],\n",
" [-1.2337, -0.4362],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534]],\n",
"\n",
" [[ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066]],\n",
"\n",
" [[ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066],\n",
" [-0.1061, -1.3482]],\n",
"\n",
" [[-0.5079, 0.7066],\n",
" [-0.1061, -1.3482],\n",
" [-0.5079, 0.7066]],\n",
"\n",
" [[-0.1061, -1.3482],\n",
" [-0.5079, 0.7066],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [-0.1061, -1.3482]],\n",
"\n",
" [[ 1.3162, -0.2359],\n",
" [-0.1061, -1.3482],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.5079, 0.7066]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582]],\n",
"\n",
" [[-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287]],\n",
"\n",
" [[ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899]],\n",
"\n",
" [[ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534]],\n",
"\n",
" [[-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534]],\n",
"\n",
" [[ 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078]],\n",
"\n",
" [[ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016]],\n",
"\n",
" [[ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561]],\n",
"\n",
" [[ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066]],\n",
"\n",
" [[ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066],\n",
" [ 1.3162, -0.2359]]])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# X is all the chars in all trainign record\n",
"# each char -> 2D \n",
"C[X]"
]
},
{
"cell_type": "markdown",
"id": "fc698b6b",
"metadata": {},
"source": [
"# 3. build neural network with the embedded input"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "11b97859",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3, 2])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# embedding\n",
"# C[X] contains the embedding for all input for [3,2]\n",
"emb = C[X]\n",
"emb.shape"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "f4bd8c19",
"metadata": {},
"outputs": [],
"source": [
"# use the word2vec architeture from the paper\n",
"# hidden layer\n",
"\n",
"# number of inputs is 3X2: 2 dimension embedding with 3 words\n",
"# number of neurons (output): 100\n",
"block_size = 3\n",
"Cdim = 2\n",
"n_input1 = block_size * Cdim\n",
"n_output1 = 100\n",
"\n",
"W1 = torch.randn([n_input1,n_output1])\n",
"# bias\n",
"b1 = torch.rand(n_output1)\n",
"\n",
"# we usually do emb @ W1 + b1\n",
"# but embedding is staked, need to flatten it"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "36172342",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "mat1 and mat2 shapes cannot be multiplied (96x2 and 6x100)",
"output_type": "error",
"traceback": [
"\u001b[0;31m-------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-30-355728df8662>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0memb\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mW1\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (96x2 and 6x100)"
]
}
],
"source": [
"emb @ W1 + b1"
]
},
{
"cell_type": "markdown",
"id": "914ace3d",
"metadata": {},
"source": [
"## concat input of 3 words into a single vector\n",
"1. cat: not efficent, use memory to create new tensor\n",
"2. view: efficient!"
]
},
{
"cell_type": "markdown",
"id": "522fad71",
"metadata": {},
"source": [
"### 1. cat"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "a0a80e36",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 2])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 1st word \n",
"emb[:, 0, :].shape"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "d03f39e8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([96, 2])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# concat the input to the correct dimension\n",
"torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], dim=0).shape"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "682d63fe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 6])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], dim=1).shape"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "eeee9a5f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.3297, -0.0899],\n",
" [-0.4464, -0.6289, -0.3297, -0.0899, -1.2337, -0.4362],\n",
" [-0.3297, -0.0899, -1.2337, -0.4362, -1.2337, -0.4362],\n",
" [-1.2337, -0.4362, -1.2337, -0.4362, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 1.0291, -0.0078],\n",
" [-0.4464, -0.6289, 1.0291, -0.0078, 0.5974, 1.3534],\n",
" [ 1.0291, -0.0078, 0.5974, 1.3534, -0.5079, 0.7066],\n",
" [ 0.5974, 1.3534, -0.5079, 0.7066, -0.1061, -1.3482],\n",
" [-0.5079, 0.7066, -0.1061, -1.3482, -0.5079, 0.7066],\n",
" [-0.1061, -1.3482, -0.5079, 0.7066, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, 1.3162, -0.2359, -0.1061, -1.3482],\n",
" [ 1.3162, -0.2359, -0.1061, -1.3482, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.5079, 0.7066],\n",
" [-0.4464, -0.6289, -0.5079, 0.7066, 0.0288, 1.1582],\n",
" [-0.5079, 0.7066, 0.0288, 1.1582, 1.3162, -0.2359],\n",
" [ 0.0288, 1.1582, 1.3162, -0.2359, 1.3732, 0.4287],\n",
" [ 1.3162, -0.2359, 1.3732, 0.4287, -0.3297, -0.0899],\n",
" [ 1.3732, 0.4287, -0.3297, -0.0899, 0.5974, 1.3534],\n",
" [-0.3297, -0.0899, 0.5974, 1.3534, 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534, 0.5974, 1.3534, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 0.0288, 1.1582],\n",
" [-0.4464, -0.6289, 0.0288, 1.1582, 1.0291, -0.0078],\n",
" [ 0.0288, 1.1582, 1.0291, -0.0078, 1.7785, 0.6016],\n",
" [ 1.0291, -0.0078, 1.7785, 0.6016, 1.5862, 0.9561],\n",
" [ 1.7785, 0.6016, 1.5862, 0.9561, -0.5079, 0.7066],\n",
" [ 1.5862, 0.9561, -0.5079, 0.7066, 1.3162, -0.2359]])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# there are 32 training samples, we want for each training sample, all 3 words are combined\n",
"# -> [32, 2*3]\n",
"torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "2e6d72df",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3, 2])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "46dc9af3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.3297, -0.0899],\n",
" [-1.2337, -0.4362],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066],\n",
" [-0.1061, -1.3482],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561]]),\n",
" tensor([[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.3297, -0.0899],\n",
" [-1.2337, -0.4362],\n",
" [-1.2337, -0.4362],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066],\n",
" [-0.1061, -1.3482],\n",
" [-0.5079, 0.7066],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [-0.1061, -1.3482],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066]]),\n",
" tensor([[-0.4464, -0.6289],\n",
" [-0.3297, -0.0899],\n",
" [-1.2337, -0.4362],\n",
" [-1.2337, -0.4362],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066],\n",
" [-0.1061, -1.3482],\n",
" [-0.5079, 0.7066],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [-0.1061, -1.3482],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066],\n",
" [ 1.3162, -0.2359]]))"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Removes a tensor dimension.\n",
"# Returns a tuple of all slices along a given dimension, already without it.\n",
"\n",
"torch.unbind(emb, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "b6372560",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 6])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# unbind and then concat\n",
"torch.cat(torch.unbind(emb, dim=1),dim=1).shape"
]
},
{
"cell_type": "markdown",
"id": "70cf946b",
"metadata": {},
"source": [
"### 2. view\n",
"very efficient because a view does not create a new tensor; instead, it just returns a tensor which is a different view on the underlying data (physical storage)\n",
"http://blog.ezyang.com/2019/05/pytorch-internals/"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "eb233714",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = torch.arange(18)\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "018a896f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([18])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.shape"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "532dcfe7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8],\n",
" [ 9, 10, 11, 12, 13, 14, 15, 16, 17]])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.view([2,9])"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "7fa30b55",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 1],\n",
" [ 2, 3],\n",
" [ 4, 5],\n",
" [ 6, 7],\n",
" [ 8, 9],\n",
" [10, 11],\n",
" [12, 13],\n",
" [14, 15],\n",
" [16, 17]])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.view([9,2])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "4d8da572",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 1, 2, 3, 4, 5],\n",
" [ 6, 7, 8, 9, 10, 11],\n",
" [12, 13, 14, 15, 16, 17]])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.view([3,6])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "8e9755e5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" 0\n",
" 1\n",
" 2\n",
" 3\n",
" 4\n",
" 5\n",
" 6\n",
" 7\n",
" 8\n",
" 9\n",
" 10\n",
" 11\n",
" 12\n",
" 13\n",
" 14\n",
" 15\n",
" 16\n",
" 17\n",
"[torch.LongStorage of size 18]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.storage()\n",
"# numbers stored in 1D vector\n",
"# .view changes the 1D sequence into a tensor\n",
"# no copy of the value created"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "777f6498",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3, 2])"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.shape"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "9c650672",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.3297, -0.0899],\n",
" [-0.4464, -0.6289, -0.3297, -0.0899, -1.2337, -0.4362],\n",
" [-0.3297, -0.0899, -1.2337, -0.4362, -1.2337, -0.4362],\n",
" [-1.2337, -0.4362, -1.2337, -0.4362, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 1.0291, -0.0078],\n",
" [-0.4464, -0.6289, 1.0291, -0.0078, 0.5974, 1.3534],\n",
" [ 1.0291, -0.0078, 0.5974, 1.3534, -0.5079, 0.7066],\n",
" [ 0.5974, 1.3534, -0.5079, 0.7066, -0.1061, -1.3482],\n",
" [-0.5079, 0.7066, -0.1061, -1.3482, -0.5079, 0.7066],\n",
" [-0.1061, -1.3482, -0.5079, 0.7066, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, 1.3162, -0.2359, -0.1061, -1.3482],\n",
" [ 1.3162, -0.2359, -0.1061, -1.3482, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.5079, 0.7066],\n",
" [-0.4464, -0.6289, -0.5079, 0.7066, 0.0288, 1.1582],\n",
" [-0.5079, 0.7066, 0.0288, 1.1582, 1.3162, -0.2359],\n",
" [ 0.0288, 1.1582, 1.3162, -0.2359, 1.3732, 0.4287],\n",
" [ 1.3162, -0.2359, 1.3732, 0.4287, -0.3297, -0.0899],\n",
" [ 1.3732, 0.4287, -0.3297, -0.0899, 0.5974, 1.3534],\n",
" [-0.3297, -0.0899, 0.5974, 1.3534, 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534, 0.5974, 1.3534, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 0.0288, 1.1582],\n",
" [-0.4464, -0.6289, 0.0288, 1.1582, 1.0291, -0.0078],\n",
" [ 0.0288, 1.1582, 1.0291, -0.0078, 1.7785, 0.6016],\n",
" [ 1.0291, -0.0078, 1.7785, 0.6016, 1.5862, 0.9561],\n",
" [ 1.7785, 0.6016, 1.5862, 0.9561, -0.5079, 0.7066],\n",
" [ 1.5862, 0.9561, -0.5079, 0.7066, 1.3162, -0.2359]])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# horray, 2 get stacked up to 3\n",
"emb.view([32,6])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "64eb97d1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True]])"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.view([32,6]) == torch.cat(torch.unbind(emb,1),1)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "2649e57a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 100])"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(emb.view([X.shape[0],n_input1])@W1 + b1).shape"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "33adcb51",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.2642, -0.5217, -0.5207, ..., 1.6704, 0.4607, -0.2842],\n",
" [-0.5069, -0.4191, 0.5045, ..., 1.2557, 0.1441, 0.3194],\n",
" [-0.4562, -0.8144, 0.4525, ..., 0.3793, -1.2262, -0.2351],\n",
" ...,\n",
" [ 0.8261, 3.0585, 2.8153, ..., 0.3313, 4.5300, 2.5886],\n",
" [ 0.4643, 1.9310, 3.9621, ..., -1.6207, 2.8345, 0.2874],\n",
" [ 3.2720, -2.1254, -0.0563, ..., -1.9226, 2.0426, -0.6754]])"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# use -1, it will infer the dim\n",
"h = emb.view([-1,n_input1])@W1 + b1\n",
"# b1 is 1D\n",
"# the broadcasting with + b1, align on the right\n",
"# 32, 100\n",
"# 1, 100\n",
"# => \n",
"# 32, 100\n",
"# 32, 100 -> broadcasted \n",
"# hidden layer\n",
"h"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "62fdd381",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 27])"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# final layer\n",
"# output is 27 possibe chars\n",
"n_output2 = Vdim\n",
"W2 = torch.randn([n_output1, n_output2])\n",
"b2 = torch.randn(n_output2)\n",
"logits = h @ W2 + b2\n",
"logits.shape"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "f3e46359",
"metadata": {},
"outputs": [],
"source": [
"counts = logits.exp()"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "bc4ea2ae",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 27])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# sum by row\n",
"prob = counts / counts.sum(1, keepdims=True)\n",
"prob.shape"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "e6eecd51",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [1.5634e-06, 5.8689e-03, 5.5907e-12, 3.6511e-24, 1.9378e-13, 1.6033e-02,\n",
" 1.3811e-09, 8.4612e-10, 9.2904e-12, 7.1332e-04, 7.8706e-05, 1.0586e-08,\n",
" 9.6776e-10, 2.4594e-09, 5.3621e-12, 1.4119e-13, 2.0930e-08, 8.0390e-03,\n",
" 1.7799e-03, 7.4927e-03, 9.5619e-08, 3.6570e-17, 5.1102e-07, 6.5709e-08,\n",
" 8.0131e-14, 9.7636e-08, 9.5999e-01],\n",
" [2.5425e-13, 8.1657e-01, 4.2780e-17, 4.2832e-29, 1.9830e-15, 1.8342e-01,\n",
" 4.6704e-11, 6.8022e-12, 9.1839e-19, 7.4077e-16, 5.5468e-18, 9.6685e-19,\n",
" 8.7594e-16, 9.6322e-16, 3.2213e-11, 2.2215e-15, 8.7443e-19, 2.2657e-11,\n",
" 1.9050e-09, 7.0052e-06, 2.5675e-14, 7.9356e-15, 4.7929e-06, 1.3517e-14,\n",
" 3.5140e-19, 1.3572e-16, 6.8860e-12],\n",
" [1.5926e-16, 8.2677e-01, 4.5069e-28, 6.2221e-31, 3.4318e-19, 4.3759e-02,\n",
" 4.8483e-15, 1.3489e-20, 1.2572e-20, 3.5169e-17, 4.1630e-18, 1.0599e-19,\n",
" 1.0524e-20, 1.7805e-16, 2.4486e-17, 1.1565e-21, 1.0517e-13, 4.8244e-04,\n",
" 1.2899e-01, 5.0841e-13, 3.9280e-20, 3.3293e-29, 1.9242e-13, 3.8402e-11,\n",
" 5.0398e-27, 2.0794e-23, 6.0030e-15],\n",
" [5.6912e-18, 8.5762e-14, 5.4786e-09, 7.6643e-28, 9.8223e-34, 1.0855e-20,\n",
" 1.6902e-23, 2.9398e-26, 2.8179e-18, 2.9614e-25, 2.1332e-17, 1.3366e-05,\n",
" 1.3824e-13, 3.1554e-20, 2.1864e-30, 4.9378e-17, 4.9420e-28, 1.7444e-06,\n",
" 9.9998e-01, 2.7049e-28, 1.3291e-20, 1.7832e-37, 7.2543e-33, 7.4020e-16,\n",
" 9.4811e-27, 1.4014e-23, 2.4846e-14],\n",
" [1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [6.7133e-07, 3.9065e-12, 7.0431e-06, 2.8241e-21, 6.2352e-18, 2.1865e-14,\n",
" 1.1576e-11, 2.3479e-14, 1.5960e-10, 1.1027e-07, 9.4439e-01, 3.4291e-05,\n",
" 1.0041e-07, 7.8863e-13, 2.0277e-18, 1.1248e-11, 5.9427e-11, 2.3680e-04,\n",
" 3.2660e-04, 4.0719e-10, 1.8607e-08, 3.9505e-23, 1.6343e-18, 1.7168e-11,\n",
" 5.8671e-17, 5.8176e-08, 5.5003e-02],\n",
" [2.3603e-19, 1.8190e-28, 1.0644e-06, 1.5602e-22, 6.6367e-20, 2.4823e-30,\n",
" 2.2450e-22, 2.0518e-26, 1.3708e-25, 9.2010e-09, 8.1778e-17, 3.9444e-13,\n",
" 7.3042e-02, 1.0480e-09, 5.7714e-22, 1.2623e-20, 6.2400e-27, 1.8503e-17,\n",
" 2.8257e-32, 3.2646e-07, 3.4103e-15, 9.0763e-01, 1.1386e-22, 2.2861e-30,\n",
" 1.9326e-02, 9.4481e-10, 1.4029e-11],\n",
" [5.4918e-28, 2.4647e-27, 5.9095e-19, 1.0829e-18, 3.9169e-10, 1.1083e-32,\n",
" 2.8731e-15, 8.3012e-37, 3.1838e-33, 3.6336e-30, 9.5694e-25, 1.5107e-38,\n",
" 9.1791e-09, 2.4099e-19, 1.0682e-17, 7.8429e-15, 4.1298e-31, 5.1046e-24,\n",
" 1.1585e-37, 1.8818e-10, 1.2956e-25, 1.0000e+00, 1.5812e-28, 8.5479e-44,\n",
" 2.8373e-21, 4.4365e-27, 1.8675e-37],\n",
" [5.3424e-22, 3.1758e-10, 6.6093e-29, 3.0561e-17, 4.7211e-18, 4.2464e-26,\n",
" 9.9979e-01, 2.6293e-29, 4.4755e-24, 0.0000e+00, 2.2950e-27, 5.8107e-26,\n",
" 2.0452e-27, 1.6954e-37, 7.9085e-17, 3.9599e-19, 5.0982e-15, 4.5363e-19,\n",
" 2.0801e-04, 1.9471e-31, 9.5267e-33, 2.2751e-28, 7.4330e-31, 1.7626e-20,\n",
" 1.2259e-40, 7.5773e-35, 1.2144e-39],\n",
" [6.6060e-16, 7.0911e-14, 8.9545e-33, 3.0449e-27, 7.3567e-22, 4.4476e-16,\n",
" 1.6143e-17, 3.1832e-25, 9.2399e-23, 1.8690e-02, 7.4776e-16, 1.8604e-06,\n",
" 9.2996e-18, 1.0148e-11, 2.1973e-21, 9.4467e-38, 9.7591e-01, 5.3289e-03,\n",
" 2.5007e-08, 1.3354e-12, 9.9222e-21, 6.2210e-22, 3.9712e-16, 7.9575e-07,\n",
" 7.1641e-10, 2.4088e-12, 6.9206e-05],\n",
" [2.1457e-19, 5.3917e-23, 9.9863e-01, 3.5339e-26, 2.4931e-23, 7.7466e-27,\n",
" 2.8573e-20, 1.9329e-25, 9.6134e-21, 7.1319e-31, 1.7529e-12, 3.0144e-25,\n",
" 1.3543e-09, 7.0777e-24, 8.2387e-26, 1.3696e-03, 9.6154e-39, 2.6552e-20,\n",
" 3.8563e-21, 5.2864e-19, 8.4760e-17, 7.3578e-25, 4.1313e-33, 1.1553e-35,\n",
" 3.3413e-32, 7.9586e-24, 7.1770e-24],\n",
" [1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [5.5573e-07, 1.6617e-13, 1.8262e-05, 6.0024e-22, 7.0303e-20, 7.5283e-17,\n",
" 4.0332e-11, 3.7349e-14, 3.2555e-10, 2.5044e-11, 9.9666e-01, 1.9005e-05,\n",
" 4.2143e-10, 7.9021e-17, 8.4015e-20, 1.5246e-11, 8.2403e-12, 7.8824e-07,\n",
" 9.0538e-04, 2.0810e-13, 2.0073e-09, 1.0948e-26, 5.9718e-21, 4.5434e-12,\n",
" 1.3866e-20, 4.2469e-09, 2.3968e-03],\n",
" [1.2502e-09, 1.4916e-10, 1.5064e-12, 4.8487e-31, 3.0801e-20, 5.8148e-16,\n",
" 1.5651e-01, 8.4349e-01, 3.9499e-17, 1.7393e-26, 1.9220e-18, 1.6434e-16,\n",
" 6.4413e-27, 3.6102e-36, 3.7108e-10, 4.1036e-19, 1.1058e-17, 9.3626e-34,\n",
" 8.1936e-19, 5.4736e-13, 1.5986e-14, 9.5979e-14, 2.6343e-06, 5.5294e-17,\n",
" 3.1486e-23, 3.1949e-10, 1.4226e-11],\n",
" [4.2523e-22, 5.6052e-45, 9.8091e-45, 1.2472e-43, 5.6901e-29, 9.3326e-43,\n",
" 1.5425e-23, 9.4540e-35, 4.2318e-29, 4.1244e-19, 9.9999e-01, 4.4577e-39,\n",
" 3.9169e-41, 3.9937e-43, 8.8514e-39, 3.7042e-40, 8.0529e-06, 8.4035e-30,\n",
" 1.6566e-31, 6.3691e-26, 1.2070e-27, 0.0000e+00, 4.2959e-40, 7.0247e-33,\n",
" 0.0000e+00, 1.5955e-22, 2.0003e-19],\n",
" [1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [3.5188e-13, 2.4852e-10, 1.3029e-16, 4.9331e-27, 2.8343e-16, 8.0934e-08,\n",
" 2.9858e-19, 9.6433e-20, 8.6115e-19, 9.6047e-01, 2.5603e-09, 1.3697e-12,\n",
" 1.7087e-07, 7.3772e-05, 8.8451e-18, 1.1927e-19, 9.9358e-14, 3.0442e-02,\n",
" 4.6515e-12, 7.8968e-03, 1.1242e-11, 1.5575e-15, 3.4302e-12, 9.6888e-15,\n",
" 2.2749e-10, 7.7974e-11, 1.1181e-03],\n",
" [5.1837e-27, 4.5088e-21, 3.2621e-09, 2.6907e-23, 5.7301e-22, 4.3898e-24,\n",
" 7.5240e-29, 2.8476e-37, 2.0032e-29, 5.0041e-21, 7.1890e-25, 8.2807e-22,\n",
" 1.0000e+00, 2.5662e-07, 7.4503e-26, 6.3922e-14, 2.8118e-39, 2.3482e-09,\n",
" 3.0197e-24, 9.3495e-14, 7.2589e-22, 2.5147e-10, 1.2226e-30, 9.4292e-35,\n",
" 4.8043e-14, 7.4667e-26, 3.7715e-25],\n",
" [3.3859e-21, 1.5635e-15, 9.9985e-01, 1.9184e-07, 7.5952e-22, 8.3803e-34,\n",
" 6.8879e-08, 2.9004e-30, 3.4280e-21, 2.3262e-43, 1.0164e-28, 2.7298e-08,\n",
" 1.5216e-04, 3.8501e-23, 3.7163e-19, 6.5912e-08, 8.2446e-32, 5.8319e-15,\n",
" 1.6726e-08, 2.1963e-29, 1.6910e-25, 4.8436e-10, 5.3335e-36, 5.5999e-25,\n",
" 7.1183e-19, 6.5876e-27, 3.9180e-33],\n",
" [1.3869e-11, 1.3064e-23, 1.3883e-11, 3.0361e-07, 1.1145e-14, 1.1017e-36,\n",
" 2.2892e-01, 2.3501e-22, 3.0828e-17, 4.0604e-17, 1.7824e-12, 7.6531e-01,\n",
" 9.6515e-08, 1.3337e-19, 4.9742e-15, 4.9165e-23, 5.4202e-03, 1.5549e-14,\n",
" 1.6286e-14, 8.5261e-17, 4.4036e-18, 3.3526e-04, 7.5974e-25, 9.9077e-15,\n",
" 1.6673e-05, 4.9403e-07, 2.8487e-13],\n",
" [3.6625e-13, 4.2774e-23, 4.5643e-19, 3.2510e-20, 1.6000e-03, 3.2078e-25,\n",
" 6.6707e-01, 4.6174e-16, 8.9175e-22, 1.4777e-17, 1.1382e-07, 1.5487e-32,\n",
" 2.6719e-16, 3.7821e-25, 1.7327e-08, 1.9661e-15, 8.0458e-10, 5.1622e-28,\n",
" 4.0367e-33, 2.8927e-01, 2.1098e-14, 4.2066e-02, 5.9124e-13, 8.9249e-31,\n",
" 6.7915e-21, 1.4951e-10, 7.5850e-20],\n",
" [3.4458e-19, 1.3673e-31, 3.7091e-26, 2.2236e-15, 2.1822e-08, 7.1872e-33,\n",
" 1.0542e-16, 2.6291e-40, 3.0007e-24, 1.2446e-06, 9.9918e-01, 1.9935e-26,\n",
" 3.3476e-06, 3.1589e-09, 7.6126e-24, 3.8830e-21, 7.0703e-05, 7.3992e-04,\n",
" 1.8127e-23, 3.4805e-08, 2.1202e-19, 2.6834e-18, 2.1624e-33, 2.3875e-29,\n",
" 4.8275e-20, 6.3455e-17, 1.7028e-19],\n",
" [8.1457e-35, 1.2857e-33, 1.1304e-07, 3.5911e-19, 1.1025e-25, 2.6485e-43,\n",
" 7.8327e-29, 2.8026e-45, 1.9902e-37, 1.8492e-35, 3.6882e-37, 9.1030e-25,\n",
" 2.3057e-01, 2.7072e-16, 7.7188e-29, 5.5285e-18, 0.0000e+00, 2.4823e-23,\n",
" 4.3372e-38, 1.6958e-22, 3.0084e-30, 7.6943e-01, 7.6148e-41, 0.0000e+00,\n",
" 2.9827e-12, 1.3995e-30, 2.2507e-38],\n",
" [4.4522e-22, 1.6499e-23, 1.1505e-09, 6.5774e-03, 1.5525e-14, 2.6204e-43,\n",
" 9.9342e-01, 2.1890e-34, 1.0482e-23, 0.0000e+00, 3.5921e-25, 1.5227e-17,\n",
" 7.4367e-09, 3.5658e-29, 6.7953e-17, 3.0815e-11, 4.1484e-22, 7.5288e-21,\n",
" 2.5777e-16, 1.3564e-28, 1.0549e-28, 2.9484e-07, 5.0902e-39, 3.8570e-30,\n",
" 2.4113e-23, 3.3730e-27, 5.3107e-40],\n",
" [1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [1.1852e-16, 2.9272e-18, 4.8271e-16, 4.9006e-27, 1.1208e-19, 6.4937e-16,\n",
" 4.4038e-25, 7.2464e-27, 1.0100e-21, 9.8265e-01, 1.3072e-09, 7.8116e-13,\n",
" 3.0387e-05, 5.3026e-04, 1.3162e-23, 8.9681e-22, 1.5955e-17, 1.6768e-02,\n",
" 1.0449e-16, 4.9513e-06, 6.7697e-14, 2.6620e-17, 4.7354e-20, 8.4945e-20,\n",
" 7.0451e-10, 2.1320e-12, 1.2574e-05],\n",
" [4.9052e-25, 9.4807e-23, 1.0000e+00, 4.0314e-21, 2.3924e-25, 1.6132e-32,\n",
" 1.3216e-19, 1.9167e-30, 3.1326e-26, 1.5693e-39, 3.6699e-27, 1.1513e-21,\n",
" 2.5881e-07, 3.5708e-24, 1.4833e-24, 1.4357e-07, 1.4013e-44, 2.7540e-23,\n",
" 1.2590e-23, 5.6606e-24, 1.9721e-23, 1.2849e-13, 3.6746e-35, 2.3322e-37,\n",
" 2.0883e-24, 5.1857e-28, 3.8505e-32],\n",
" [1.4766e-13, 5.4573e-26, 2.6174e-08, 8.8484e-05, 7.2608e-16, 2.7285e-40,\n",
" 5.3116e-04, 2.4367e-27, 9.4136e-18, 6.8406e-20, 1.3292e-12, 9.9894e-01,\n",
" 4.2736e-04, 5.0113e-18, 5.5990e-18, 5.2619e-20, 3.0610e-07, 3.1204e-12,\n",
" 5.6372e-14, 1.0915e-19, 3.0624e-19, 1.0511e-05, 4.6000e-31, 8.9617e-18,\n",
" 1.2070e-06, 6.0971e-10, 1.5755e-16],\n",
" [6.6248e-20, 1.1519e-42, 1.4801e-09, 4.5097e-14, 1.3133e-12, 0.0000e+00,\n",
" 3.3392e-11, 2.1073e-31, 1.6451e-26, 3.0954e-21, 3.6437e-10, 3.6631e-25,\n",
" 8.9187e-07, 7.8893e-24, 3.7848e-21, 9.4501e-18, 2.3735e-20, 6.5117e-30,\n",
" 1.9898e-42, 4.6795e-12, 4.4955e-19, 1.0000e+00, 8.9663e-34, 1.9880e-41,\n",
" 1.5975e-14, 6.6194e-12, 1.6366e-24],\n",
" [1.3479e-28, 7.0023e-38, 3.6616e-32, 1.7165e-20, 3.1452e-08, 1.9898e-43,\n",
" 1.1325e-10, 6.2044e-38, 4.5461e-37, 6.2867e-29, 7.8839e-23, 1.9618e-44,\n",
" 6.9054e-19, 7.3255e-29, 7.6490e-18, 4.1989e-27, 5.3466e-19, 2.9622e-34,\n",
" 0.0000e+00, 6.0975e-11, 1.9852e-29, 1.0000e+00, 5.9715e-29, 0.0000e+00,\n",
" 1.1249e-23, 5.7915e-24, 1.0314e-38],\n",
" [9.2334e-16, 1.1982e-23, 1.8237e-18, 2.7255e-06, 5.3642e-07, 3.3659e-35,\n",
" 9.9622e-01, 3.5788e-34, 1.0860e-17, 7.1806e-31, 3.7684e-03, 1.3917e-25,\n",
" 1.9584e-11, 1.1933e-24, 3.2485e-18, 9.0902e-09, 5.8568e-06, 8.2379e-10,\n",
" 1.5351e-10, 2.3523e-20, 3.7594e-22, 1.1572e-23, 5.5352e-38, 1.7066e-26,\n",
" 1.3043e-34, 8.2526e-24, 2.0988e-31]])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "5d0d5213",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y.shape"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "c00f76b2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 5, 13, 13, 1, 0, 15, 12, 9, 22, 9, 1, 0, 1, 22, 1, 0, 9, 19,\n",
" 1, 2, 5, 12, 12, 1, 0, 19, 15, 16, 8, 9, 1, 0])"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we have 32 training records, and 27 dimensions\n",
"Y"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "8e9205e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
" 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.arange(Y.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "f5b57e82",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([2.4027e-01, 2.4594e-09, 9.6322e-16, 8.2677e-01, 5.6912e-18, 2.9170e-14,\n",
" 1.0041e-07, 9.2010e-09, 1.5812e-28, 0.0000e+00, 7.0911e-14, 2.1457e-19,\n",
" 7.0665e-01, 5.9718e-21, 1.4916e-10, 4.2523e-22, 1.2262e-09, 7.8968e-03,\n",
" 4.5088e-21, 9.9985e-01, 1.1017e-36, 2.6719e-16, 3.3476e-06, 1.2857e-33,\n",
" 4.4522e-22, 5.6506e-06, 8.9681e-22, 1.4013e-44, 9.4136e-18, 3.0954e-21,\n",
" 7.0023e-38, 9.2334e-16])"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# for each record, get the prob of the correct char\n",
"# some prob is really low -> log likelihood :D \n",
"prob[torch.arange(Y.shape[0]), Y]"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "227429fe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(inf)"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# negative average log likelihood\n",
"# log of each prob, average, and then negative\n",
"loss = -prob[torch.arange(Y.shape[0]), Y].log().mean()\n",
"loss"
]
},
{
"cell_type": "markdown",
"id": "628a232a",
"metadata": {},
"source": [
"# 4. train the neural net with backpropagation"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "14c3b9a8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([32, 3]), torch.Size([32]))"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape, Y.shape"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "a4e2938d",
"metadata": {},
"outputs": [],
"source": [
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "e265acb2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3481"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(p.nelement() for p in parameters) # total number of parameters"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "b228f7d0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(17.7697)"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# forward pass\n",
"emb = C[X]\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"# softmax classification\n",
"logits = h @ W2 + b2\n",
"counts = logits.exp()\n",
"prob = counts / counts.sum(1, keepdims=True)\n",
"loss = - prob[torch.arange(Y.shape[0]), Y].log().mean()\n",
"loss"
]
},
{
"cell_type": "markdown",
"id": "b8500dd2",
"metadata": {},
"source": [
"## softmax loss from scratch v.s. F.cross_entropy\n",
"F.cross_entropy more efficient\n",
"- 1. numerically well-behaved, with more extreme valued weights, with exp() -> we could get very huge numbers. forward pass more efficient\n",
"- 2. backward pass is more efficient \n",
"- 3. forward pass is more efficient. we do not create intermediate tensors as in logits."
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "c832a78b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(17.7697)"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.cross_entropy(logits, Y)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "3e0ba1c2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0., 0., 0., nan])"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# numerical out of range in float\n",
"logits = torch.tensor([-100,-3,0,100])\n",
"counts = logits.exp()\n",
"prob = counts / counts.sum()\n",
"prob"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "b84b0ee2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3.7835e-44, 4.9787e-02, 1.0000e+00, inf])"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# e^100 is too BIG, out of range in float\n",
"counts"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "2274f7c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0000e+00, 1.4013e-45, 3.7835e-44, 1.0000e+00])"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# use normalization to linearly offset the logits without affecting the prob\n",
"# pytorch internally SUBTRACTS the maximal value in the tensor\n",
"logits = torch.tensor([-100,-3,0,100]) - 100\n",
"counts = logits.exp()\n",
"prob = counts / counts.sum()\n",
"prob"
]
},
{
"cell_type": "markdown",
"id": "42671fe8",
"metadata": {},
"source": [
"## train with backward()"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "04f208d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 17.76971435546875\n"
]
},
{
"ename": "RuntimeError",
"evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
"output_type": "error",
"traceback": [
"\u001b[0;31m-------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-66-b30885fcc503>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;31m# update the parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.pyenv/versions/offline-simulation-3.6.9/lib/python3.6/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 306\u001b[0m inputs=inputs)\n\u001b[0;32m--> 307\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.pyenv/versions/offline-simulation-3.6.9/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 154\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 155\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 156\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
]
}
],
"source": [
"# train the neural net\n",
"\n",
"learning_rate = 0.1\n",
"\n",
"for _ in range(10):\n",
" # forward pass\n",
" emb = C[X]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y)\n",
" print (_, loss.item())\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -learning_rate * p.grad"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "8331f712",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 17.76971435546875\n",
"1 9.42310619354248\n",
"2 8.315447807312012\n",
"3 6.388716220855713\n",
"4 3.51501202583313\n",
"5 4.386330604553223\n",
"6 2.32904052734375\n",
"7 1.8342325687408447\n",
"8 2.4214775562286377\n",
"9 2.067505359649658\n"
]
}
],
"source": [
"# make sure to set requires_grad = true, otherwise will complain\n",
"for p in parameters:\n",
" p.requires_grad=True\n",
" \n",
"learning_rate = 0.5\n",
"\n",
"for _ in range(10):\n",
" # forward pass\n",
" emb = C[X]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y)\n",
" print (_, loss.item())\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -learning_rate * p.grad"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "d63e0161",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.max(\n",
"values=tensor([17.8802, 12.6695, 15.6311, 17.3821, 18.0480, 17.8802, 14.4196, 11.9789,\n",
" 14.2498, 13.5953, 12.4208, 15.7624, 17.8802, 17.1180, 14.2179, 15.5572,\n",
" 17.8802, 13.5108, 11.3627, 16.2610, 16.8473, 12.2349, 7.4047, 6.9497,\n",
" 13.6933, 17.8802, 16.9499, 13.7836, 9.7913, 14.5982, 18.5608, 12.1148],\n",
" grad_fn=<MaxBackward0>),\n",
"indices=tensor([15, 12, 13, 1, 15, 15, 12, 12, 22, 9, 1, 0, 15, 15, 1, 15, 15, 15,\n",
" 15, 2, 5, 12, 12, 9, 0, 15, 15, 12, 8, 9, 1, 1]))"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we overfit the model with 32 data points and 3k parameters!\n",
"# torch reports the max value and the index\n",
"logits.max(1)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "357e8783",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 5, 13, 13, 1, 0, 15, 12, 9, 22, 9, 1, 0, 1, 22, 1, 0, 9, 19,\n",
" 1, 2, 5, 12, 12, 1, 0, 19, 15, 16, 8, 9, 1, 0])"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y\n",
"# seems we predict super well\n",
"# we will not get 0 loss, because many words can be the starting char following \"...\""
]
},
{
"cell_type": "markdown",
"id": "af33602f",
"metadata": {},
"source": [
"## train with all words"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "01930ea1",
"metadata": {},
"outputs": [],
"source": [
"# build the dataset\n",
"\n",
"# context to take, input X \n",
"block_size = 3\n",
"X, Y = [], []\n",
"for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix]\n",
" \n",
"X=torch.tensor(X)\n",
"Y=torch.tensor(Y)\n",
"\n",
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "41732590",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([228146, 3]), torch.Size([228146]))"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape, Y.shape"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "e6e0fcd3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 19.505226135253906\n",
"1 14.167920112609863\n",
"2 11.403884887695312\n",
"3 10.201157569885254\n",
"4 8.886395454406738\n",
"5 7.84780216217041\n",
"6 7.272636890411377\n",
"7 6.522163391113281\n",
"8 6.065618515014648\n",
"9 5.838536739349365\n"
]
}
],
"source": [
"learning_rate = 0.5\n",
"\n",
"for _ in range(10):\n",
" # forward pass\n",
" emb = C[X]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y)\n",
" print (_, loss.item())\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -learning_rate * p.grad"
]
},
{
"cell_type": "markdown",
"id": "098cda74",
"metadata": {},
"source": [
"# 5. mini-batch\n",
"- using all data in forward and backward pass takes quite a long time in each iteration because each time we go through all the training examples!\n",
"- mini-batch randomly selects some data in each iteration for forward and backward pass\n",
"- pros: much faster in each iteration\n",
"- cons: less stable, may take longer time for loss to converge "
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "be243227",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 48049, 83229, 192839, 189260, 65187, 101329, 7717, 41259, 153683,\n",
" 122018, 5189, 147134, 29649, 101782, 192637, 109671, 132041, 63528,\n",
" 181138, 45778, 48738, 137578, 62475, 136876, 154447, 4831, 186918,\n",
" 131601, 148080, 94904, 126630, 193761])"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# generate index for the starting number \n",
"# generate 32 of them\n",
"batch_size = 32\n",
"# Returns a tensor filled with random integers generated uniformly\n",
"torch.randint(0, X.shape[0], (batch_size,))"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "651c0f50",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32])"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.randint(0, X.shape[0], (batch_size,)).shape"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "a74a6e27",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3])"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[torch.randint(0, X.shape[0], (32,))].shape"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "f552adb7",
"metadata": {},
"outputs": [],
"source": [
"# build the dataset\n",
"\n",
"# context to take, input X \n",
"block_size = 3\n",
"X, Y = [], []\n",
"for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix]\n",
" \n",
"X=torch.tensor(X)\n",
"Y=torch.tensor(Y)\n",
"\n",
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "22dbc71c",
"metadata": {},
"outputs": [],
"source": [
"# a good learning rate is around 1, 10 is too big, 0.01 is too slow in the beginning, but good towards later stage\n",
"learning_rate = 0.1\n",
"\n",
"# much faster\n",
"# gradient is less reliable than the whole dataset\n",
"# direction is good enough \n",
"# update with small amount of data for multiple steps > update with all the data with fewer steps\n",
"for _ in range(10000):\n",
" # mini-batch index\n",
" ix = torch.randint(0, X.shape[0], (batch_size,))\n",
" \n",
" # forward pass\n",
" emb = C[X[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -learning_rate * p.grad"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "c6536cdb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7195205688476562\n"
]
}
],
"source": [
"print (loss.item())"
]
},
{
"cell_type": "markdown",
"id": "eee2bbeb",
"metadata": {},
"source": [
"# 6. find a reasonable learning rate"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "9df1bea7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0011,\n",
" 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,\n",
" 0.0011, 0.0011, 0.0011, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012,\n",
" 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0013, 0.0013, 0.0013,\n",
" 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0014,\n",
" 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014,\n",
" 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,\n",
" 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,\n",
" 0.0016, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017,\n",
" 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019,\n",
" 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020,\n",
" 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021, 0.0021, 0.0021, 0.0021,\n",
" 0.0021, 0.0021, 0.0021, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022,\n",
" 0.0022, 0.0023, 0.0023, 0.0023, 0.0023, 0.0023, 0.0023, 0.0024, 0.0024,\n",
" 0.0024, 0.0024, 0.0024, 0.0024, 0.0025, 0.0025, 0.0025, 0.0025, 0.0025,\n",
" 0.0025, 0.0026, 0.0026, 0.0026, 0.0026, 0.0026, 0.0027, 0.0027, 0.0027,\n",
" 0.0027, 0.0027, 0.0027, 0.0028, 0.0028, 0.0028, 0.0028, 0.0028, 0.0029,\n",
" 0.0029, 0.0029, 0.0029, 0.0029, 0.0030, 0.0030, 0.0030, 0.0030, 0.0030,\n",
" 0.0031, 0.0031, 0.0031, 0.0031, 0.0032, 0.0032, 0.0032, 0.0032, 0.0032,\n",
" 0.0033, 0.0033, 0.0033, 0.0033, 0.0034, 0.0034, 0.0034, 0.0034, 0.0034,\n",
" 0.0035, 0.0035, 0.0035, 0.0035, 0.0036, 0.0036, 0.0036, 0.0036, 0.0037,\n",
" 0.0037, 0.0037, 0.0037, 0.0038, 0.0038, 0.0038, 0.0039, 0.0039, 0.0039,\n",
" 0.0039, 0.0040, 0.0040, 0.0040, 0.0040, 0.0041, 0.0041, 0.0041, 0.0042,\n",
" 0.0042, 0.0042, 0.0042, 0.0043, 0.0043, 0.0043, 0.0044, 0.0044, 0.0044,\n",
" 0.0045, 0.0045, 0.0045, 0.0045, 0.0046, 0.0046, 0.0046, 0.0047, 0.0047,\n",
" 0.0047, 0.0048, 0.0048, 0.0048, 0.0049, 0.0049, 0.0049, 0.0050, 0.0050,\n",
" 0.0050, 0.0051, 0.0051, 0.0051, 0.0052, 0.0052, 0.0053, 0.0053, 0.0053,\n",
" 0.0054, 0.0054, 0.0054, 0.0055, 0.0055, 0.0056, 0.0056, 0.0056, 0.0057,\n",
" 0.0057, 0.0058, 0.0058, 0.0058, 0.0059, 0.0059, 0.0060, 0.0060, 0.0060,\n",
" 0.0061, 0.0061, 0.0062, 0.0062, 0.0062, 0.0063, 0.0063, 0.0064, 0.0064,\n",
" 0.0065, 0.0065, 0.0066, 0.0066, 0.0067, 0.0067, 0.0067, 0.0068, 0.0068,\n",
" 0.0069, 0.0069, 0.0070, 0.0070, 0.0071, 0.0071, 0.0072, 0.0072, 0.0073,\n",
" 0.0073, 0.0074, 0.0074, 0.0075, 0.0075, 0.0076, 0.0076, 0.0077, 0.0077,\n",
" 0.0078, 0.0079, 0.0079, 0.0080, 0.0080, 0.0081, 0.0081, 0.0082, 0.0082,\n",
" 0.0083, 0.0084, 0.0084, 0.0085, 0.0085, 0.0086, 0.0086, 0.0087, 0.0088,\n",
" 0.0088, 0.0089, 0.0090, 0.0090, 0.0091, 0.0091, 0.0092, 0.0093, 0.0093,\n",
" 0.0094, 0.0095, 0.0095, 0.0096, 0.0097, 0.0097, 0.0098, 0.0099, 0.0099,\n",
" 0.0100, 0.0101, 0.0101, 0.0102, 0.0103, 0.0104, 0.0104, 0.0105, 0.0106,\n",
" 0.0106, 0.0107, 0.0108, 0.0109, 0.0109, 0.0110, 0.0111, 0.0112, 0.0112,\n",
" 0.0113, 0.0114, 0.0115, 0.0116, 0.0116, 0.0117, 0.0118, 0.0119, 0.0120,\n",
" 0.0121, 0.0121, 0.0122, 0.0123, 0.0124, 0.0125, 0.0126, 0.0127, 0.0127,\n",
" 0.0128, 0.0129, 0.0130, 0.0131, 0.0132, 0.0133, 0.0134, 0.0135, 0.0136,\n",
" 0.0137, 0.0137, 0.0138, 0.0139, 0.0140, 0.0141, 0.0142, 0.0143, 0.0144,\n",
" 0.0145, 0.0146, 0.0147, 0.0148, 0.0149, 0.0150, 0.0151, 0.0152, 0.0154,\n",
" 0.0155, 0.0156, 0.0157, 0.0158, 0.0159, 0.0160, 0.0161, 0.0162, 0.0163,\n",
" 0.0165, 0.0166, 0.0167, 0.0168, 0.0169, 0.0170, 0.0171, 0.0173, 0.0174,\n",
" 0.0175, 0.0176, 0.0178, 0.0179, 0.0180, 0.0181, 0.0182, 0.0184, 0.0185,\n",
" 0.0186, 0.0188, 0.0189, 0.0190, 0.0192, 0.0193, 0.0194, 0.0196, 0.0197,\n",
" 0.0198, 0.0200, 0.0201, 0.0202, 0.0204, 0.0205, 0.0207, 0.0208, 0.0210,\n",
" 0.0211, 0.0212, 0.0214, 0.0215, 0.0217, 0.0218, 0.0220, 0.0221, 0.0223,\n",
" 0.0225, 0.0226, 0.0228, 0.0229, 0.0231, 0.0232, 0.0234, 0.0236, 0.0237,\n",
" 0.0239, 0.0241, 0.0242, 0.0244, 0.0246, 0.0247, 0.0249, 0.0251, 0.0253,\n",
" 0.0254, 0.0256, 0.0258, 0.0260, 0.0261, 0.0263, 0.0265, 0.0267, 0.0269,\n",
" 0.0271, 0.0273, 0.0274, 0.0276, 0.0278, 0.0280, 0.0282, 0.0284, 0.0286,\n",
" 0.0288, 0.0290, 0.0292, 0.0294, 0.0296, 0.0298, 0.0300, 0.0302, 0.0304,\n",
" 0.0307, 0.0309, 0.0311, 0.0313, 0.0315, 0.0317, 0.0320, 0.0322, 0.0324,\n",
" 0.0326, 0.0328, 0.0331, 0.0333, 0.0335, 0.0338, 0.0340, 0.0342, 0.0345,\n",
" 0.0347, 0.0350, 0.0352, 0.0354, 0.0357, 0.0359, 0.0362, 0.0364, 0.0367,\n",
" 0.0369, 0.0372, 0.0375, 0.0377, 0.0380, 0.0382, 0.0385, 0.0388, 0.0390,\n",
" 0.0393, 0.0396, 0.0399, 0.0401, 0.0404, 0.0407, 0.0410, 0.0413, 0.0416,\n",
" 0.0418, 0.0421, 0.0424, 0.0427, 0.0430, 0.0433, 0.0436, 0.0439, 0.0442,\n",
" 0.0445, 0.0448, 0.0451, 0.0455, 0.0458, 0.0461, 0.0464, 0.0467, 0.0471,\n",
" 0.0474, 0.0477, 0.0480, 0.0484, 0.0487, 0.0491, 0.0494, 0.0497, 0.0501,\n",
" 0.0504, 0.0508, 0.0511, 0.0515, 0.0518, 0.0522, 0.0526, 0.0529, 0.0533,\n",
" 0.0537, 0.0540, 0.0544, 0.0548, 0.0552, 0.0556, 0.0559, 0.0563, 0.0567,\n",
" 0.0571, 0.0575, 0.0579, 0.0583, 0.0587, 0.0591, 0.0595, 0.0599, 0.0604,\n",
" 0.0608, 0.0612, 0.0616, 0.0621, 0.0625, 0.0629, 0.0634, 0.0638, 0.0642,\n",
" 0.0647, 0.0651, 0.0656, 0.0660, 0.0665, 0.0670, 0.0674, 0.0679, 0.0684,\n",
" 0.0688, 0.0693, 0.0698, 0.0703, 0.0708, 0.0713, 0.0718, 0.0723, 0.0728,\n",
" 0.0733, 0.0738, 0.0743, 0.0748, 0.0753, 0.0758, 0.0764, 0.0769, 0.0774,\n",
" 0.0780, 0.0785, 0.0790, 0.0796, 0.0802, 0.0807, 0.0813, 0.0818, 0.0824,\n",
" 0.0830, 0.0835, 0.0841, 0.0847, 0.0853, 0.0859, 0.0865, 0.0871, 0.0877,\n",
" 0.0883, 0.0889, 0.0895, 0.0901, 0.0908, 0.0914, 0.0920, 0.0927, 0.0933,\n",
" 0.0940, 0.0946, 0.0953, 0.0959, 0.0966, 0.0973, 0.0979, 0.0986, 0.0993,\n",
" 0.1000, 0.1007, 0.1014, 0.1021, 0.1028, 0.1035, 0.1042, 0.1050, 0.1057,\n",
" 0.1064, 0.1072, 0.1079, 0.1087, 0.1094, 0.1102, 0.1109, 0.1117, 0.1125,\n",
" 0.1133, 0.1140, 0.1148, 0.1156, 0.1164, 0.1172, 0.1181, 0.1189, 0.1197,\n",
" 0.1205, 0.1214, 0.1222, 0.1231, 0.1239, 0.1248, 0.1256, 0.1265, 0.1274,\n",
" 0.1283, 0.1292, 0.1301, 0.1310, 0.1319, 0.1328, 0.1337, 0.1346, 0.1356,\n",
" 0.1365, 0.1374, 0.1384, 0.1394, 0.1403, 0.1413, 0.1423, 0.1433, 0.1443,\n",
" 0.1453, 0.1463, 0.1473, 0.1483, 0.1493, 0.1504, 0.1514, 0.1525, 0.1535,\n",
" 0.1546, 0.1557, 0.1567, 0.1578, 0.1589, 0.1600, 0.1611, 0.1623, 0.1634,\n",
" 0.1645, 0.1657, 0.1668, 0.1680, 0.1691, 0.1703, 0.1715, 0.1727, 0.1739,\n",
" 0.1751, 0.1763, 0.1775, 0.1788, 0.1800, 0.1812, 0.1825, 0.1838, 0.1850,\n",
" 0.1863, 0.1876, 0.1889, 0.1902, 0.1916, 0.1929, 0.1942, 0.1956, 0.1969,\n",
" 0.1983, 0.1997, 0.2010, 0.2024, 0.2038, 0.2053, 0.2067, 0.2081, 0.2096,\n",
" 0.2110, 0.2125, 0.2140, 0.2154, 0.2169, 0.2184, 0.2200, 0.2215, 0.2230,\n",
" 0.2246, 0.2261, 0.2277, 0.2293, 0.2309, 0.2325, 0.2341, 0.2357, 0.2373,\n",
" 0.2390, 0.2406, 0.2423, 0.2440, 0.2457, 0.2474, 0.2491, 0.2508, 0.2526,\n",
" 0.2543, 0.2561, 0.2579, 0.2597, 0.2615, 0.2633, 0.2651, 0.2669, 0.2688,\n",
" 0.2707, 0.2725, 0.2744, 0.2763, 0.2783, 0.2802, 0.2821, 0.2841, 0.2861,\n",
" 0.2880, 0.2900, 0.2921, 0.2941, 0.2961, 0.2982, 0.3002, 0.3023, 0.3044,\n",
" 0.3065, 0.3087, 0.3108, 0.3130, 0.3151, 0.3173, 0.3195, 0.3217, 0.3240,\n",
" 0.3262, 0.3285, 0.3308, 0.3331, 0.3354, 0.3377, 0.3400, 0.3424, 0.3448,\n",
" 0.3472, 0.3496, 0.3520, 0.3544, 0.3569, 0.3594, 0.3619, 0.3644, 0.3669,\n",
" 0.3695, 0.3720, 0.3746, 0.3772, 0.3798, 0.3825, 0.3851, 0.3878, 0.3905,\n",
" 0.3932, 0.3959, 0.3987, 0.4014, 0.4042, 0.4070, 0.4098, 0.4127, 0.4155,\n",
" 0.4184, 0.4213, 0.4243, 0.4272, 0.4302, 0.4331, 0.4362, 0.4392, 0.4422,\n",
" 0.4453, 0.4484, 0.4515, 0.4546, 0.4578, 0.4610, 0.4642, 0.4674, 0.4706,\n",
" 0.4739, 0.4772, 0.4805, 0.4838, 0.4872, 0.4906, 0.4940, 0.4974, 0.5008,\n",
" 0.5043, 0.5078, 0.5113, 0.5149, 0.5185, 0.5221, 0.5257, 0.5293, 0.5330,\n",
" 0.5367, 0.5404, 0.5442, 0.5479, 0.5517, 0.5556, 0.5594, 0.5633, 0.5672,\n",
" 0.5712, 0.5751, 0.5791, 0.5831, 0.5872, 0.5913, 0.5954, 0.5995, 0.6036,\n",
" 0.6078, 0.6120, 0.6163, 0.6206, 0.6249, 0.6292, 0.6336, 0.6380, 0.6424,\n",
" 0.6469, 0.6513, 0.6559, 0.6604, 0.6650, 0.6696, 0.6743, 0.6789, 0.6837,\n",
" 0.6884, 0.6932, 0.6980, 0.7028, 0.7077, 0.7126, 0.7176, 0.7225, 0.7275,\n",
" 0.7326, 0.7377, 0.7428, 0.7480, 0.7531, 0.7584, 0.7636, 0.7689, 0.7743,\n",
" 0.7796, 0.7850, 0.7905, 0.7960, 0.8015, 0.8071, 0.8127, 0.8183, 0.8240,\n",
" 0.8297, 0.8355, 0.8412, 0.8471, 0.8530, 0.8589, 0.8648, 0.8708, 0.8769,\n",
" 0.8830, 0.8891, 0.8953, 0.9015, 0.9077, 0.9140, 0.9204, 0.9268, 0.9332,\n",
" 0.9397, 0.9462, 0.9528, 0.9594, 0.9660, 0.9727, 0.9795, 0.9863, 0.9931,\n",
" 1.0000])"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# hyperparameter tuning\n",
"lre = torch.linspace(-3, 0, 1000)\n",
"lrs = 10**lre # search between 0.001 to 1\n",
"lrs"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "8d1e7860",
"metadata": {},
"outputs": [],
"source": [
"# build the dataset\n",
"\n",
"# context to take, input X \n",
"block_size = 3\n",
"X, Y = [], []\n",
"for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix]\n",
" \n",
"X=torch.tensor(X)\n",
"Y=torch.tensor(Y)\n",
"\n",
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "4fca2cf3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20.058767318725586\n",
"18.386871337890625\n",
"20.938833236694336\n",
"18.30322265625\n",
"18.207124710083008\n",
"20.926156997680664\n",
"19.694604873657227\n",
"20.796886444091797\n",
"18.43265151977539\n",
"17.471906661987305\n",
"17.41988182067871\n",
"18.475698471069336\n",
"21.594234466552734\n",
"20.208314895629883\n",
"18.260589599609375\n",
"20.787334442138672\n",
"18.640947341918945\n",
"21.286579132080078\n",
"19.08941650390625\n",
"19.984092712402344\n",
"19.909046173095703\n",
"17.2227725982666\n",
"18.984329223632812\n",
"18.248714447021484\n",
"19.701366424560547\n",
"16.733842849731445\n",
"17.595561981201172\n",
"19.489112854003906\n",
"18.0367488861084\n",
"17.623903274536133\n",
"18.184213638305664\n",
"18.172489166259766\n",
"16.67631721496582\n",
"19.491979598999023\n",
"18.97934913635254\n",
"21.72498321533203\n",
"19.489974975585938\n",
"16.01494789123535\n",
"21.133420944213867\n",
"18.17195701599121\n",
"19.101388931274414\n",
"19.54134178161621\n",
"17.296428680419922\n",
"17.832054138183594\n",
"18.379337310791016\n",
"17.681211471557617\n",
"16.202953338623047\n",
"19.22760772705078\n",
"16.480175018310547\n",
"17.625823974609375\n",
"17.91048240661621\n",
"20.03180694580078\n",
"15.48385238647461\n",
"20.130367279052734\n",
"15.727002143859863\n",
"19.727458953857422\n",
"17.598100662231445\n",
"18.688796997070312\n",
"18.924501419067383\n",
"15.734272956848145\n",
"19.930004119873047\n",
"16.35295867919922\n",
"14.893880844116211\n",
"19.13114356994629\n",
"17.03234100341797\n",
"17.192468643188477\n",
"17.264965057373047\n",
"15.996824264526367\n",
"13.68630313873291\n",
"18.991613388061523\n",
"16.458696365356445\n",
"18.184864044189453\n",
"18.45516586303711\n",
"19.210384368896484\n",
"19.555805206298828\n",
"19.245872497558594\n",
"15.916244506835938\n",
"18.539262771606445\n",
"15.675761222839355\n",
"15.922590255737305\n",
"18.28230094909668\n",
"18.379098892211914\n",
"17.07571792602539\n",
"17.531736373901367\n",
"16.743484497070312\n",
"18.085182189941406\n",
"19.677038192749023\n",
"16.05499267578125\n",
"16.966365814208984\n",
"16.9460391998291\n",
"16.56806755065918\n",
"17.405820846557617\n",
"17.567480087280273\n",
"13.594080924987793\n",
"15.556037902832031\n",
"13.114115715026855\n",
"13.831795692443848\n",
"15.224092483520508\n",
"16.704626083374023\n",
"14.737277030944824\n",
"18.403032302856445\n",
"16.741727828979492\n",
"15.251632690429688\n",
"15.746190071105957\n",
"15.956998825073242\n",
"21.602508544921875\n",
"14.20954418182373\n",
"14.598978042602539\n",
"17.341718673706055\n",
"16.25075340270996\n",
"15.769144058227539\n",
"16.367794036865234\n",
"16.197172164916992\n",
"18.48675537109375\n",
"17.168405532836914\n",
"14.934187889099121\n",
"17.833589553833008\n",
"17.77256202697754\n",
"17.94159698486328\n",
"17.360082626342773\n",
"14.95344352722168\n",
"13.779696464538574\n",
"17.619510650634766\n",
"17.39665985107422\n",
"17.34027671813965\n",
"15.9515380859375\n",
"18.24655532836914\n",
"14.980691909790039\n",
"15.100189208984375\n",
"18.663522720336914\n",
"14.429710388183594\n",
"14.762701988220215\n",
"15.34671401977539\n",
"15.603941917419434\n",
"17.844085693359375\n",
"14.396878242492676\n",
"16.44132423400879\n",
"14.74105453491211\n",
"14.460693359375\n",
"14.309600830078125\n",
"14.838363647460938\n",
"16.653427124023438\n",
"13.572357177734375\n",
"14.395666122436523\n",
"18.12348175048828\n",
"14.245956420898438\n",
"15.573516845703125\n",
"15.885431289672852\n",
"14.404387474060059\n",
"17.079296112060547\n",
"15.201760292053223\n",
"15.457757949829102\n",
"13.696967124938965\n",
"15.21992015838623\n",
"13.501822471618652\n",
"15.713217735290527\n",
"17.039230346679688\n",
"15.352463722229004\n",
"13.331183433532715\n",
"14.037686347961426\n",
"14.475912094116211\n",
"15.033218383789062\n",
"14.212654113769531\n",
"15.011879920959473\n",
"13.989218711853027\n",
"14.684889793395996\n",
"16.003490447998047\n",
"13.3534517288208\n",
"18.113367080688477\n",
"15.766525268554688\n",
"13.79685115814209\n",
"15.681324005126953\n",
"14.373132705688477\n",
"13.612071990966797\n",
"14.888436317443848\n",
"13.91701602935791\n",
"14.809433937072754\n",
"15.310776710510254\n",
"14.299510955810547\n",
"15.394023895263672\n",
"16.06818962097168\n",
"13.481806755065918\n",
"14.091451644897461\n",
"17.476533889770508\n",
"13.11625862121582\n",
"15.260262489318848\n",
"14.466033935546875\n",
"14.652369499206543\n",
"15.203973770141602\n",
"12.611348152160645\n",
"13.802355766296387\n",
"13.701678276062012\n",
"12.002172470092773\n",
"17.47108268737793\n",
"12.778606414794922\n",
"16.25493621826172\n",
"12.363943099975586\n",
"14.833003997802734\n",
"12.832207679748535\n",
"15.066575050354004\n",
"14.566367149353027\n",
"14.7083158493042\n",
"11.102749824523926\n",
"15.345829963684082\n",
"14.21712875366211\n",
"10.669098854064941\n",
"14.981069564819336\n",
"12.768899917602539\n",
"14.53537368774414\n",
"15.360010147094727\n",
"12.52229118347168\n",
"15.298910140991211\n",
"12.180039405822754\n",
"15.207792282104492\n",
"12.739935874938965\n",
"14.423806190490723\n",
"12.861767768859863\n",
"14.22921085357666\n",
"10.41986083984375\n",
"13.890667915344238\n",
"10.88259220123291\n",
"13.321094512939453\n",
"11.663039207458496\n",
"11.932415008544922\n",
"10.430420875549316\n",
"10.10925579071045\n",
"11.853407859802246\n",
"10.79029369354248\n",
"14.451656341552734\n",
"14.122856140136719\n",
"13.67242431640625\n",
"13.091327667236328\n",
"13.005173683166504\n",
"12.685237884521484\n",
"13.005620956420898\n",
"13.553939819335938\n",
"12.267475128173828\n",
"14.190502166748047\n",
"11.30736255645752\n",
"10.914015769958496\n",
"12.445353507995605\n",
"10.342656135559082\n",
"12.580787658691406\n",
"10.81489372253418\n",
"10.514963150024414\n",
"12.052750587463379\n",
"12.980908393859863\n",
"12.380534172058105\n",
"11.733548164367676\n",
"14.590177536010742\n",
"11.620329856872559\n",
"11.907234191894531\n",
"13.085349082946777\n",
"14.325329780578613\n",
"10.8388032913208\n",
"14.467137336730957\n",
"9.308248519897461\n",
"8.384876251220703\n",
"12.706113815307617\n",
"14.621389389038086\n",
"13.659004211425781\n",
"12.338562965393066\n",
"14.114700317382812\n",
"11.619922637939453\n",
"10.123787879943848\n",
"11.096317291259766\n",
"12.064632415771484\n",
"12.706136703491211\n",
"11.947173118591309\n",
"13.065418243408203\n",
"13.177108764648438\n",
"11.022246360778809\n",
"11.563748359680176\n",
"10.7993803024292\n",
"12.816320419311523\n",
"10.871781349182129\n",
"10.823201179504395\n",
"11.811736106872559\n",
"11.892790794372559\n",
"10.869131088256836\n",
"11.294859886169434\n",
"9.91629409790039\n",
"10.768162727355957\n",
"9.785820960998535\n",
"11.019686698913574\n",
"14.884580612182617\n",
"8.939095497131348\n",
"11.121686935424805\n",
"9.681001663208008\n",
"14.208605766296387\n",
"10.606358528137207\n",
"10.238022804260254\n",
"14.800966262817383\n",
"9.842073440551758\n",
"10.831656455993652\n",
"11.870290756225586\n",
"11.204913139343262\n",
"11.685627937316895\n",
"14.99278736114502\n",
"9.216063499450684\n",
"9.720016479492188\n",
"10.127654075622559\n",
"10.670577049255371\n",
"11.96092700958252\n",
"10.001982688903809\n",
"9.47696590423584\n",
"10.871495246887207\n",
"9.49959659576416\n",
"7.812280654907227\n",
"10.504291534423828\n",
"9.808164596557617\n",
"12.335737228393555\n",
"8.316473007202148\n",
"10.166391372680664\n",
"11.575448036193848\n",
"13.462331771850586\n",
"7.442794322967529\n",
"9.652787208557129\n",
"9.31308364868164\n",
"8.436732292175293\n",
"9.656270027160645\n",
"10.529017448425293\n",
"10.405807495117188\n",
"9.85356330871582\n",
"8.51788330078125\n",
"10.666929244995117\n",
"9.939424514770508\n",
"10.175727844238281\n",
"10.883666038513184\n",
"8.571637153625488\n",
"8.996121406555176\n",
"9.41479206085205\n",
"9.743539810180664\n",
"10.75446605682373\n",
"9.327347755432129\n",
"10.134511947631836\n",
"8.580121040344238\n",
"11.218292236328125\n",
"9.430485725402832\n",
"8.714287757873535\n",
"8.426252365112305\n",
"10.338098526000977\n",
"11.220523834228516\n",
"10.878255844116211\n",
"9.532553672790527\n",
"8.539838790893555\n",
"7.672667980194092\n",
"10.36310863494873\n",
"10.894793510437012\n",
"9.161340713500977\n",
"10.406136512756348\n",
"8.765401840209961\n",
"11.612442016601562\n",
"8.186119079589844\n",
"11.474461555480957\n",
"9.826943397521973\n",
"7.516147136688232\n",
"9.676780700683594\n",
"7.89693546295166\n",
"8.4464693069458\n",
"9.113298416137695\n",
"9.641304969787598\n",
"9.967488288879395\n",
"10.455885887145996\n",
"8.237849235534668\n",
"7.047585487365723\n",
"7.616254806518555\n",
"9.337509155273438\n",
"9.760736465454102\n",
"8.914117813110352\n",
"8.533182144165039\n",
"8.845060348510742\n",
"8.286819458007812\n",
"7.084433555603027\n",
"8.056537628173828\n",
"7.038687705993652\n",
"10.185018539428711\n",
"9.052075386047363\n",
"8.2262601852417\n",
"7.8366827964782715\n",
"8.655904769897461\n",
"8.646445274353027\n",
"10.069315910339355\n",
"8.779277801513672\n",
"8.074405670166016\n",
"7.853659152984619\n",
"8.303882598876953\n",
"8.005335807800293\n",
"6.686523914337158\n",
"5.987373352050781\n",
"7.754853248596191\n",
"9.928034782409668\n",
"6.959237098693848\n",
"7.998958587646484\n",
"7.160789489746094\n",
"9.51343059539795\n",
"7.81441068649292\n",
"9.714046478271484\n",
"9.615175247192383\n",
"7.681924819946289\n",
"8.25478744506836\n",
"8.42897891998291\n",
"8.113512992858887\n",
"7.563684463500977\n",
"8.473864555358887\n",
"6.3443403244018555\n",
"9.098320007324219\n",
"7.37877893447876\n",
"6.578047275543213\n",
"7.377756118774414\n",
"8.60280990600586\n",
"7.422914981842041\n",
"7.379459857940674\n",
"8.935554504394531\n",
"6.389432430267334\n",
"8.996026039123535\n",
"8.94339370727539\n",
"5.5102949142456055\n",
"6.963716983795166\n",
"11.425223350524902\n",
"7.493052959442139\n",
"8.636194229125977\n",
"6.408836841583252\n",
"8.882610321044922\n",
"6.552365779876709\n",
"6.99999475479126\n",
"8.897416114807129\n",
"7.332728862762451\n",
"6.513792991638184\n",
"7.331733703613281\n",
"6.8848090171813965\n",
"8.904072761535645\n",
"6.285671710968018\n",
"6.721429824829102\n",
"8.188508987426758\n",
"8.023738861083984\n",
"8.223982810974121\n",
"7.631497859954834\n",
"6.868099689483643\n",
"7.845915794372559\n",
"8.994128227233887\n",
"6.873394012451172\n",
"6.575554847717285\n",
"5.681668758392334\n",
"9.183479309082031\n",
"7.379150390625\n",
"6.419661045074463\n",
"6.1244611740112305\n",
"7.120537281036377\n",
"7.353793144226074\n",
"6.679843902587891\n",
"6.7223920822143555\n",
"6.166895866394043\n",
"6.358139514923096\n",
"8.206326484680176\n",
"4.982204437255859\n",
"7.02251672744751\n",
"6.5761237144470215\n",
"6.247386932373047\n",
"5.406505107879639\n",
"5.372550010681152\n",
"5.643982410430908\n",
"4.830392837524414\n",
"5.968005180358887\n",
"6.492893218994141\n",
"4.8271164894104\n",
"7.547109603881836\n",
"5.955235481262207\n",
"6.241334438323975\n",
"7.30716609954834\n",
"7.0900068283081055\n",
"6.036428928375244\n",
"7.084930896759033\n",
"4.779949188232422\n",
"7.399024963378906\n",
"6.500931739807129\n",
"5.981712818145752\n",
"5.332479953765869\n",
"5.2537617683410645\n",
"5.506017208099365\n",
"6.235259056091309\n",
"6.222657680511475\n",
"4.586050510406494\n",
"6.005897521972656\n",
"6.667596817016602\n",
"5.921751499176025\n",
"5.109879016876221\n",
"3.7901201248168945\n",
"4.821964740753174\n",
"5.593708038330078\n",
"6.004891395568848\n",
"4.737380027770996\n",
"6.811749458312988\n",
"5.370182514190674\n",
"6.791818618774414\n",
"5.8135175704956055\n",
"6.2312235832214355\n",
"5.7961883544921875\n",
"5.895395755767822\n",
"7.342142105102539\n",
"5.725778102874756\n",
"4.989012718200684\n",
"4.228612422943115\n",
"4.350794315338135\n",
"5.3611674308776855\n",
"4.481586456298828\n",
"5.720825672149658\n",
"5.554401874542236\n",
"5.9270429611206055\n",
"4.660965919494629\n",
"5.52239465713501\n",
"5.524663925170898\n",
"4.811008930206299\n",
"5.698330879211426\n",
"6.13444185256958\n",
"4.305564880371094\n",
"4.850207328796387\n",
"4.782961845397949\n",
"5.758709907531738\n",
"5.040655136108398\n",
"4.9152045249938965\n",
"5.848385334014893\n",
"4.637724876403809\n",
"6.089776515960693\n",
"4.190829753875732\n",
"4.483164310455322\n",
"5.096348285675049\n",
"4.175036430358887\n",
"4.132568836212158\n",
"5.668691635131836\n",
"5.299254894256592\n",
"3.9058778285980225\n",
"6.583087921142578\n",
"5.95508337020874\n",
"5.55558967590332\n",
"3.896312952041626\n",
"5.0055389404296875\n",
"3.8073885440826416\n",
"4.782783031463623\n",
"5.341424942016602\n",
"5.220652103424072\n",
"4.3145551681518555\n",
"4.867645263671875\n",
"4.317051887512207\n",
"4.742574214935303\n",
"4.937697410583496\n",
"5.328518867492676\n",
"4.958920955657959\n",
"3.481821298599243\n",
"4.7215189933776855\n",
"5.35048770904541\n",
"4.197367191314697\n",
"4.604339122772217\n",
"4.23129940032959\n",
"4.756180763244629\n",
"4.884371757507324\n",
"5.155618190765381\n",
"4.574168682098389\n",
"3.391645669937134\n",
"3.3872737884521484\n",
"4.672423362731934\n",
"4.533355712890625\n",
"5.220793724060059\n",
"3.8227334022521973\n",
"3.7068281173706055\n",
"4.240579128265381\n",
"4.195023059844971\n",
"4.6160888671875\n",
"4.480436325073242\n",
"4.624906539916992\n",
"4.456013202667236\n",
"5.195305824279785\n",
"3.5553925037384033\n",
"3.7349674701690674\n",
"4.430811405181885\n",
"4.551756858825684\n",
"3.7298829555511475\n",
"4.410430431365967\n",
"3.553434371948242\n",
"3.4611258506774902\n",
"4.55808687210083\n",
"3.088973045349121\n",
"4.4551920890808105\n",
"3.5407283306121826\n",
"4.2899980545043945\n",
"4.096808433532715\n",
"3.5416300296783447\n",
"4.437430381774902\n",
"3.881909132003784\n",
"3.404031753540039\n",
"4.887137413024902\n",
"3.174898386001587\n",
"3.9722282886505127\n",
"4.0994086265563965\n",
"4.036005020141602\n",
"3.7181835174560547\n",
"3.022531747817993\n",
"4.7746100425720215\n",
"3.107983112335205\n",
"3.581955909729004\n",
"4.060939311981201\n",
"3.075693130493164\n",
"4.609142780303955\n",
"3.3816468715667725\n",
"5.256540298461914\n",
"3.156243324279785\n",
"4.870355129241943\n",
"4.980332851409912\n",
"4.01073694229126\n",
"3.294806957244873\n",
"4.303383827209473\n",
"3.302541732788086\n",
"3.692582845687866\n",
"4.456576824188232\n",
"3.896005630493164\n",
"3.8528528213500977\n",
"4.296661853790283\n",
"3.214301347732544\n",
"3.0818440914154053\n",
"3.5100314617156982\n",
"3.1930394172668457\n",
"2.9393343925476074\n",
"2.919846773147583\n",
"3.8025946617126465\n",
"2.892787456512451\n",
"4.956643581390381\n",
"2.888521671295166\n",
"3.342435359954834\n",
"3.0627057552337646\n",
"4.08694314956665\n",
"3.3463854789733887\n",
"4.060541152954102\n",
"2.9815173149108887\n",
"3.2821357250213623\n",
"3.0937447547912598\n",
"3.0083119869232178\n",
"3.954728126525879\n",
"3.924070119857788\n",
"3.2576563358306885\n",
"3.7045695781707764\n",
"3.2489681243896484\n",
"2.9568793773651123\n",
"3.5199520587921143\n",
"3.453016996383667\n",
"3.802501678466797\n",
"3.637922763824463\n",
"3.216110944747925\n",
"3.706007242202759\n",
"3.984416961669922\n",
"3.632723569869995\n",
"3.1989638805389404\n",
"3.6497514247894287\n",
"3.1191749572753906\n",
"4.295158386230469\n",
"2.5521702766418457\n",
"3.194572925567627\n",
"3.4275803565979004\n",
"3.4875407218933105\n",
"3.0968613624572754\n",
"2.905925750732422\n",
"2.9836156368255615\n",
"3.2466132640838623\n",
"3.4357643127441406\n",
"3.062093734741211\n",
"3.6897025108337402\n",
"4.675938129425049\n",
"3.194821357727051\n",
"3.059985876083374\n",
"2.917567491531372\n",
"3.519122838973999\n",
"3.1748576164245605\n",
"3.006070852279663\n",
"2.9797580242156982\n",
"3.3550615310668945\n",
"3.930544376373291\n",
"2.6503937244415283\n",
"3.325307607650757\n",
"3.9021999835968018\n",
"2.6885509490966797\n",
"3.1735334396362305\n",
"3.559260368347168\n",
"3.3719849586486816\n",
"3.2907395362854004\n",
"3.0547420978546143\n",
"3.8450896739959717\n",
"3.3199620246887207\n",
"3.0020692348480225\n",
"3.428286552429199\n",
"3.70822811126709\n",
"3.3926849365234375\n",
"3.313302516937256\n",
"2.6763691902160645\n",
"2.7028677463531494\n",
"3.20300555229187\n",
"3.1753597259521484\n",
"3.429269790649414\n",
"3.1132779121398926\n",
"2.4863622188568115\n",
"3.024989604949951\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.143970251083374\n",
"2.5822694301605225\n",
"3.616262674331665\n",
"3.358060598373413\n",
"3.2837748527526855\n",
"3.4139809608459473\n",
"3.3856797218322754\n",
"2.7517240047454834\n",
"2.959219455718994\n",
"3.2085742950439453\n",
"3.230847120285034\n",
"3.00468111038208\n",
"2.6261143684387207\n",
"3.411679983139038\n",
"2.6447250843048096\n",
"2.7005319595336914\n",
"3.232732057571411\n",
"2.8750433921813965\n",
"3.520543098449707\n",
"2.8991739749908447\n",
"2.8904154300689697\n",
"3.335700273513794\n",
"3.2348315715789795\n",
"3.2892212867736816\n",
"3.896378993988037\n",
"2.8900647163391113\n",
"3.0431835651397705\n",
"3.039491891860962\n",
"3.037858009338379\n",
"3.210955858230591\n",
"4.34665584564209\n",
"2.8819639682769775\n",
"3.726290225982666\n",
"3.7548255920410156\n",
"3.493394613265991\n",
"3.512141227722168\n",
"3.016710042953491\n",
"2.9070377349853516\n",
"3.4122490882873535\n",
"3.086726188659668\n",
"3.199303388595581\n",
"3.335646390914917\n",
"3.1801059246063232\n",
"3.533473491668701\n",
"3.0448317527770996\n",
"2.899057149887085\n",
"3.0487630367279053\n",
"3.3654282093048096\n",
"3.331571578979492\n",
"2.6786227226257324\n",
"3.215420722961426\n",
"2.7102489471435547\n",
"3.626858949661255\n",
"3.428629159927368\n",
"2.43825101852417\n",
"2.8668174743652344\n",
"3.504131555557251\n",
"2.7836339473724365\n",
"3.1075804233551025\n",
"3.2532730102539062\n",
"3.036126136779785\n",
"2.777801275253296\n",
"2.84604549407959\n",
"2.948404312133789\n",
"3.195976734161377\n",
"3.2230095863342285\n",
"3.0998973846435547\n",
"3.718081474304199\n",
"2.800168037414551\n",
"3.2489216327667236\n",
"2.8199150562286377\n",
"3.6910431385040283\n",
"3.173072338104248\n",
"2.7641613483428955\n",
"3.77432918548584\n",
"3.624133825302124\n",
"3.227541446685791\n",
"3.183627128601074\n",
"3.283186435699463\n",
"3.057158946990967\n",
"2.777653694152832\n",
"2.889397144317627\n",
"3.3341434001922607\n",
"3.4635066986083984\n",
"2.904613971710205\n",
"3.0934841632843018\n",
"3.5386648178100586\n",
"3.4154884815216064\n",
"3.043239116668701\n",
"3.262671947479248\n",
"4.166245937347412\n",
"2.961470127105713\n",
"2.8482401371002197\n",
"3.534893274307251\n",
"3.1096999645233154\n",
"2.7225348949432373\n",
"3.547924041748047\n",
"3.745229721069336\n",
"2.809553861618042\n",
"3.616903305053711\n",
"3.6580862998962402\n",
"3.534134864807129\n",
"2.620283365249634\n",
"2.9316813945770264\n",
"3.368833541870117\n",
"3.4300267696380615\n",
"3.5138981342315674\n",
"3.6709792613983154\n",
"3.0291519165039062\n",
"3.6491801738739014\n",
"4.69722318649292\n",
"3.6200473308563232\n",
"3.425041913986206\n",
"3.839419364929199\n",
"3.556222677230835\n",
"3.6515636444091797\n",
"3.271301746368408\n",
"3.6069674491882324\n",
"3.619744300842285\n",
"3.2222707271575928\n",
"3.6004958152770996\n",
"3.295170545578003\n",
"3.294151782989502\n",
"3.963306188583374\n",
"3.9888854026794434\n",
"3.395801544189453\n",
"2.7327609062194824\n",
"4.557826042175293\n",
"3.7835001945495605\n",
"4.577127456665039\n",
"3.4779865741729736\n",
"3.5828535556793213\n",
"3.474299192428589\n",
"3.554988145828247\n",
"3.217632532119751\n",
"3.7217812538146973\n",
"3.207597494125366\n",
"3.381564140319824\n",
"3.94753098487854\n",
"4.8308258056640625\n",
"5.170791149139404\n",
"3.3575708866119385\n",
"4.269369125366211\n",
"3.247758626937866\n",
"4.41501522064209\n",
"4.021042823791504\n",
"3.817215919494629\n",
"3.2845804691314697\n",
"4.472692489624023\n",
"4.064505100250244\n",
"3.219364881515503\n",
"3.844644784927368\n",
"4.002055644989014\n",
"3.275210380554199\n",
"3.7533915042877197\n",
"4.884897708892822\n",
"3.598559617996216\n",
"4.256794452667236\n",
"3.9089903831481934\n",
"3.7389872074127197\n",
"4.224829196929932\n",
"3.9468448162078857\n",
"3.9633235931396484\n",
"3.9781320095062256\n",
"4.462779521942139\n",
"3.4086098670959473\n",
"4.164345741271973\n",
"8.31614875793457\n",
"5.93451452255249\n",
"6.095202922821045\n",
"4.018470287322998\n",
"3.7814743518829346\n",
"4.7484130859375\n",
"4.075839042663574\n",
"4.838449001312256\n",
"4.355576992034912\n",
"4.112987518310547\n",
"4.829951286315918\n",
"4.8745436668396\n",
"4.683174133300781\n",
"4.704440593719482\n",
"3.4863169193267822\n",
"5.267629146575928\n",
"4.023910999298096\n",
"3.553920030593872\n",
"4.033780097961426\n",
"5.377976417541504\n",
"4.1329731941223145\n",
"5.838648796081543\n",
"4.945495128631592\n",
"4.220717906951904\n",
"6.133040904998779\n",
"4.624537944793701\n",
"6.009963512420654\n",
"4.677973747253418\n",
"4.602327823638916\n",
"3.3685643672943115\n",
"5.05299186706543\n",
"4.9118242263793945\n",
"4.44879674911499\n",
"5.753711223602295\n",
"5.209222316741943\n",
"4.082849025726318\n",
"4.3265061378479\n",
"4.385910987854004\n",
"3.9758224487304688\n",
"4.269911766052246\n",
"4.9230637550354\n",
"4.7449116706848145\n",
"4.327792644500732\n",
"3.6670405864715576\n",
"4.678028583526611\n",
"3.702824592590332\n",
"4.078690052032471\n",
"4.826587200164795\n",
"5.305018424987793\n",
"7.958984375\n",
"4.998692035675049\n",
"3.7893269062042236\n",
"3.986933708190918\n",
"3.4457900524139404\n",
"2.8971259593963623\n",
"4.211317539215088\n",
"4.968047618865967\n",
"5.523991107940674\n",
"4.809507846832275\n",
"5.05592679977417\n",
"6.213912487030029\n",
"6.041927337646484\n",
"6.665318012237549\n",
"6.162818431854248\n",
"5.069524765014648\n",
"5.342731475830078\n",
"6.58087682723999\n",
"8.372491836547852\n",
"7.336589336395264\n",
"8.737000465393066\n",
"8.807034492492676\n",
"7.782583236694336\n",
"5.720577239990234\n",
"7.279583930969238\n",
"6.570875644683838\n",
"6.362747669219971\n",
"7.10400915145874\n",
"5.268760681152344\n",
"5.338332653045654\n",
"5.195115566253662\n",
"6.703448295593262\n",
"5.327396869659424\n",
"5.858665943145752\n",
"7.968378067016602\n",
"7.84922981262207\n",
"6.56950569152832\n",
"7.11574649810791\n",
"7.103545188903809\n",
"8.86844253540039\n",
"7.60360860824585\n",
"6.428305149078369\n",
"5.785921573638916\n",
"7.077493190765381\n",
"7.3251953125\n",
"6.634866237640381\n",
"5.36044979095459\n",
"8.364887237548828\n",
"8.705479621887207\n",
"6.498984336853027\n",
"7.325703144073486\n",
"8.187382698059082\n",
"8.92408561706543\n",
"9.026152610778809\n",
"7.936668872833252\n",
"5.978590488433838\n",
"6.503759860992432\n",
"5.820070743560791\n",
"6.012407302856445\n",
"6.347898483276367\n",
"5.945157527923584\n",
"10.594411849975586\n",
"7.612778186798096\n",
"8.65031909942627\n",
"8.221137046813965\n",
"5.586372375488281\n",
"6.739910125732422\n",
"6.437129974365234\n",
"6.807755947113037\n",
"7.184710502624512\n",
"7.534907341003418\n",
"7.9743194580078125\n",
"11.786186218261719\n",
"8.150350570678711\n",
"9.01694393157959\n",
"9.777725219726562\n",
"10.811105728149414\n",
"10.679527282714844\n",
"8.601481437683105\n",
"12.968698501586914\n",
"7.7538743019104\n",
"8.0514554977417\n",
"8.91894245147705\n",
"8.04653263092041\n",
"8.193402290344238\n"
]
}
],
"source": [
"lri = []\n",
"lossi = []\n",
"\n",
"for i in range(1000):\n",
" # mini-batch\n",
" ix = torch.randint(0, X.shape[0], (batch_size,))\n",
" \n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[X[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" \n",
" lr = lrs[i]\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
" lri.append(lre[i])\n",
" lossi.append(loss)\n",
" \n",
" print (loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "2efaea53",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x129ff3518>]"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# in the beginning, as we increase the learning rate, we are not so stable \n",
"# somewhere around 10**-1 = 0.1\n",
"plt.plot(lri, lossi)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "ca8ce989",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.43825101852417\n"
]
}
],
"source": [
"# over the late stage, loss increases, diverged!\n",
"print (min(lossi).item())"
]
},
{
"cell_type": "markdown",
"id": "e07358e8",
"metadata": {},
"source": [
"# 7. cross-validation\n",
"train, valid, test split\n",
"\n",
"80%, 10%, 10%\n",
"\n",
"param, hyperparam, evaluate (few times)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "52c638a6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([182625, 3]) torch.Size([182625])\n",
"torch.Size([22655, 3]) torch.Size([22655])\n",
"torch.Size([22866, 3]) torch.Size([22866])\n"
]
}
],
"source": [
"def build_dataset(words):\n",
" block_size = 3\n",
" X, Y = [], []\n",
" for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix]\n",
"\n",
" X=torch.tensor(X)\n",
" Y=torch.tensor(Y)\n",
" \n",
" print(X.shape, Y.shape)\n",
" return X, Y\n",
"\n",
"import random\n",
"random.seed(42)\n",
"# shuffle\n",
"random.shuffle(words)\n",
"train_perc = 0.8\n",
"dev_perc = 0.1\n",
"test_perc = 0.1\n",
"\n",
"n1=int(train_perc*len(words))\n",
"n2=int((train_perc + dev_perc)*len(words))\n",
"\n",
"Xtr, Ytr = build_dataset(words[:n1])\n",
"Xdev, Ydev = build_dataset(words[n1:n2])\n",
"Xte, Yte = build_dataset(words[n2:])"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "5dabdf90",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([182625, 3]), torch.Size([182625]))"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# training dataset\n",
"Xtr.shape, Ytr.shape"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "e7b6f9ed",
"metadata": {},
"outputs": [],
"source": [
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "a1b009dc",
"metadata": {},
"outputs": [],
"source": [
"# find a reasonable learning rate \n",
"lri = []\n",
"lossi = []\n",
"learning_rate = 1\n",
"\n",
"for i in range(100000):\n",
" # mini-batch\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
" \n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[Xtr[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,6) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Ytr[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" \n",
" lr = 0.1 #lrs[i]\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
"# lri.append(lre[i])\n",
" lossi.append(loss)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "a4384ccb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 2.321897268295288\n",
"dev loss 2.329921007156372\n"
]
}
],
"source": [
"# compute loss on dev dataset\n",
"emb = C[Xtr] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ytr)\n",
"print ('train loss', loss.item())\n",
"\n",
"# can have even smaller loss with smaller learning rate\n",
"emb = C[Xdev] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ydev)\n",
"print ('dev loss', loss.item())"
]
},
{
"cell_type": "markdown",
"id": "e90a7368",
"metadata": {},
"source": [
"# 8. use a bigger hidden layer"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "e9cca6e5",
"metadata": {},
"outputs": [],
"source": [
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"batch_size = 32\n",
"n_output1 = 300 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2, n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "699aceed",
"metadata": {},
"outputs": [],
"source": [
"# run this a bit more times to see if we can reduce loss further\n",
"stepi = []\n",
"lossi = []\n",
"\n",
"for i in range(100000):\n",
" # mini-batch\n",
" ix = torch.randint(0, X.shape[0], (batch_size,))\n",
" \n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[X[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" \n",
" # update the parameters\n",
" # we could do learning rate decay\n",
" lr = 0.1 if i < 10000 else 0.01\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
" stepi.append(i)\n",
" lossi.append(loss)"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "00018c8d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x12ada6be0>]"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(stepi, lossi)\n",
"# there is flucturation because of mini-batch\n",
"# increase batch size to be more stable"
]
},
{
"cell_type": "code",
"execution_count": 92,
"id": "697a5cdb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 2.329075336456299\n",
"dev loss 2.3230931758880615\n"
]
}
],
"source": [
"# compute loss on dev dataset\n",
"emb = C[Xtr] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ytr)\n",
"print ('train loss', loss.item())\n",
"\n",
"# can have even smaller loss with smaller learning rate\n",
"emb = C[Xdev] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ydev)\n",
"print ('dev loss', loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 93,
"id": "c7fde9ee",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 576x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# visualize the embedding\n",
"\n",
"plt.figure(figsize=(8,8))\n",
"plt.scatter(C[:,0].data, C[:,1].data, s=200)\n",
"for i in range(C.shape[0]):\n",
" # plt.text(x,y, the actual text)\n",
" plt.text(C[i,0].item(), C[i,1].item(), itos[i], ha='center', va='center', color='white')\n",
"plt.grid('minor')\n",
"\n",
"# visualize the char with the 2-D embeddin\n",
"# C learns to separate stuff\n",
"# aeiou are similar "
]
},
{
"cell_type": "markdown",
"id": "ce5424d2",
"metadata": {},
"source": [
"# 9. increase embedding size"
]
},
{
"cell_type": "code",
"execution_count": 94,
"id": "4825573a",
"metadata": {},
"outputs": [],
"source": [
"Vdim=27\n",
"Cdim=10 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"batch_size = 32\n",
"n_output1 = 200 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2, n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "fe484cca",
"metadata": {},
"outputs": [],
"source": [
"# run this a bit more times\n",
" \n",
"# find a reasonable learning rate \n",
"lri = []\n",
"lossi = []\n",
"stepi = []"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "97328a1f",
"metadata": {},
"outputs": [],
"source": [
"for i in range(100000):\n",
" # mini-batch\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
" \n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[Xtr[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Ytr[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" \n",
" lr = 0.1 if i < 10000 else 0.01\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
"# lri.append(lre[i])\n",
" stepi.append(i)\n",
" lossi.append(loss.log10().item()) # plot log-loss to squash it in"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "91727b0a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x12b1601d0>]"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(stepi, lossi)\n",
"# there is flucturation because of mini-batch\n",
"# increase batch size to be more stable"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "6218b79e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 2.2576844692230225\n",
"dev loss 2.2767961025238037\n"
]
}
],
"source": [
"# compute loss on dev dataset\n",
"emb = C[Xtr] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ytr)\n",
"print ('train loss', loss.item())\n",
"\n",
"# can have even smaller loss with smaller learning rate\n",
"emb = C[Xdev] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ydev)\n",
"print ('dev loss', loss.item())"
]
},
{
"cell_type": "markdown",
"id": "b496f3e3",
"metadata": {},
"source": [
"# 10. generate new words with the trained model"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "4e888e6a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"careah.\n",
"ambyn.\n",
"vih.\n",
"jors.\n",
"revty.\n",
"salaysy.\n",
"jazhnen.\n",
"amerync.\n",
"kaqii.\n",
"nelania.\n",
"chaiiv.\n",
"kaleig.\n",
"halm.\n",
"joce.\n",
"quinn.\n",
"shon.\n",
"raiviani.\n",
"wanthon.\n",
"jaryni.\n",
"jace.\n"
]
}
],
"source": [
"g = torch.Generator().manual_seed(2147483647+10)\n",
"for _ in range(20):\n",
" \n",
" out = []\n",
" context = [0] * block_size # we start from ...\n",
" while True:\n",
" # embed current context using C\n",
" # 1 dim training set size \n",
" emb = C[torch.tensor([context])] # (1, block_size, d), d is the embedding vector C \n",
" h = torch.tanh(emb.view(1, -1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" probs = F.softmax(logits, dim=1)\n",
" # from the prob, sample by probability to get the next index \n",
" ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n",
" # shift to the next char\n",
" context = context[1:] + [ix]\n",
" out.append(ix)\n",
" if ix == 0:\n",
" break # stop \n",
" \n",
" print (''.join(itos[i] for i in out))"
]
},
{
"cell_type": "markdown",
"id": "3897ee18",
"metadata": {},
"source": [
"# 11. hyperparameter tuning"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "805a4253",
"metadata": {},
"outputs": [],
"source": [
"Cdims = [2, 30, 50]\n",
"n_output1s = [100, 200, 300]"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "fa12acef",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6min 14s, sys: 1min 35s, total: 7min 49s\n",
"Wall time: 7min 9s\n"
]
}
],
"source": [
"%%time\n",
"loss_list = []\n",
"\n",
"for Cdim in Cdims:\n",
" for n_output1 in n_output1s:\n",
" Vdim=27\n",
" block_size = 3 # hyperparamter\n",
" n_input1 = Cdim * block_size\n",
" batch_size = 32\n",
" n_input2 = n_output1\n",
" n_output2 = Vdim\n",
"\n",
" g = torch.Generator().manual_seed(2147483647)\n",
" C = torch.randn((Vdim, Cdim), generator=g)\n",
" W1 = torch.randn((n_input1,n_output1), generator=g)\n",
" b1 = torch.randn(n_output1, generator=g)\n",
" W2 = torch.randn((n_input2, n_output2), generator=g)\n",
" b2 = torch.randn(n_output2, generator=g)\n",
" parameters = [C, W1, b1, W2, b2]\n",
"\n",
" for p in parameters:\n",
" p.requires_grad=True\n",
" \n",
" for i in range(100000):\n",
" # mini-batch\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
"\n",
" # forward pass\n",
" emb = C[Xtr[ix]] \n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Ytr[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
"\n",
" loss.backward()\n",
" \n",
" lr = 0.1 if i < 10000 else 0.01\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
" # compute loss on dev dataset\n",
" emb = C[Xtr]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss_train = F.cross_entropy(logits, Ytr)\n",
"\n",
" # can have even smaller loss with smaller learning rate\n",
" emb = C[Xdev]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss_dev = F.cross_entropy(logits, Ydev)\n",
" \n",
" loss_list.append([Cdim, n_output1, loss_train.item(), loss_dev.item()])"
]
},
{
"cell_type": "code",
"execution_count": 124,
"id": "9712fb33",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 125,
"id": "030c17b5",
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(loss_list, columns = ['Cdim', 'hidden_layer_size','train loss', 'dev loss'])"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "2bb18918",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Cdim</th>\n",
" <th>hidden_layer_size</th>\n",
" <th>train loss</th>\n",
" <th>dev loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>50</td>\n",
" <td>100</td>\n",
" <td>2.215908</td>\n",
" <td>2.255365</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>2.226047</td>\n",
" <td>2.261118</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>50</td>\n",
" <td>200</td>\n",
" <td>2.183852</td>\n",
" <td>2.274030</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>30</td>\n",
" <td>200</td>\n",
" <td>2.195534</td>\n",
" <td>2.276445</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>50</td>\n",
" <td>300</td>\n",
" <td>2.179892</td>\n",
" <td>2.299338</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>30</td>\n",
" <td>300</td>\n",
" <td>2.198899</td>\n",
" <td>2.305203</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>300</td>\n",
" <td>2.326979</td>\n",
" <td>2.324543</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>200</td>\n",
" <td>2.333926</td>\n",
" <td>2.331151</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>100</td>\n",
" <td>2.368216</td>\n",
" <td>2.365333</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Cdim hidden_layer_size train loss dev loss\n",
"6 50 100 2.215908 2.255365\n",
"3 30 100 2.226047 2.261118\n",
"7 50 200 2.183852 2.274030\n",
"4 30 200 2.195534 2.276445\n",
"8 50 300 2.179892 2.299338\n",
"5 30 300 2.198899 2.305203\n",
"2 2 300 2.326979 2.324543\n",
"1 2 200 2.333926 2.331151\n",
"0 2 100 2.368216 2.365333"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfs = df.sort_values('dev loss')\n",
"dfs"
]
},
{
"cell_type": "code",
"execution_count": 154,
"id": "323b9fe5",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "\n",
"text/plain": [
"<Figure size 432x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(6,6))\n",
"plt.scatter(dfs['train loss'], dfs['dev loss'],s=800)\n",
"plt.xlabel('train loss')\n",
"plt.ylabel('dev loss')\n",
"for i in range(len(dfs)):\n",
" plt.text(dfs['train loss'][i],dfs['dev loss'][i], \n",
" str(df['Cdim'][i]) + '\\n' + str(df['hidden_layer_size'][i]),\n",
" ha='center', va='center', color='white',fontsize=8)\n",
"plt.grid('minor')\n",
"# note that the smalles train loss does not have the smallest dev loss \n",
"# due to overfitting at Cdim=50, hidden_layer_size=300"
]
},
{
"cell_type": "markdown",
"id": "d1b93302",
"metadata": {},
"source": [
"## use the best hyperparameter to calculate loss on the test dataset"
]
},
{
"cell_type": "code",
"execution_count": 155,
"id": "3579b344",
"metadata": {},
"outputs": [],
"source": [
"Cdim = dfs['Cdim'][0]\n",
"n_output1 = dfs['hidden_layer_size'][0]"
]
},
{
"cell_type": "code",
"execution_count": 156,
"id": "d1b70f03",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2min 40s, sys: 7.38 s, total: 2min 47s\n",
"Wall time: 2min 45s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"Vdim=27\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"batch_size = 32\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2, n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True\n",
" \n",
"for i in range(500000):\n",
" # mini-batch\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
"\n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[Xtr[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Ytr[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
"\n",
" loss.backward()\n",
"\n",
" if i < 10000:\n",
" lr = 0.1 \n",
" elif i < 100000:\n",
" lr = 0.05\n",
" elif i < 200000:\n",
" lr = 0.01\n",
" else:\n",
" lr = 0.005\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad"
]
},
{
"cell_type": "code",
"execution_count": 157,
"id": "e81edd4a",
"metadata": {},
"outputs": [],
"source": [
"# compute loss on dev dataset\n",
"emb = C[Xtr] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss_train = F.cross_entropy(logits, Ytr).item()\n",
"\n",
"# can have even smaller loss with smaller learning rate\n",
"emb = C[Xdev] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss_dev = F.cross_entropy(logits, Ydev).item()\n",
"\n",
"# report loss with test dataset\n",
"emb = C[Xte] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss_test = F.cross_entropy(logits, Yte).item()"
]
},
{
"cell_type": "code",
"execution_count": 158,
"id": "d9c6ed37",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss: 2.2831\n",
"dev loss: 2.2875\n",
"test loss: 2.2865\n"
]
}
],
"source": [
"print ('train loss: %.4f'%loss_train)\n",
"print ('dev loss: %.4f'%loss_dev)\n",
"print ('test loss: %.4f'%loss_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61337334",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment