Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save yangju2011/90f34cf6646119a05a7e85abc67c3b44 to your computer and use it in GitHub Desktop.
Save yangju2011/90f34cf6646119a05a7e85abc67c3b44 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "dccd1777",
"metadata": {},
"source": [
"Original paper: Bengio et al. 2003 A Neural Probabilistic Language Model\n",
"- challenges\n",
" - curse of dimensionality for joint probability \n",
" - context: word sequence\n",
" - similarity between words\n",
"- solution\n",
" - a learned distributed feature vector (30-dimension) to represent each word in the vocabulary of 17000 words\n",
" - each training sequence informs the model about a combinatorial number of other sentences\n",
" - input layer is 30-dim vector per word, 3 words -> 90 neurons -> hidden layer -> softmax -> 17000 output neurons -> softmax P(w_i|context). most computation happens at the softmax output step.\n",
"\n",
"In this notebook, we will **predict the next character based on 3 previous characters**."
]
},
{
"cell_type": "markdown",
"id": "14b30ea6",
"metadata": {},
"source": [
"# 1. string <> integer map for vocabulary"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "6ed51a65",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn.functional as F\n",
"import matplotlib.pyplot as plt\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a72e7a3c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['emma', 'olivia', 'ava', 'isabella', 'sophia']"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"words = open('names.txt','r').read().splitlines()\n",
"words[:5]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6f531a94",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"32033"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# corpus\n",
"len(words)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "e83a6060",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}\n"
]
}
],
"source": [
"# build the vocabulary of chars and mappings to/from integers\n",
"# same as in bag-of-words\n",
"chars = sorted(list(set(''.join(words))))\n",
"# string to integer\n",
"stoi = {s:i+1 for i, s in enumerate(chars)}\n",
"# special char to represent start and end\n",
"stoi['.'] = 0\n",
"# integer to string\n",
"itos = {i:s for s, i in stoi.items()}\n",
"print (itos)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "ca1a37a3",
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"--------------\n",
"emma\n",
"... ---> e\n",
"context before: [0, 0, 0]\n",
"output char: e\n",
"output index 5\n",
"context after: [0, 0, 5]\n",
"..e ---> m\n",
"context before: [0, 0, 5]\n",
"output char: m\n",
"output index 13\n",
"context after: [0, 5, 13]\n",
".em ---> m\n",
"context before: [0, 5, 13]\n",
"output char: m\n",
"output index 13\n",
"context after: [5, 13, 13]\n",
"emm ---> a\n",
"context before: [5, 13, 13]\n",
"output char: a\n",
"output index 1\n",
"context after: [13, 13, 1]\n",
"mma ---> .\n",
"context before: [13, 13, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [13, 1, 0]\n",
"--------------\n",
"olivia\n",
"... ---> o\n",
"context before: [0, 0, 0]\n",
"output char: o\n",
"output index 15\n",
"context after: [0, 0, 15]\n",
"..o ---> l\n",
"context before: [0, 0, 15]\n",
"output char: l\n",
"output index 12\n",
"context after: [0, 15, 12]\n",
".ol ---> i\n",
"context before: [0, 15, 12]\n",
"output char: i\n",
"output index 9\n",
"context after: [15, 12, 9]\n",
"oli ---> v\n",
"context before: [15, 12, 9]\n",
"output char: v\n",
"output index 22\n",
"context after: [12, 9, 22]\n",
"liv ---> i\n",
"context before: [12, 9, 22]\n",
"output char: i\n",
"output index 9\n",
"context after: [9, 22, 9]\n",
"ivi ---> a\n",
"context before: [9, 22, 9]\n",
"output char: a\n",
"output index 1\n",
"context after: [22, 9, 1]\n",
"via ---> .\n",
"context before: [22, 9, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [9, 1, 0]\n",
"--------------\n",
"ava\n",
"... ---> a\n",
"context before: [0, 0, 0]\n",
"output char: a\n",
"output index 1\n",
"context after: [0, 0, 1]\n",
"..a ---> v\n",
"context before: [0, 0, 1]\n",
"output char: v\n",
"output index 22\n",
"context after: [0, 1, 22]\n",
".av ---> a\n",
"context before: [0, 1, 22]\n",
"output char: a\n",
"output index 1\n",
"context after: [1, 22, 1]\n",
"ava ---> .\n",
"context before: [1, 22, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [22, 1, 0]\n",
"--------------\n",
"isabella\n",
"... ---> i\n",
"context before: [0, 0, 0]\n",
"output char: i\n",
"output index 9\n",
"context after: [0, 0, 9]\n",
"..i ---> s\n",
"context before: [0, 0, 9]\n",
"output char: s\n",
"output index 19\n",
"context after: [0, 9, 19]\n",
".is ---> a\n",
"context before: [0, 9, 19]\n",
"output char: a\n",
"output index 1\n",
"context after: [9, 19, 1]\n",
"isa ---> b\n",
"context before: [9, 19, 1]\n",
"output char: b\n",
"output index 2\n",
"context after: [19, 1, 2]\n",
"sab ---> e\n",
"context before: [19, 1, 2]\n",
"output char: e\n",
"output index 5\n",
"context after: [1, 2, 5]\n",
"abe ---> l\n",
"context before: [1, 2, 5]\n",
"output char: l\n",
"output index 12\n",
"context after: [2, 5, 12]\n",
"bel ---> l\n",
"context before: [2, 5, 12]\n",
"output char: l\n",
"output index 12\n",
"context after: [5, 12, 12]\n",
"ell ---> a\n",
"context before: [5, 12, 12]\n",
"output char: a\n",
"output index 1\n",
"context after: [12, 12, 1]\n",
"lla ---> .\n",
"context before: [12, 12, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [12, 1, 0]\n",
"--------------\n",
"sophia\n",
"... ---> s\n",
"context before: [0, 0, 0]\n",
"output char: s\n",
"output index 19\n",
"context after: [0, 0, 19]\n",
"..s ---> o\n",
"context before: [0, 0, 19]\n",
"output char: o\n",
"output index 15\n",
"context after: [0, 19, 15]\n",
".so ---> p\n",
"context before: [0, 19, 15]\n",
"output char: p\n",
"output index 16\n",
"context after: [19, 15, 16]\n",
"sop ---> h\n",
"context before: [19, 15, 16]\n",
"output char: h\n",
"output index 8\n",
"context after: [15, 16, 8]\n",
"oph ---> i\n",
"context before: [15, 16, 8]\n",
"output char: i\n",
"output index 9\n",
"context after: [16, 8, 9]\n",
"phi ---> a\n",
"context before: [16, 8, 9]\n",
"output char: a\n",
"output index 1\n",
"context after: [8, 9, 1]\n",
"hia ---> .\n",
"context before: [8, 9, 1]\n",
"output char: .\n",
"output index 0\n",
"context after: [9, 1, 0]\n"
]
}
],
"source": [
"# build the dataset and see how it works\n",
"\n",
"block_size = 3\n",
"X, Y = [], []\n",
"for w in words[:5]:\n",
" print ('--------------')\n",
" print (w)\n",
" # integer 0 -> .\n",
" # ... ---> e\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" print (''.join(itos[i] for i in context), '--->', itos[ix])\n",
" print ('context before:', context)\n",
" print ('output char:',ch)\n",
" print ('output index',ix)\n",
"\n",
" # sliding window to the next context\n",
" context = context[1:] + [ix]\n",
" print ('context after:', context)\n",
" \n",
" \n",
"X=torch.tensor(X)\n",
"Y=torch.tensor(Y)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "8d3190c7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape, X.dtype, Y.shape, Y.dtype"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "cdf7a0f4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 0, 0],\n",
" [ 0, 0, 5],\n",
" [ 0, 5, 13],\n",
" [ 5, 13, 13],\n",
" [13, 13, 1]])"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[:5]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "85ca81b4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 5, 13, 13, 1, 0])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y[:5]"
]
},
{
"cell_type": "markdown",
"id": "c55a1fab",
"metadata": {},
"source": [
"# 2. feature vector"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "38c6aecf",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"27"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Vdim = len(stoi.keys())\n",
"Vdim"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "61fc8fec",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.0174, 0.2338],\n",
" [ 0.3223, 0.4244],\n",
" [-0.3297, -0.0899],\n",
" [-0.3791, 0.2541],\n",
" [-0.8308, 0.4593],\n",
" [ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066],\n",
" [ 0.5257, 1.8360],\n",
" [-0.7879, 1.7276],\n",
" [ 0.5974, 1.3534],\n",
" [-1.2337, -0.4362],\n",
" [-1.8687, -1.5495],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.2043, -0.8093],\n",
" [-0.0207, 0.7163],\n",
" [ 0.0288, 1.1582],\n",
" [-1.0672, -1.6540],\n",
" [-0.9492, 0.2938],\n",
" [-0.1061, -1.3482],\n",
" [ 0.3623, -0.1373],\n",
" [-0.0161, -0.2087],\n",
" [-0.8985, -0.2570],\n",
" [-0.4422, 0.4705]])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# take 27 unique chars, put into a 2-D feature vector\n",
"# layer with 27 input and 2 output\n",
"\n",
"Cdim = 2\n",
"# initialize with random numbers from a normal distribution with mean `0` and variance `1`\n",
"C = torch.randn((Vdim, Cdim))\n",
"C"
]
},
{
"cell_type": "markdown",
"id": "9b184414",
"metadata": {},
"source": [
"## embedding of the integer (feature vector for a char)\n",
"1. index for C\n",
"2. one-hot encodding @ C"
]
},
{
"cell_type": "markdown",
"id": "e37da9d0",
"metadata": {},
"source": [
"### 1. index for C"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "5058e4c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4464, -0.6289])"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# prefered. it is FASTER than matrix multiplication\n",
"C[0]"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "09e5c2ad",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4464, -0.6289])"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C[stoi['.']]"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "80d34c62",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.3297, -0.0899])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C[5]"
]
},
{
"cell_type": "markdown",
"id": "6b8ef1c0",
"metadata": {},
"source": [
"### 2. one-hot encodding @ C"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "3d832521",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.one_hot(torch.tensor(5), num_classes=27)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "a9c49619",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "expected scalar type Long but found Float",
"output_type": "error",
"traceback": [
"\u001b[0;31m-------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-15-246e9ef054ac>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# one_hot is int, C is float, does not know how to multiply\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mF\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mone_hot\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtensor\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m5\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_classes\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m27\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mC\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m: expected scalar type Long but found Float"
]
}
],
"source": [
"# one_hot is int, C is float, does not know how to multiply\n",
"F.one_hot(torch.tensor(5), num_classes=27) @ C"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "4ac24f98",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.int64"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.one_hot(torch.tensor(5), num_classes=27).dtype"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "4e1dfed9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.3297, -0.0899])"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# matrix multiplication to get the i-th row from C\n",
"F.one_hot(torch.tensor(5), num_classes=27).float() @ C"
]
},
{
"cell_type": "markdown",
"id": "a87d59b9",
"metadata": {},
"source": [
"## embedding of X"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "f5173d6a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.4464, -0.6289])"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C[0]"
]
},
{
"cell_type": "code",
"execution_count": 19,
"id": "eeefed64",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [ 1.3732, 0.4287],\n",
" [ 1.3732, 0.4287]])"
]
},
"execution_count": 19,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"C[[0,1,2,2,2]]"
]
},
{
"cell_type": "code",
"execution_count": 20,
"id": "6791fef3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [ 1.3732, 0.4287],\n",
" [ 1.3732, 0.4287]])"
]
},
"execution_count": 20,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# retrieve value from a matrix\n",
"C[torch.tensor([0,1,2,2,2])]"
]
},
{
"cell_type": "code",
"execution_count": 21,
"id": "7b327d78",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3])"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# from 2 words, 32 training samples\n",
"X.shape"
]
},
{
"cell_type": "code",
"execution_count": 22,
"id": "0f104d34",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 0, 0],\n",
" [ 0, 0, 5],\n",
" [ 0, 5, 13],\n",
" [ 5, 13, 13],\n",
" [13, 13, 1],\n",
" [ 0, 0, 0],\n",
" [ 0, 0, 15],\n",
" [ 0, 15, 12],\n",
" [15, 12, 9],\n",
" [12, 9, 22],\n",
" [ 9, 22, 9],\n",
" [22, 9, 1],\n",
" [ 0, 0, 0],\n",
" [ 0, 0, 1],\n",
" [ 0, 1, 22],\n",
" [ 1, 22, 1],\n",
" [ 0, 0, 0],\n",
" [ 0, 0, 9],\n",
" [ 0, 9, 19],\n",
" [ 9, 19, 1],\n",
" [19, 1, 2],\n",
" [ 1, 2, 5],\n",
" [ 2, 5, 12],\n",
" [ 5, 12, 12],\n",
" [12, 12, 1],\n",
" [ 0, 0, 0],\n",
" [ 0, 0, 19],\n",
" [ 0, 19, 15],\n",
" [19, 15, 16],\n",
" [15, 16, 8],\n",
" [16, 8, 9],\n",
" [ 8, 9, 1]])"
]
},
"execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X\n",
"# X is training data, each X[i] is 3 word in a row\n",
"# X[a,b] is the b-th char in the a-th record"
]
},
{
"cell_type": "code",
"execution_count": 23,
"id": "0f8bd881",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(9)"
]
},
"execution_count": 23,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[10,2]"
]
},
{
"cell_type": "code",
"execution_count": 24,
"id": "9f4c8ffe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'i'"
]
},
"execution_count": 24,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"itos[X[10,2].item()]"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "64227f92",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([-0.5079, 0.7066])"
]
},
"execution_count": 25,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# index 1 of C is a 2-dim value\n",
"# each char is encoded in a 2-d array in C\n",
"# get the embedding fromm C for a char \n",
"C[X[10,2]]"
]
},
{
"cell_type": "code",
"execution_count": 26,
"id": "0da98bb8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3, 2])"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# [32, 3]\n",
"# [27, 2]\n",
"# [32, 3, 2]\n",
"C[X].shape"
]
},
{
"cell_type": "code",
"execution_count": 27,
"id": "942cde17",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.3297, -0.0899]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.3297, -0.0899],\n",
" [-1.2337, -0.4362]],\n",
"\n",
" [[-0.3297, -0.0899],\n",
" [-1.2337, -0.4362],\n",
" [-1.2337, -0.4362]],\n",
"\n",
" [[-1.2337, -0.4362],\n",
" [-1.2337, -0.4362],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534]],\n",
"\n",
" [[ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066]],\n",
"\n",
" [[ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066],\n",
" [-0.1061, -1.3482]],\n",
"\n",
" [[-0.5079, 0.7066],\n",
" [-0.1061, -1.3482],\n",
" [-0.5079, 0.7066]],\n",
"\n",
" [[-0.1061, -1.3482],\n",
" [-0.5079, 0.7066],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [-0.1061, -1.3482]],\n",
"\n",
" [[ 1.3162, -0.2359],\n",
" [-0.1061, -1.3482],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.5079, 0.7066]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582]],\n",
"\n",
" [[-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287]],\n",
"\n",
" [[ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899]],\n",
"\n",
" [[ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534]],\n",
"\n",
" [[-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534]],\n",
"\n",
" [[ 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534],\n",
" [ 1.3162, -0.2359]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582]],\n",
"\n",
" [[-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078]],\n",
"\n",
" [[ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016]],\n",
"\n",
" [[ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561]],\n",
"\n",
" [[ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066]],\n",
"\n",
" [[ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066],\n",
" [ 1.3162, -0.2359]]])"
]
},
"execution_count": 27,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# X is all the chars in all trainign record\n",
"# each char -> 2D \n",
"C[X]"
]
},
{
"cell_type": "markdown",
"id": "fc698b6b",
"metadata": {},
"source": [
"# 3. build neural network with the embedded input"
]
},
{
"cell_type": "code",
"execution_count": 28,
"id": "11b97859",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3, 2])"
]
},
"execution_count": 28,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# embedding\n",
"# C[X] contains the embedding for all input for [3,2]\n",
"emb = C[X]\n",
"emb.shape"
]
},
{
"cell_type": "code",
"execution_count": 29,
"id": "f4bd8c19",
"metadata": {},
"outputs": [],
"source": [
"# use the word2vec architeture from the paper\n",
"# hidden layer\n",
"\n",
"# number of inputs is 3X2: 2 dimension embedding with 3 words\n",
"# number of neurons (output): 100\n",
"block_size = 3\n",
"Cdim = 2\n",
"n_input1 = block_size * Cdim\n",
"n_output1 = 100\n",
"\n",
"W1 = torch.randn([n_input1,n_output1])\n",
"# bias\n",
"b1 = torch.rand(n_output1)\n",
"\n",
"# we usually do emb @ W1 + b1\n",
"# but embedding is staked, need to flatten it"
]
},
{
"cell_type": "code",
"execution_count": 30,
"id": "36172342",
"metadata": {},
"outputs": [
{
"ename": "RuntimeError",
"evalue": "mat1 and mat2 shapes cannot be multiplied (96x2 and 6x100)",
"output_type": "error",
"traceback": [
"\u001b[0;31m-------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-30-355728df8662>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0memb\u001b[0m \u001b[0;34m@\u001b[0m \u001b[0mW1\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0mb1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0m: mat1 and mat2 shapes cannot be multiplied (96x2 and 6x100)"
]
}
],
"source": [
"emb @ W1 + b1"
]
},
{
"cell_type": "markdown",
"id": "914ace3d",
"metadata": {},
"source": [
"## concat input of 3 words into a single vector\n",
"1. cat: not efficent, use memory to create new tensor\n",
"2. view: efficient!"
]
},
{
"cell_type": "markdown",
"id": "522fad71",
"metadata": {},
"source": [
"### 1. cat"
]
},
{
"cell_type": "code",
"execution_count": 31,
"id": "a0a80e36",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 2])"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 1st word \n",
"emb[:, 0, :].shape"
]
},
{
"cell_type": "code",
"execution_count": 32,
"id": "d03f39e8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([96, 2])"
]
},
"execution_count": 32,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# concat the input to the correct dimension\n",
"torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], dim=0).shape"
]
},
{
"cell_type": "code",
"execution_count": 33,
"id": "682d63fe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 6])"
]
},
"execution_count": 33,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], dim=1).shape"
]
},
{
"cell_type": "code",
"execution_count": 34,
"id": "eeee9a5f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.3297, -0.0899],\n",
" [-0.4464, -0.6289, -0.3297, -0.0899, -1.2337, -0.4362],\n",
" [-0.3297, -0.0899, -1.2337, -0.4362, -1.2337, -0.4362],\n",
" [-1.2337, -0.4362, -1.2337, -0.4362, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 1.0291, -0.0078],\n",
" [-0.4464, -0.6289, 1.0291, -0.0078, 0.5974, 1.3534],\n",
" [ 1.0291, -0.0078, 0.5974, 1.3534, -0.5079, 0.7066],\n",
" [ 0.5974, 1.3534, -0.5079, 0.7066, -0.1061, -1.3482],\n",
" [-0.5079, 0.7066, -0.1061, -1.3482, -0.5079, 0.7066],\n",
" [-0.1061, -1.3482, -0.5079, 0.7066, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, 1.3162, -0.2359, -0.1061, -1.3482],\n",
" [ 1.3162, -0.2359, -0.1061, -1.3482, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.5079, 0.7066],\n",
" [-0.4464, -0.6289, -0.5079, 0.7066, 0.0288, 1.1582],\n",
" [-0.5079, 0.7066, 0.0288, 1.1582, 1.3162, -0.2359],\n",
" [ 0.0288, 1.1582, 1.3162, -0.2359, 1.3732, 0.4287],\n",
" [ 1.3162, -0.2359, 1.3732, 0.4287, -0.3297, -0.0899],\n",
" [ 1.3732, 0.4287, -0.3297, -0.0899, 0.5974, 1.3534],\n",
" [-0.3297, -0.0899, 0.5974, 1.3534, 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534, 0.5974, 1.3534, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 0.0288, 1.1582],\n",
" [-0.4464, -0.6289, 0.0288, 1.1582, 1.0291, -0.0078],\n",
" [ 0.0288, 1.1582, 1.0291, -0.0078, 1.7785, 0.6016],\n",
" [ 1.0291, -0.0078, 1.7785, 0.6016, 1.5862, 0.9561],\n",
" [ 1.7785, 0.6016, 1.5862, 0.9561, -0.5079, 0.7066],\n",
" [ 1.5862, 0.9561, -0.5079, 0.7066, 1.3162, -0.2359]])"
]
},
"execution_count": 34,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# there are 32 training samples, we want for each training sample, all 3 words are combined\n",
"# -> [32, 2*3]\n",
"torch.cat([emb[:, 0, :], emb[:, 1, :], emb[:, 2, :]], dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 35,
"id": "2e6d72df",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3, 2])"
]
},
"execution_count": 35,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.shape"
]
},
{
"cell_type": "code",
"execution_count": 36,
"id": "46dc9af3",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(tensor([[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.3297, -0.0899],\n",
" [-1.2337, -0.4362],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066],\n",
" [-0.1061, -1.3482],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561]]),\n",
" tensor([[-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.3297, -0.0899],\n",
" [-1.2337, -0.4362],\n",
" [-1.2337, -0.4362],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066],\n",
" [-0.1061, -1.3482],\n",
" [-0.5079, 0.7066],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [-0.1061, -1.3482],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534],\n",
" [-0.4464, -0.6289],\n",
" [-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066]]),\n",
" tensor([[-0.4464, -0.6289],\n",
" [-0.3297, -0.0899],\n",
" [-1.2337, -0.4362],\n",
" [-1.2337, -0.4362],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [ 1.0291, -0.0078],\n",
" [ 0.5974, 1.3534],\n",
" [-0.5079, 0.7066],\n",
" [-0.1061, -1.3482],\n",
" [-0.5079, 0.7066],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [ 1.3162, -0.2359],\n",
" [-0.1061, -1.3482],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [-0.5079, 0.7066],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.3162, -0.2359],\n",
" [ 1.3732, 0.4287],\n",
" [-0.3297, -0.0899],\n",
" [ 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534],\n",
" [ 1.3162, -0.2359],\n",
" [-0.4464, -0.6289],\n",
" [ 0.0288, 1.1582],\n",
" [ 1.0291, -0.0078],\n",
" [ 1.7785, 0.6016],\n",
" [ 1.5862, 0.9561],\n",
" [-0.5079, 0.7066],\n",
" [ 1.3162, -0.2359]]))"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Removes a tensor dimension.\n",
"# Returns a tuple of all slices along a given dimension, already without it.\n",
"\n",
"torch.unbind(emb, dim=1)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "b6372560",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 6])"
]
},
"execution_count": 37,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# unbind and then concat\n",
"torch.cat(torch.unbind(emb, dim=1),dim=1).shape"
]
},
{
"cell_type": "markdown",
"id": "70cf946b",
"metadata": {},
"source": [
"### 2. view\n",
"very efficient because a view does not create a new tensor; instead, it just returns a tensor which is a different view on the underlying data (physical storage)\n",
"http://blog.ezyang.com/2019/05/pytorch-internals/"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "eb233714",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17])"
]
},
"execution_count": 38,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = torch.arange(18)\n",
"a"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "018a896f",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([18])"
]
},
"execution_count": 39,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.shape"
]
},
{
"cell_type": "code",
"execution_count": 40,
"id": "532dcfe7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8],\n",
" [ 9, 10, 11, 12, 13, 14, 15, 16, 17]])"
]
},
"execution_count": 40,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.view([2,9])"
]
},
{
"cell_type": "code",
"execution_count": 41,
"id": "7fa30b55",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 1],\n",
" [ 2, 3],\n",
" [ 4, 5],\n",
" [ 6, 7],\n",
" [ 8, 9],\n",
" [10, 11],\n",
" [12, 13],\n",
" [14, 15],\n",
" [16, 17]])"
]
},
"execution_count": 41,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.view([9,2])"
]
},
{
"cell_type": "code",
"execution_count": 42,
"id": "4d8da572",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[ 0, 1, 2, 3, 4, 5],\n",
" [ 6, 7, 8, 9, 10, 11],\n",
" [12, 13, 14, 15, 16, 17]])"
]
},
"execution_count": 42,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.view([3,6])"
]
},
{
"cell_type": "code",
"execution_count": 43,
"id": "8e9755e5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
" 0\n",
" 1\n",
" 2\n",
" 3\n",
" 4\n",
" 5\n",
" 6\n",
" 7\n",
" 8\n",
" 9\n",
" 10\n",
" 11\n",
" 12\n",
" 13\n",
" 14\n",
" 15\n",
" 16\n",
" 17\n",
"[torch.LongStorage of size 18]"
]
},
"execution_count": 43,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a.storage()\n",
"# numbers stored in 1D vector\n",
"# .view changes the 1D sequence into a tensor\n",
"# no copy of the value created"
]
},
{
"cell_type": "code",
"execution_count": 44,
"id": "777f6498",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3, 2])"
]
},
"execution_count": 44,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.shape"
]
},
{
"cell_type": "code",
"execution_count": 45,
"id": "9c650672",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.3297, -0.0899],\n",
" [-0.4464, -0.6289, -0.3297, -0.0899, -1.2337, -0.4362],\n",
" [-0.3297, -0.0899, -1.2337, -0.4362, -1.2337, -0.4362],\n",
" [-1.2337, -0.4362, -1.2337, -0.4362, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 1.0291, -0.0078],\n",
" [-0.4464, -0.6289, 1.0291, -0.0078, 0.5974, 1.3534],\n",
" [ 1.0291, -0.0078, 0.5974, 1.3534, -0.5079, 0.7066],\n",
" [ 0.5974, 1.3534, -0.5079, 0.7066, -0.1061, -1.3482],\n",
" [-0.5079, 0.7066, -0.1061, -1.3482, -0.5079, 0.7066],\n",
" [-0.1061, -1.3482, -0.5079, 0.7066, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, 1.3162, -0.2359, -0.1061, -1.3482],\n",
" [ 1.3162, -0.2359, -0.1061, -1.3482, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.5079, 0.7066],\n",
" [-0.4464, -0.6289, -0.5079, 0.7066, 0.0288, 1.1582],\n",
" [-0.5079, 0.7066, 0.0288, 1.1582, 1.3162, -0.2359],\n",
" [ 0.0288, 1.1582, 1.3162, -0.2359, 1.3732, 0.4287],\n",
" [ 1.3162, -0.2359, 1.3732, 0.4287, -0.3297, -0.0899],\n",
" [ 1.3732, 0.4287, -0.3297, -0.0899, 0.5974, 1.3534],\n",
" [-0.3297, -0.0899, 0.5974, 1.3534, 0.5974, 1.3534],\n",
" [ 0.5974, 1.3534, 0.5974, 1.3534, 1.3162, -0.2359],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, -0.4464, -0.6289],\n",
" [-0.4464, -0.6289, -0.4464, -0.6289, 0.0288, 1.1582],\n",
" [-0.4464, -0.6289, 0.0288, 1.1582, 1.0291, -0.0078],\n",
" [ 0.0288, 1.1582, 1.0291, -0.0078, 1.7785, 0.6016],\n",
" [ 1.0291, -0.0078, 1.7785, 0.6016, 1.5862, 0.9561],\n",
" [ 1.7785, 0.6016, 1.5862, 0.9561, -0.5079, 0.7066],\n",
" [ 1.5862, 0.9561, -0.5079, 0.7066, 1.3162, -0.2359]])"
]
},
"execution_count": 45,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# horray, 2 get stacked up to 3\n",
"emb.view([32,6])"
]
},
{
"cell_type": "code",
"execution_count": 46,
"id": "64eb97d1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True],\n",
" [True, True, True, True, True, True]])"
]
},
"execution_count": 46,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"emb.view([32,6]) == torch.cat(torch.unbind(emb,1),1)"
]
},
{
"cell_type": "code",
"execution_count": 47,
"id": "2649e57a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 100])"
]
},
"execution_count": 47,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"(emb.view([X.shape[0],n_input1])@W1 + b1).shape"
]
},
{
"cell_type": "code",
"execution_count": 48,
"id": "33adcb51",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[-0.2642, -0.5217, -0.5207, ..., 1.6704, 0.4607, -0.2842],\n",
" [-0.5069, -0.4191, 0.5045, ..., 1.2557, 0.1441, 0.3194],\n",
" [-0.4562, -0.8144, 0.4525, ..., 0.3793, -1.2262, -0.2351],\n",
" ...,\n",
" [ 0.8261, 3.0585, 2.8153, ..., 0.3313, 4.5300, 2.5886],\n",
" [ 0.4643, 1.9310, 3.9621, ..., -1.6207, 2.8345, 0.2874],\n",
" [ 3.2720, -2.1254, -0.0563, ..., -1.9226, 2.0426, -0.6754]])"
]
},
"execution_count": 48,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# use -1, it will infer the dim\n",
"h = emb.view([-1,n_input1])@W1 + b1\n",
"# b1 is 1D\n",
"# the broadcasting with + b1, align on the right\n",
"# 32, 100\n",
"# 1, 100\n",
"# => \n",
"# 32, 100\n",
"# 32, 100 -> broadcasted \n",
"# hidden layer\n",
"h"
]
},
{
"cell_type": "code",
"execution_count": 49,
"id": "62fdd381",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 27])"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# final layer\n",
"# output is 27 possibe chars\n",
"n_output2 = Vdim\n",
"W2 = torch.randn([n_output1, n_output2])\n",
"b2 = torch.randn(n_output2)\n",
"logits = h @ W2 + b2\n",
"logits.shape"
]
},
{
"cell_type": "code",
"execution_count": 50,
"id": "f3e46359",
"metadata": {},
"outputs": [],
"source": [
"counts = logits.exp()"
]
},
{
"cell_type": "code",
"execution_count": 51,
"id": "bc4ea2ae",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 27])"
]
},
"execution_count": 51,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# sum by row\n",
"prob = counts / counts.sum(1, keepdims=True)\n",
"prob.shape"
]
},
{
"cell_type": "code",
"execution_count": 52,
"id": "e6eecd51",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [1.5634e-06, 5.8689e-03, 5.5907e-12, 3.6511e-24, 1.9378e-13, 1.6033e-02,\n",
" 1.3811e-09, 8.4612e-10, 9.2904e-12, 7.1332e-04, 7.8706e-05, 1.0586e-08,\n",
" 9.6776e-10, 2.4594e-09, 5.3621e-12, 1.4119e-13, 2.0930e-08, 8.0390e-03,\n",
" 1.7799e-03, 7.4927e-03, 9.5619e-08, 3.6570e-17, 5.1102e-07, 6.5709e-08,\n",
" 8.0131e-14, 9.7636e-08, 9.5999e-01],\n",
" [2.5425e-13, 8.1657e-01, 4.2780e-17, 4.2832e-29, 1.9830e-15, 1.8342e-01,\n",
" 4.6704e-11, 6.8022e-12, 9.1839e-19, 7.4077e-16, 5.5468e-18, 9.6685e-19,\n",
" 8.7594e-16, 9.6322e-16, 3.2213e-11, 2.2215e-15, 8.7443e-19, 2.2657e-11,\n",
" 1.9050e-09, 7.0052e-06, 2.5675e-14, 7.9356e-15, 4.7929e-06, 1.3517e-14,\n",
" 3.5140e-19, 1.3572e-16, 6.8860e-12],\n",
" [1.5926e-16, 8.2677e-01, 4.5069e-28, 6.2221e-31, 3.4318e-19, 4.3759e-02,\n",
" 4.8483e-15, 1.3489e-20, 1.2572e-20, 3.5169e-17, 4.1630e-18, 1.0599e-19,\n",
" 1.0524e-20, 1.7805e-16, 2.4486e-17, 1.1565e-21, 1.0517e-13, 4.8244e-04,\n",
" 1.2899e-01, 5.0841e-13, 3.9280e-20, 3.3293e-29, 1.9242e-13, 3.8402e-11,\n",
" 5.0398e-27, 2.0794e-23, 6.0030e-15],\n",
" [5.6912e-18, 8.5762e-14, 5.4786e-09, 7.6643e-28, 9.8223e-34, 1.0855e-20,\n",
" 1.6902e-23, 2.9398e-26, 2.8179e-18, 2.9614e-25, 2.1332e-17, 1.3366e-05,\n",
" 1.3824e-13, 3.1554e-20, 2.1864e-30, 4.9378e-17, 4.9420e-28, 1.7444e-06,\n",
" 9.9998e-01, 2.7049e-28, 1.3291e-20, 1.7832e-37, 7.2543e-33, 7.4020e-16,\n",
" 9.4811e-27, 1.4014e-23, 2.4846e-14],\n",
" [1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [6.7133e-07, 3.9065e-12, 7.0431e-06, 2.8241e-21, 6.2352e-18, 2.1865e-14,\n",
" 1.1576e-11, 2.3479e-14, 1.5960e-10, 1.1027e-07, 9.4439e-01, 3.4291e-05,\n",
" 1.0041e-07, 7.8863e-13, 2.0277e-18, 1.1248e-11, 5.9427e-11, 2.3680e-04,\n",
" 3.2660e-04, 4.0719e-10, 1.8607e-08, 3.9505e-23, 1.6343e-18, 1.7168e-11,\n",
" 5.8671e-17, 5.8176e-08, 5.5003e-02],\n",
" [2.3603e-19, 1.8190e-28, 1.0644e-06, 1.5602e-22, 6.6367e-20, 2.4823e-30,\n",
" 2.2450e-22, 2.0518e-26, 1.3708e-25, 9.2010e-09, 8.1778e-17, 3.9444e-13,\n",
" 7.3042e-02, 1.0480e-09, 5.7714e-22, 1.2623e-20, 6.2400e-27, 1.8503e-17,\n",
" 2.8257e-32, 3.2646e-07, 3.4103e-15, 9.0763e-01, 1.1386e-22, 2.2861e-30,\n",
" 1.9326e-02, 9.4481e-10, 1.4029e-11],\n",
" [5.4918e-28, 2.4647e-27, 5.9095e-19, 1.0829e-18, 3.9169e-10, 1.1083e-32,\n",
" 2.8731e-15, 8.3012e-37, 3.1838e-33, 3.6336e-30, 9.5694e-25, 1.5107e-38,\n",
" 9.1791e-09, 2.4099e-19, 1.0682e-17, 7.8429e-15, 4.1298e-31, 5.1046e-24,\n",
" 1.1585e-37, 1.8818e-10, 1.2956e-25, 1.0000e+00, 1.5812e-28, 8.5479e-44,\n",
" 2.8373e-21, 4.4365e-27, 1.8675e-37],\n",
" [5.3424e-22, 3.1758e-10, 6.6093e-29, 3.0561e-17, 4.7211e-18, 4.2464e-26,\n",
" 9.9979e-01, 2.6293e-29, 4.4755e-24, 0.0000e+00, 2.2950e-27, 5.8107e-26,\n",
" 2.0452e-27, 1.6954e-37, 7.9085e-17, 3.9599e-19, 5.0982e-15, 4.5363e-19,\n",
" 2.0801e-04, 1.9471e-31, 9.5267e-33, 2.2751e-28, 7.4330e-31, 1.7626e-20,\n",
" 1.2259e-40, 7.5773e-35, 1.2144e-39],\n",
" [6.6060e-16, 7.0911e-14, 8.9545e-33, 3.0449e-27, 7.3567e-22, 4.4476e-16,\n",
" 1.6143e-17, 3.1832e-25, 9.2399e-23, 1.8690e-02, 7.4776e-16, 1.8604e-06,\n",
" 9.2996e-18, 1.0148e-11, 2.1973e-21, 9.4467e-38, 9.7591e-01, 5.3289e-03,\n",
" 2.5007e-08, 1.3354e-12, 9.9222e-21, 6.2210e-22, 3.9712e-16, 7.9575e-07,\n",
" 7.1641e-10, 2.4088e-12, 6.9206e-05],\n",
" [2.1457e-19, 5.3917e-23, 9.9863e-01, 3.5339e-26, 2.4931e-23, 7.7466e-27,\n",
" 2.8573e-20, 1.9329e-25, 9.6134e-21, 7.1319e-31, 1.7529e-12, 3.0144e-25,\n",
" 1.3543e-09, 7.0777e-24, 8.2387e-26, 1.3696e-03, 9.6154e-39, 2.6552e-20,\n",
" 3.8563e-21, 5.2864e-19, 8.4760e-17, 7.3578e-25, 4.1313e-33, 1.1553e-35,\n",
" 3.3413e-32, 7.9586e-24, 7.1770e-24],\n",
" [1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [5.5573e-07, 1.6617e-13, 1.8262e-05, 6.0024e-22, 7.0303e-20, 7.5283e-17,\n",
" 4.0332e-11, 3.7349e-14, 3.2555e-10, 2.5044e-11, 9.9666e-01, 1.9005e-05,\n",
" 4.2143e-10, 7.9021e-17, 8.4015e-20, 1.5246e-11, 8.2403e-12, 7.8824e-07,\n",
" 9.0538e-04, 2.0810e-13, 2.0073e-09, 1.0948e-26, 5.9718e-21, 4.5434e-12,\n",
" 1.3866e-20, 4.2469e-09, 2.3968e-03],\n",
" [1.2502e-09, 1.4916e-10, 1.5064e-12, 4.8487e-31, 3.0801e-20, 5.8148e-16,\n",
" 1.5651e-01, 8.4349e-01, 3.9499e-17, 1.7393e-26, 1.9220e-18, 1.6434e-16,\n",
" 6.4413e-27, 3.6102e-36, 3.7108e-10, 4.1036e-19, 1.1058e-17, 9.3626e-34,\n",
" 8.1936e-19, 5.4736e-13, 1.5986e-14, 9.5979e-14, 2.6343e-06, 5.5294e-17,\n",
" 3.1486e-23, 3.1949e-10, 1.4226e-11],\n",
" [4.2523e-22, 5.6052e-45, 9.8091e-45, 1.2472e-43, 5.6901e-29, 9.3326e-43,\n",
" 1.5425e-23, 9.4540e-35, 4.2318e-29, 4.1244e-19, 9.9999e-01, 4.4577e-39,\n",
" 3.9169e-41, 3.9937e-43, 8.8514e-39, 3.7042e-40, 8.0529e-06, 8.4035e-30,\n",
" 1.6566e-31, 6.3691e-26, 1.2070e-27, 0.0000e+00, 4.2959e-40, 7.0247e-33,\n",
" 0.0000e+00, 1.5955e-22, 2.0003e-19],\n",
" [1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [3.5188e-13, 2.4852e-10, 1.3029e-16, 4.9331e-27, 2.8343e-16, 8.0934e-08,\n",
" 2.9858e-19, 9.6433e-20, 8.6115e-19, 9.6047e-01, 2.5603e-09, 1.3697e-12,\n",
" 1.7087e-07, 7.3772e-05, 8.8451e-18, 1.1927e-19, 9.9358e-14, 3.0442e-02,\n",
" 4.6515e-12, 7.8968e-03, 1.1242e-11, 1.5575e-15, 3.4302e-12, 9.6888e-15,\n",
" 2.2749e-10, 7.7974e-11, 1.1181e-03],\n",
" [5.1837e-27, 4.5088e-21, 3.2621e-09, 2.6907e-23, 5.7301e-22, 4.3898e-24,\n",
" 7.5240e-29, 2.8476e-37, 2.0032e-29, 5.0041e-21, 7.1890e-25, 8.2807e-22,\n",
" 1.0000e+00, 2.5662e-07, 7.4503e-26, 6.3922e-14, 2.8118e-39, 2.3482e-09,\n",
" 3.0197e-24, 9.3495e-14, 7.2589e-22, 2.5147e-10, 1.2226e-30, 9.4292e-35,\n",
" 4.8043e-14, 7.4667e-26, 3.7715e-25],\n",
" [3.3859e-21, 1.5635e-15, 9.9985e-01, 1.9184e-07, 7.5952e-22, 8.3803e-34,\n",
" 6.8879e-08, 2.9004e-30, 3.4280e-21, 2.3262e-43, 1.0164e-28, 2.7298e-08,\n",
" 1.5216e-04, 3.8501e-23, 3.7163e-19, 6.5912e-08, 8.2446e-32, 5.8319e-15,\n",
" 1.6726e-08, 2.1963e-29, 1.6910e-25, 4.8436e-10, 5.3335e-36, 5.5999e-25,\n",
" 7.1183e-19, 6.5876e-27, 3.9180e-33],\n",
" [1.3869e-11, 1.3064e-23, 1.3883e-11, 3.0361e-07, 1.1145e-14, 1.1017e-36,\n",
" 2.2892e-01, 2.3501e-22, 3.0828e-17, 4.0604e-17, 1.7824e-12, 7.6531e-01,\n",
" 9.6515e-08, 1.3337e-19, 4.9742e-15, 4.9165e-23, 5.4202e-03, 1.5549e-14,\n",
" 1.6286e-14, 8.5261e-17, 4.4036e-18, 3.3526e-04, 7.5974e-25, 9.9077e-15,\n",
" 1.6673e-05, 4.9403e-07, 2.8487e-13],\n",
" [3.6625e-13, 4.2774e-23, 4.5643e-19, 3.2510e-20, 1.6000e-03, 3.2078e-25,\n",
" 6.6707e-01, 4.6174e-16, 8.9175e-22, 1.4777e-17, 1.1382e-07, 1.5487e-32,\n",
" 2.6719e-16, 3.7821e-25, 1.7327e-08, 1.9661e-15, 8.0458e-10, 5.1622e-28,\n",
" 4.0367e-33, 2.8927e-01, 2.1098e-14, 4.2066e-02, 5.9124e-13, 8.9249e-31,\n",
" 6.7915e-21, 1.4951e-10, 7.5850e-20],\n",
" [3.4458e-19, 1.3673e-31, 3.7091e-26, 2.2236e-15, 2.1822e-08, 7.1872e-33,\n",
" 1.0542e-16, 2.6291e-40, 3.0007e-24, 1.2446e-06, 9.9918e-01, 1.9935e-26,\n",
" 3.3476e-06, 3.1589e-09, 7.6126e-24, 3.8830e-21, 7.0703e-05, 7.3992e-04,\n",
" 1.8127e-23, 3.4805e-08, 2.1202e-19, 2.6834e-18, 2.1624e-33, 2.3875e-29,\n",
" 4.8275e-20, 6.3455e-17, 1.7028e-19],\n",
" [8.1457e-35, 1.2857e-33, 1.1304e-07, 3.5911e-19, 1.1025e-25, 2.6485e-43,\n",
" 7.8327e-29, 2.8026e-45, 1.9902e-37, 1.8492e-35, 3.6882e-37, 9.1030e-25,\n",
" 2.3057e-01, 2.7072e-16, 7.7188e-29, 5.5285e-18, 0.0000e+00, 2.4823e-23,\n",
" 4.3372e-38, 1.6958e-22, 3.0084e-30, 7.6943e-01, 7.6148e-41, 0.0000e+00,\n",
" 2.9827e-12, 1.3995e-30, 2.2507e-38],\n",
" [4.4522e-22, 1.6499e-23, 1.1505e-09, 6.5774e-03, 1.5525e-14, 2.6204e-43,\n",
" 9.9342e-01, 2.1890e-34, 1.0482e-23, 0.0000e+00, 3.5921e-25, 1.5227e-17,\n",
" 7.4367e-09, 3.5658e-29, 6.7953e-17, 3.0815e-11, 4.1484e-22, 7.5288e-21,\n",
" 2.5777e-16, 1.3564e-28, 1.0549e-28, 2.9484e-07, 5.0902e-39, 3.8570e-30,\n",
" 2.4113e-23, 3.3730e-27, 5.3107e-40],\n",
" [1.8994e-06, 7.0665e-01, 2.4528e-14, 3.9459e-27, 3.7168e-15, 2.4027e-01,\n",
" 3.2191e-07, 8.0452e-07, 1.0725e-11, 1.2262e-09, 6.1807e-07, 4.1937e-11,\n",
" 6.2864e-16, 4.8229e-16, 1.9374e-11, 2.9170e-14, 8.0970e-09, 2.6827e-07,\n",
" 4.7176e-02, 5.6506e-06, 2.1539e-09, 1.4395e-21, 5.4801e-06, 3.7784e-07,\n",
" 6.4447e-20, 5.1319e-10, 5.8910e-03],\n",
" [1.1852e-16, 2.9272e-18, 4.8271e-16, 4.9006e-27, 1.1208e-19, 6.4937e-16,\n",
" 4.4038e-25, 7.2464e-27, 1.0100e-21, 9.8265e-01, 1.3072e-09, 7.8116e-13,\n",
" 3.0387e-05, 5.3026e-04, 1.3162e-23, 8.9681e-22, 1.5955e-17, 1.6768e-02,\n",
" 1.0449e-16, 4.9513e-06, 6.7697e-14, 2.6620e-17, 4.7354e-20, 8.4945e-20,\n",
" 7.0451e-10, 2.1320e-12, 1.2574e-05],\n",
" [4.9052e-25, 9.4807e-23, 1.0000e+00, 4.0314e-21, 2.3924e-25, 1.6132e-32,\n",
" 1.3216e-19, 1.9167e-30, 3.1326e-26, 1.5693e-39, 3.6699e-27, 1.1513e-21,\n",
" 2.5881e-07, 3.5708e-24, 1.4833e-24, 1.4357e-07, 1.4013e-44, 2.7540e-23,\n",
" 1.2590e-23, 5.6606e-24, 1.9721e-23, 1.2849e-13, 3.6746e-35, 2.3322e-37,\n",
" 2.0883e-24, 5.1857e-28, 3.8505e-32],\n",
" [1.4766e-13, 5.4573e-26, 2.6174e-08, 8.8484e-05, 7.2608e-16, 2.7285e-40,\n",
" 5.3116e-04, 2.4367e-27, 9.4136e-18, 6.8406e-20, 1.3292e-12, 9.9894e-01,\n",
" 4.2736e-04, 5.0113e-18, 5.5990e-18, 5.2619e-20, 3.0610e-07, 3.1204e-12,\n",
" 5.6372e-14, 1.0915e-19, 3.0624e-19, 1.0511e-05, 4.6000e-31, 8.9617e-18,\n",
" 1.2070e-06, 6.0971e-10, 1.5755e-16],\n",
" [6.6248e-20, 1.1519e-42, 1.4801e-09, 4.5097e-14, 1.3133e-12, 0.0000e+00,\n",
" 3.3392e-11, 2.1073e-31, 1.6451e-26, 3.0954e-21, 3.6437e-10, 3.6631e-25,\n",
" 8.9187e-07, 7.8893e-24, 3.7848e-21, 9.4501e-18, 2.3735e-20, 6.5117e-30,\n",
" 1.9898e-42, 4.6795e-12, 4.4955e-19, 1.0000e+00, 8.9663e-34, 1.9880e-41,\n",
" 1.5975e-14, 6.6194e-12, 1.6366e-24],\n",
" [1.3479e-28, 7.0023e-38, 3.6616e-32, 1.7165e-20, 3.1452e-08, 1.9898e-43,\n",
" 1.1325e-10, 6.2044e-38, 4.5461e-37, 6.2867e-29, 7.8839e-23, 1.9618e-44,\n",
" 6.9054e-19, 7.3255e-29, 7.6490e-18, 4.1989e-27, 5.3466e-19, 2.9622e-34,\n",
" 0.0000e+00, 6.0975e-11, 1.9852e-29, 1.0000e+00, 5.9715e-29, 0.0000e+00,\n",
" 1.1249e-23, 5.7915e-24, 1.0314e-38],\n",
" [9.2334e-16, 1.1982e-23, 1.8237e-18, 2.7255e-06, 5.3642e-07, 3.3659e-35,\n",
" 9.9622e-01, 3.5788e-34, 1.0860e-17, 7.1806e-31, 3.7684e-03, 1.3917e-25,\n",
" 1.9584e-11, 1.1933e-24, 3.2485e-18, 9.0902e-09, 5.8568e-06, 8.2379e-10,\n",
" 1.5351e-10, 2.3523e-20, 3.7594e-22, 1.1572e-23, 5.5352e-38, 1.7066e-26,\n",
" 1.3043e-34, 8.2526e-24, 2.0988e-31]])"
]
},
"execution_count": 52,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prob"
]
},
{
"cell_type": "code",
"execution_count": 53,
"id": "5d0d5213",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32])"
]
},
"execution_count": 53,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y.shape"
]
},
{
"cell_type": "code",
"execution_count": 54,
"id": "c00f76b2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 5, 13, 13, 1, 0, 15, 12, 9, 22, 9, 1, 0, 1, 22, 1, 0, 9, 19,\n",
" 1, 2, 5, 12, 12, 1, 0, 19, 15, 16, 8, 9, 1, 0])"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we have 32 training records, and 27 dimensions\n",
"Y"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "8e9205e0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,\n",
" 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31])"
]
},
"execution_count": 55,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.arange(Y.shape[0])"
]
},
{
"cell_type": "code",
"execution_count": 56,
"id": "f5b57e82",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([2.4027e-01, 2.4594e-09, 9.6322e-16, 8.2677e-01, 5.6912e-18, 2.9170e-14,\n",
" 1.0041e-07, 9.2010e-09, 1.5812e-28, 0.0000e+00, 7.0911e-14, 2.1457e-19,\n",
" 7.0665e-01, 5.9718e-21, 1.4916e-10, 4.2523e-22, 1.2262e-09, 7.8968e-03,\n",
" 4.5088e-21, 9.9985e-01, 1.1017e-36, 2.6719e-16, 3.3476e-06, 1.2857e-33,\n",
" 4.4522e-22, 5.6506e-06, 8.9681e-22, 1.4013e-44, 9.4136e-18, 3.0954e-21,\n",
" 7.0023e-38, 9.2334e-16])"
]
},
"execution_count": 56,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# for each record, get the prob of the correct char\n",
"# some prob is really low -> log likelihood :D \n",
"prob[torch.arange(Y.shape[0]), Y]"
]
},
{
"cell_type": "code",
"execution_count": 57,
"id": "227429fe",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(inf)"
]
},
"execution_count": 57,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# negative average log likelihood\n",
"# log of each prob, average, and then negative\n",
"loss = -prob[torch.arange(Y.shape[0]), Y].log().mean()\n",
"loss"
]
},
{
"cell_type": "markdown",
"id": "628a232a",
"metadata": {},
"source": [
"# 4. train the neural net with backpropagation"
]
},
{
"cell_type": "code",
"execution_count": 58,
"id": "14c3b9a8",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([32, 3]), torch.Size([32]))"
]
},
"execution_count": 58,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape, Y.shape"
]
},
{
"cell_type": "code",
"execution_count": 59,
"id": "a4e2938d",
"metadata": {},
"outputs": [],
"source": [
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "e265acb2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3481"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sum(p.nelement() for p in parameters) # total number of parameters"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "b228f7d0",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(17.7697)"
]
},
"execution_count": 61,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# forward pass\n",
"emb = C[X]\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"# softmax classification\n",
"logits = h @ W2 + b2\n",
"counts = logits.exp()\n",
"prob = counts / counts.sum(1, keepdims=True)\n",
"loss = - prob[torch.arange(Y.shape[0]), Y].log().mean()\n",
"loss"
]
},
{
"cell_type": "markdown",
"id": "b8500dd2",
"metadata": {},
"source": [
"## softmax loss from scratch v.s. F.cross_entropy\n",
"F.cross_entropy more efficient\n",
"- 1. numerically well-behaved, with more extreme valued weights, with exp() -> we could get very huge numbers. forward pass more efficient\n",
"- 2. backward pass is more efficient \n",
"- 3. forward pass is more efficient. we do not create intermediate tensors as in logits."
]
},
{
"cell_type": "code",
"execution_count": 62,
"id": "c832a78b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor(17.7697)"
]
},
"execution_count": 62,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"F.cross_entropy(logits, Y)"
]
},
{
"cell_type": "code",
"execution_count": 63,
"id": "3e0ba1c2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0., 0., 0., nan])"
]
},
"execution_count": 63,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# numerical out of range in float\n",
"logits = torch.tensor([-100,-3,0,100])\n",
"counts = logits.exp()\n",
"prob = counts / counts.sum()\n",
"prob"
]
},
{
"cell_type": "code",
"execution_count": 64,
"id": "b84b0ee2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([3.7835e-44, 4.9787e-02, 1.0000e+00, inf])"
]
},
"execution_count": 64,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# e^100 is too BIG, out of range in float\n",
"counts"
]
},
{
"cell_type": "code",
"execution_count": 65,
"id": "2274f7c9",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0000e+00, 1.4013e-45, 3.7835e-44, 1.0000e+00])"
]
},
"execution_count": 65,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# use normalization to linearly offset the logits without affecting the prob\n",
"# pytorch internally SUBTRACTS the maximal value in the tensor\n",
"logits = torch.tensor([-100,-3,0,100]) - 100\n",
"counts = logits.exp()\n",
"prob = counts / counts.sum()\n",
"prob"
]
},
{
"cell_type": "markdown",
"id": "42671fe8",
"metadata": {},
"source": [
"## train with backward()"
]
},
{
"cell_type": "code",
"execution_count": 66,
"id": "04f208d9",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 17.76971435546875\n"
]
},
{
"ename": "RuntimeError",
"evalue": "element 0 of tensors does not require grad and does not have a grad_fn",
"output_type": "error",
"traceback": [
"\u001b[0;31m-------------------------\u001b[0m",
"\u001b[0;31mRuntimeError\u001b[0mTraceback (most recent call last)",
"\u001b[0;32m<ipython-input-66-b30885fcc503>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 16\u001b[0m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgrad\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 18\u001b[0;31m \u001b[0mloss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 19\u001b[0m \u001b[0;31m# update the parameters\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 20\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mp\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mparameters\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.pyenv/versions/offline-simulation-3.6.9/lib/python3.6/site-packages/torch/_tensor.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(self, gradient, retain_graph, create_graph, inputs)\u001b[0m\n\u001b[1;32m 305\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 306\u001b[0m inputs=inputs)\n\u001b[0;32m--> 307\u001b[0;31m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mautograd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgradient\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 308\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 309\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mregister_hook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m~/.pyenv/versions/offline-simulation-3.6.9/lib/python3.6/site-packages/torch/autograd/__init__.py\u001b[0m in \u001b[0;36mbackward\u001b[0;34m(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)\u001b[0m\n\u001b[1;32m 154\u001b[0m Variable._execution_engine.run_backward(\n\u001b[1;32m 155\u001b[0m \u001b[0mtensors\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgrad_tensors_\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mretain_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcreate_graph\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 156\u001b[0;31m allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag\n\u001b[0m\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 158\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mRuntimeError\u001b[0m: element 0 of tensors does not require grad and does not have a grad_fn"
]
}
],
"source": [
"# train the neural net\n",
"\n",
"learning_rate = 0.1\n",
"\n",
"for _ in range(10):\n",
" # forward pass\n",
" emb = C[X]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y)\n",
" print (_, loss.item())\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -learning_rate * p.grad"
]
},
{
"cell_type": "code",
"execution_count": 67,
"id": "8331f712",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 17.76971435546875\n",
"1 9.42310619354248\n",
"2 8.315447807312012\n",
"3 6.388716220855713\n",
"4 3.51501202583313\n",
"5 4.386330604553223\n",
"6 2.32904052734375\n",
"7 1.8342325687408447\n",
"8 2.4214775562286377\n",
"9 2.067505359649658\n"
]
}
],
"source": [
"# make sure to set requires_grad = true, otherwise will complain\n",
"for p in parameters:\n",
" p.requires_grad=True\n",
" \n",
"learning_rate = 0.5\n",
"\n",
"for _ in range(10):\n",
" # forward pass\n",
" emb = C[X]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y)\n",
" print (_, loss.item())\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -learning_rate * p.grad"
]
},
{
"cell_type": "code",
"execution_count": 68,
"id": "d63e0161",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.return_types.max(\n",
"values=tensor([17.8802, 12.6695, 15.6311, 17.3821, 18.0480, 17.8802, 14.4196, 11.9789,\n",
" 14.2498, 13.5953, 12.4208, 15.7624, 17.8802, 17.1180, 14.2179, 15.5572,\n",
" 17.8802, 13.5108, 11.3627, 16.2610, 16.8473, 12.2349, 7.4047, 6.9497,\n",
" 13.6933, 17.8802, 16.9499, 13.7836, 9.7913, 14.5982, 18.5608, 12.1148],\n",
" grad_fn=<MaxBackward0>),\n",
"indices=tensor([15, 12, 13, 1, 15, 15, 12, 12, 22, 9, 1, 0, 15, 15, 1, 15, 15, 15,\n",
" 15, 2, 5, 12, 12, 9, 0, 15, 15, 12, 8, 9, 1, 1]))"
]
},
"execution_count": 68,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# we overfit the model with 32 data points and 3k parameters!\n",
"# torch reports the max value and the index\n",
"logits.max(1)"
]
},
{
"cell_type": "code",
"execution_count": 69,
"id": "357e8783",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 5, 13, 13, 1, 0, 15, 12, 9, 22, 9, 1, 0, 1, 22, 1, 0, 9, 19,\n",
" 1, 2, 5, 12, 12, 1, 0, 19, 15, 16, 8, 9, 1, 0])"
]
},
"execution_count": 69,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"Y\n",
"# seems we predict super well\n",
"# we will not get 0 loss, because many words can be the starting char following \"...\""
]
},
{
"cell_type": "markdown",
"id": "af33602f",
"metadata": {},
"source": [
"## train with all words"
]
},
{
"cell_type": "code",
"execution_count": 70,
"id": "01930ea1",
"metadata": {},
"outputs": [],
"source": [
"# build the dataset\n",
"\n",
"# context to take, input X \n",
"block_size = 3\n",
"X, Y = [], []\n",
"for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix]\n",
" \n",
"X=torch.tensor(X)\n",
"Y=torch.tensor(Y)\n",
"\n",
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 71,
"id": "41732590",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([228146, 3]), torch.Size([228146]))"
]
},
"execution_count": 71,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X.shape, Y.shape"
]
},
{
"cell_type": "code",
"execution_count": 72,
"id": "e6e0fcd3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"0 19.505226135253906\n",
"1 14.167920112609863\n",
"2 11.403884887695312\n",
"3 10.201157569885254\n",
"4 8.886395454406738\n",
"5 7.84780216217041\n",
"6 7.272636890411377\n",
"7 6.522163391113281\n",
"8 6.065618515014648\n",
"9 5.838536739349365\n"
]
}
],
"source": [
"learning_rate = 0.5\n",
"\n",
"for _ in range(10):\n",
" # forward pass\n",
" emb = C[X]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y)\n",
" print (_, loss.item())\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -learning_rate * p.grad"
]
},
{
"cell_type": "markdown",
"id": "098cda74",
"metadata": {},
"source": [
"# 5. mini-batch\n",
"- using all data in forward and backward pass takes quite a long time in each iteration because each time we go through all the training examples!\n",
"- mini-batch randomly selects some data in each iteration for forward and backward pass\n",
"- pros: much faster in each iteration\n",
"- cons: less stable, may take longer time for loss to converge "
]
},
{
"cell_type": "code",
"execution_count": 73,
"id": "be243227",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([ 48049, 83229, 192839, 189260, 65187, 101329, 7717, 41259, 153683,\n",
" 122018, 5189, 147134, 29649, 101782, 192637, 109671, 132041, 63528,\n",
" 181138, 45778, 48738, 137578, 62475, 136876, 154447, 4831, 186918,\n",
" 131601, 148080, 94904, 126630, 193761])"
]
},
"execution_count": 73,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# generate index for the starting number \n",
"# generate 32 of them\n",
"batch_size = 32\n",
"# Returns a tensor filled with random integers generated uniformly\n",
"torch.randint(0, X.shape[0], (batch_size,))"
]
},
{
"cell_type": "code",
"execution_count": 74,
"id": "651c0f50",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32])"
]
},
"execution_count": 74,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"torch.randint(0, X.shape[0], (batch_size,)).shape"
]
},
{
"cell_type": "code",
"execution_count": 75,
"id": "a74a6e27",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"torch.Size([32, 3])"
]
},
"execution_count": 75,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"X[torch.randint(0, X.shape[0], (32,))].shape"
]
},
{
"cell_type": "code",
"execution_count": 76,
"id": "f552adb7",
"metadata": {},
"outputs": [],
"source": [
"# build the dataset\n",
"\n",
"# context to take, input X \n",
"block_size = 3\n",
"X, Y = [], []\n",
"for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix]\n",
" \n",
"X=torch.tensor(X)\n",
"Y=torch.tensor(Y)\n",
"\n",
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 77,
"id": "22dbc71c",
"metadata": {},
"outputs": [],
"source": [
"# a good learning rate is around 1, 10 is too big, 0.01 is too slow in the beginning, but good towards later stage\n",
"learning_rate = 0.1\n",
"\n",
"# much faster\n",
"# gradient is less reliable than the whole dataset\n",
"# direction is good enough \n",
"# update with small amount of data for multiple steps > update with all the data with fewer steps\n",
"for _ in range(10000):\n",
" # mini-batch index\n",
" ix = torch.randint(0, X.shape[0], (batch_size,))\n",
" \n",
" # forward pass\n",
" emb = C[X[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -learning_rate * p.grad"
]
},
{
"cell_type": "code",
"execution_count": 78,
"id": "c6536cdb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.7195205688476562\n"
]
}
],
"source": [
"print (loss.item())"
]
},
{
"cell_type": "markdown",
"id": "eee2bbeb",
"metadata": {},
"source": [
"# 6. find a reasonable learning rate"
]
},
{
"cell_type": "code",
"execution_count": 79,
"id": "9df1bea7",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0011,\n",
" 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011, 0.0011,\n",
" 0.0011, 0.0011, 0.0011, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012,\n",
" 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0012, 0.0013, 0.0013, 0.0013,\n",
" 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0013, 0.0014,\n",
" 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014, 0.0014,\n",
" 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015, 0.0015,\n",
" 0.0015, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016, 0.0016,\n",
" 0.0016, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017, 0.0017,\n",
" 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0018, 0.0019,\n",
" 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0019, 0.0020, 0.0020,\n",
" 0.0020, 0.0020, 0.0020, 0.0020, 0.0020, 0.0021, 0.0021, 0.0021, 0.0021,\n",
" 0.0021, 0.0021, 0.0021, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022, 0.0022,\n",
" 0.0022, 0.0023, 0.0023, 0.0023, 0.0023, 0.0023, 0.0023, 0.0024, 0.0024,\n",
" 0.0024, 0.0024, 0.0024, 0.0024, 0.0025, 0.0025, 0.0025, 0.0025, 0.0025,\n",
" 0.0025, 0.0026, 0.0026, 0.0026, 0.0026, 0.0026, 0.0027, 0.0027, 0.0027,\n",
" 0.0027, 0.0027, 0.0027, 0.0028, 0.0028, 0.0028, 0.0028, 0.0028, 0.0029,\n",
" 0.0029, 0.0029, 0.0029, 0.0029, 0.0030, 0.0030, 0.0030, 0.0030, 0.0030,\n",
" 0.0031, 0.0031, 0.0031, 0.0031, 0.0032, 0.0032, 0.0032, 0.0032, 0.0032,\n",
" 0.0033, 0.0033, 0.0033, 0.0033, 0.0034, 0.0034, 0.0034, 0.0034, 0.0034,\n",
" 0.0035, 0.0035, 0.0035, 0.0035, 0.0036, 0.0036, 0.0036, 0.0036, 0.0037,\n",
" 0.0037, 0.0037, 0.0037, 0.0038, 0.0038, 0.0038, 0.0039, 0.0039, 0.0039,\n",
" 0.0039, 0.0040, 0.0040, 0.0040, 0.0040, 0.0041, 0.0041, 0.0041, 0.0042,\n",
" 0.0042, 0.0042, 0.0042, 0.0043, 0.0043, 0.0043, 0.0044, 0.0044, 0.0044,\n",
" 0.0045, 0.0045, 0.0045, 0.0045, 0.0046, 0.0046, 0.0046, 0.0047, 0.0047,\n",
" 0.0047, 0.0048, 0.0048, 0.0048, 0.0049, 0.0049, 0.0049, 0.0050, 0.0050,\n",
" 0.0050, 0.0051, 0.0051, 0.0051, 0.0052, 0.0052, 0.0053, 0.0053, 0.0053,\n",
" 0.0054, 0.0054, 0.0054, 0.0055, 0.0055, 0.0056, 0.0056, 0.0056, 0.0057,\n",
" 0.0057, 0.0058, 0.0058, 0.0058, 0.0059, 0.0059, 0.0060, 0.0060, 0.0060,\n",
" 0.0061, 0.0061, 0.0062, 0.0062, 0.0062, 0.0063, 0.0063, 0.0064, 0.0064,\n",
" 0.0065, 0.0065, 0.0066, 0.0066, 0.0067, 0.0067, 0.0067, 0.0068, 0.0068,\n",
" 0.0069, 0.0069, 0.0070, 0.0070, 0.0071, 0.0071, 0.0072, 0.0072, 0.0073,\n",
" 0.0073, 0.0074, 0.0074, 0.0075, 0.0075, 0.0076, 0.0076, 0.0077, 0.0077,\n",
" 0.0078, 0.0079, 0.0079, 0.0080, 0.0080, 0.0081, 0.0081, 0.0082, 0.0082,\n",
" 0.0083, 0.0084, 0.0084, 0.0085, 0.0085, 0.0086, 0.0086, 0.0087, 0.0088,\n",
" 0.0088, 0.0089, 0.0090, 0.0090, 0.0091, 0.0091, 0.0092, 0.0093, 0.0093,\n",
" 0.0094, 0.0095, 0.0095, 0.0096, 0.0097, 0.0097, 0.0098, 0.0099, 0.0099,\n",
" 0.0100, 0.0101, 0.0101, 0.0102, 0.0103, 0.0104, 0.0104, 0.0105, 0.0106,\n",
" 0.0106, 0.0107, 0.0108, 0.0109, 0.0109, 0.0110, 0.0111, 0.0112, 0.0112,\n",
" 0.0113, 0.0114, 0.0115, 0.0116, 0.0116, 0.0117, 0.0118, 0.0119, 0.0120,\n",
" 0.0121, 0.0121, 0.0122, 0.0123, 0.0124, 0.0125, 0.0126, 0.0127, 0.0127,\n",
" 0.0128, 0.0129, 0.0130, 0.0131, 0.0132, 0.0133, 0.0134, 0.0135, 0.0136,\n",
" 0.0137, 0.0137, 0.0138, 0.0139, 0.0140, 0.0141, 0.0142, 0.0143, 0.0144,\n",
" 0.0145, 0.0146, 0.0147, 0.0148, 0.0149, 0.0150, 0.0151, 0.0152, 0.0154,\n",
" 0.0155, 0.0156, 0.0157, 0.0158, 0.0159, 0.0160, 0.0161, 0.0162, 0.0163,\n",
" 0.0165, 0.0166, 0.0167, 0.0168, 0.0169, 0.0170, 0.0171, 0.0173, 0.0174,\n",
" 0.0175, 0.0176, 0.0178, 0.0179, 0.0180, 0.0181, 0.0182, 0.0184, 0.0185,\n",
" 0.0186, 0.0188, 0.0189, 0.0190, 0.0192, 0.0193, 0.0194, 0.0196, 0.0197,\n",
" 0.0198, 0.0200, 0.0201, 0.0202, 0.0204, 0.0205, 0.0207, 0.0208, 0.0210,\n",
" 0.0211, 0.0212, 0.0214, 0.0215, 0.0217, 0.0218, 0.0220, 0.0221, 0.0223,\n",
" 0.0225, 0.0226, 0.0228, 0.0229, 0.0231, 0.0232, 0.0234, 0.0236, 0.0237,\n",
" 0.0239, 0.0241, 0.0242, 0.0244, 0.0246, 0.0247, 0.0249, 0.0251, 0.0253,\n",
" 0.0254, 0.0256, 0.0258, 0.0260, 0.0261, 0.0263, 0.0265, 0.0267, 0.0269,\n",
" 0.0271, 0.0273, 0.0274, 0.0276, 0.0278, 0.0280, 0.0282, 0.0284, 0.0286,\n",
" 0.0288, 0.0290, 0.0292, 0.0294, 0.0296, 0.0298, 0.0300, 0.0302, 0.0304,\n",
" 0.0307, 0.0309, 0.0311, 0.0313, 0.0315, 0.0317, 0.0320, 0.0322, 0.0324,\n",
" 0.0326, 0.0328, 0.0331, 0.0333, 0.0335, 0.0338, 0.0340, 0.0342, 0.0345,\n",
" 0.0347, 0.0350, 0.0352, 0.0354, 0.0357, 0.0359, 0.0362, 0.0364, 0.0367,\n",
" 0.0369, 0.0372, 0.0375, 0.0377, 0.0380, 0.0382, 0.0385, 0.0388, 0.0390,\n",
" 0.0393, 0.0396, 0.0399, 0.0401, 0.0404, 0.0407, 0.0410, 0.0413, 0.0416,\n",
" 0.0418, 0.0421, 0.0424, 0.0427, 0.0430, 0.0433, 0.0436, 0.0439, 0.0442,\n",
" 0.0445, 0.0448, 0.0451, 0.0455, 0.0458, 0.0461, 0.0464, 0.0467, 0.0471,\n",
" 0.0474, 0.0477, 0.0480, 0.0484, 0.0487, 0.0491, 0.0494, 0.0497, 0.0501,\n",
" 0.0504, 0.0508, 0.0511, 0.0515, 0.0518, 0.0522, 0.0526, 0.0529, 0.0533,\n",
" 0.0537, 0.0540, 0.0544, 0.0548, 0.0552, 0.0556, 0.0559, 0.0563, 0.0567,\n",
" 0.0571, 0.0575, 0.0579, 0.0583, 0.0587, 0.0591, 0.0595, 0.0599, 0.0604,\n",
" 0.0608, 0.0612, 0.0616, 0.0621, 0.0625, 0.0629, 0.0634, 0.0638, 0.0642,\n",
" 0.0647, 0.0651, 0.0656, 0.0660, 0.0665, 0.0670, 0.0674, 0.0679, 0.0684,\n",
" 0.0688, 0.0693, 0.0698, 0.0703, 0.0708, 0.0713, 0.0718, 0.0723, 0.0728,\n",
" 0.0733, 0.0738, 0.0743, 0.0748, 0.0753, 0.0758, 0.0764, 0.0769, 0.0774,\n",
" 0.0780, 0.0785, 0.0790, 0.0796, 0.0802, 0.0807, 0.0813, 0.0818, 0.0824,\n",
" 0.0830, 0.0835, 0.0841, 0.0847, 0.0853, 0.0859, 0.0865, 0.0871, 0.0877,\n",
" 0.0883, 0.0889, 0.0895, 0.0901, 0.0908, 0.0914, 0.0920, 0.0927, 0.0933,\n",
" 0.0940, 0.0946, 0.0953, 0.0959, 0.0966, 0.0973, 0.0979, 0.0986, 0.0993,\n",
" 0.1000, 0.1007, 0.1014, 0.1021, 0.1028, 0.1035, 0.1042, 0.1050, 0.1057,\n",
" 0.1064, 0.1072, 0.1079, 0.1087, 0.1094, 0.1102, 0.1109, 0.1117, 0.1125,\n",
" 0.1133, 0.1140, 0.1148, 0.1156, 0.1164, 0.1172, 0.1181, 0.1189, 0.1197,\n",
" 0.1205, 0.1214, 0.1222, 0.1231, 0.1239, 0.1248, 0.1256, 0.1265, 0.1274,\n",
" 0.1283, 0.1292, 0.1301, 0.1310, 0.1319, 0.1328, 0.1337, 0.1346, 0.1356,\n",
" 0.1365, 0.1374, 0.1384, 0.1394, 0.1403, 0.1413, 0.1423, 0.1433, 0.1443,\n",
" 0.1453, 0.1463, 0.1473, 0.1483, 0.1493, 0.1504, 0.1514, 0.1525, 0.1535,\n",
" 0.1546, 0.1557, 0.1567, 0.1578, 0.1589, 0.1600, 0.1611, 0.1623, 0.1634,\n",
" 0.1645, 0.1657, 0.1668, 0.1680, 0.1691, 0.1703, 0.1715, 0.1727, 0.1739,\n",
" 0.1751, 0.1763, 0.1775, 0.1788, 0.1800, 0.1812, 0.1825, 0.1838, 0.1850,\n",
" 0.1863, 0.1876, 0.1889, 0.1902, 0.1916, 0.1929, 0.1942, 0.1956, 0.1969,\n",
" 0.1983, 0.1997, 0.2010, 0.2024, 0.2038, 0.2053, 0.2067, 0.2081, 0.2096,\n",
" 0.2110, 0.2125, 0.2140, 0.2154, 0.2169, 0.2184, 0.2200, 0.2215, 0.2230,\n",
" 0.2246, 0.2261, 0.2277, 0.2293, 0.2309, 0.2325, 0.2341, 0.2357, 0.2373,\n",
" 0.2390, 0.2406, 0.2423, 0.2440, 0.2457, 0.2474, 0.2491, 0.2508, 0.2526,\n",
" 0.2543, 0.2561, 0.2579, 0.2597, 0.2615, 0.2633, 0.2651, 0.2669, 0.2688,\n",
" 0.2707, 0.2725, 0.2744, 0.2763, 0.2783, 0.2802, 0.2821, 0.2841, 0.2861,\n",
" 0.2880, 0.2900, 0.2921, 0.2941, 0.2961, 0.2982, 0.3002, 0.3023, 0.3044,\n",
" 0.3065, 0.3087, 0.3108, 0.3130, 0.3151, 0.3173, 0.3195, 0.3217, 0.3240,\n",
" 0.3262, 0.3285, 0.3308, 0.3331, 0.3354, 0.3377, 0.3400, 0.3424, 0.3448,\n",
" 0.3472, 0.3496, 0.3520, 0.3544, 0.3569, 0.3594, 0.3619, 0.3644, 0.3669,\n",
" 0.3695, 0.3720, 0.3746, 0.3772, 0.3798, 0.3825, 0.3851, 0.3878, 0.3905,\n",
" 0.3932, 0.3959, 0.3987, 0.4014, 0.4042, 0.4070, 0.4098, 0.4127, 0.4155,\n",
" 0.4184, 0.4213, 0.4243, 0.4272, 0.4302, 0.4331, 0.4362, 0.4392, 0.4422,\n",
" 0.4453, 0.4484, 0.4515, 0.4546, 0.4578, 0.4610, 0.4642, 0.4674, 0.4706,\n",
" 0.4739, 0.4772, 0.4805, 0.4838, 0.4872, 0.4906, 0.4940, 0.4974, 0.5008,\n",
" 0.5043, 0.5078, 0.5113, 0.5149, 0.5185, 0.5221, 0.5257, 0.5293, 0.5330,\n",
" 0.5367, 0.5404, 0.5442, 0.5479, 0.5517, 0.5556, 0.5594, 0.5633, 0.5672,\n",
" 0.5712, 0.5751, 0.5791, 0.5831, 0.5872, 0.5913, 0.5954, 0.5995, 0.6036,\n",
" 0.6078, 0.6120, 0.6163, 0.6206, 0.6249, 0.6292, 0.6336, 0.6380, 0.6424,\n",
" 0.6469, 0.6513, 0.6559, 0.6604, 0.6650, 0.6696, 0.6743, 0.6789, 0.6837,\n",
" 0.6884, 0.6932, 0.6980, 0.7028, 0.7077, 0.7126, 0.7176, 0.7225, 0.7275,\n",
" 0.7326, 0.7377, 0.7428, 0.7480, 0.7531, 0.7584, 0.7636, 0.7689, 0.7743,\n",
" 0.7796, 0.7850, 0.7905, 0.7960, 0.8015, 0.8071, 0.8127, 0.8183, 0.8240,\n",
" 0.8297, 0.8355, 0.8412, 0.8471, 0.8530, 0.8589, 0.8648, 0.8708, 0.8769,\n",
" 0.8830, 0.8891, 0.8953, 0.9015, 0.9077, 0.9140, 0.9204, 0.9268, 0.9332,\n",
" 0.9397, 0.9462, 0.9528, 0.9594, 0.9660, 0.9727, 0.9795, 0.9863, 0.9931,\n",
" 1.0000])"
]
},
"execution_count": 79,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# hyperparameter tuning\n",
"lre = torch.linspace(-3, 0, 1000)\n",
"lrs = 10**lre # search between 0.001 to 1\n",
"lrs"
]
},
{
"cell_type": "code",
"execution_count": 80,
"id": "8d1e7860",
"metadata": {},
"outputs": [],
"source": [
"# build the dataset\n",
"\n",
"# context to take, input X \n",
"block_size = 3\n",
"X, Y = [], []\n",
"for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix]\n",
" \n",
"X=torch.tensor(X)\n",
"Y=torch.tensor(Y)\n",
"\n",
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 81,
"id": "4fca2cf3",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"20.058767318725586\n",
"18.386871337890625\n",
"20.938833236694336\n",
"18.30322265625\n",
"18.207124710083008\n",
"20.926156997680664\n",
"19.694604873657227\n",
"20.796886444091797\n",
"18.43265151977539\n",
"17.471906661987305\n",
"17.41988182067871\n",
"18.475698471069336\n",
"21.594234466552734\n",
"20.208314895629883\n",
"18.260589599609375\n",
"20.787334442138672\n",
"18.640947341918945\n",
"21.286579132080078\n",
"19.08941650390625\n",
"19.984092712402344\n",
"19.909046173095703\n",
"17.2227725982666\n",
"18.984329223632812\n",
"18.248714447021484\n",
"19.701366424560547\n",
"16.733842849731445\n",
"17.595561981201172\n",
"19.489112854003906\n",
"18.0367488861084\n",
"17.623903274536133\n",
"18.184213638305664\n",
"18.172489166259766\n",
"16.67631721496582\n",
"19.491979598999023\n",
"18.97934913635254\n",
"21.72498321533203\n",
"19.489974975585938\n",
"16.01494789123535\n",
"21.133420944213867\n",
"18.17195701599121\n",
"19.101388931274414\n",
"19.54134178161621\n",
"17.296428680419922\n",
"17.832054138183594\n",
"18.379337310791016\n",
"17.681211471557617\n",
"16.202953338623047\n",
"19.22760772705078\n",
"16.480175018310547\n",
"17.625823974609375\n",
"17.91048240661621\n",
"20.03180694580078\n",
"15.48385238647461\n",
"20.130367279052734\n",
"15.727002143859863\n",
"19.727458953857422\n",
"17.598100662231445\n",
"18.688796997070312\n",
"18.924501419067383\n",
"15.734272956848145\n",
"19.930004119873047\n",
"16.35295867919922\n",
"14.893880844116211\n",
"19.13114356994629\n",
"17.03234100341797\n",
"17.192468643188477\n",
"17.264965057373047\n",
"15.996824264526367\n",
"13.68630313873291\n",
"18.991613388061523\n",
"16.458696365356445\n",
"18.184864044189453\n",
"18.45516586303711\n",
"19.210384368896484\n",
"19.555805206298828\n",
"19.245872497558594\n",
"15.916244506835938\n",
"18.539262771606445\n",
"15.675761222839355\n",
"15.922590255737305\n",
"18.28230094909668\n",
"18.379098892211914\n",
"17.07571792602539\n",
"17.531736373901367\n",
"16.743484497070312\n",
"18.085182189941406\n",
"19.677038192749023\n",
"16.05499267578125\n",
"16.966365814208984\n",
"16.9460391998291\n",
"16.56806755065918\n",
"17.405820846557617\n",
"17.567480087280273\n",
"13.594080924987793\n",
"15.556037902832031\n",
"13.114115715026855\n",
"13.831795692443848\n",
"15.224092483520508\n",
"16.704626083374023\n",
"14.737277030944824\n",
"18.403032302856445\n",
"16.741727828979492\n",
"15.251632690429688\n",
"15.746190071105957\n",
"15.956998825073242\n",
"21.602508544921875\n",
"14.20954418182373\n",
"14.598978042602539\n",
"17.341718673706055\n",
"16.25075340270996\n",
"15.769144058227539\n",
"16.367794036865234\n",
"16.197172164916992\n",
"18.48675537109375\n",
"17.168405532836914\n",
"14.934187889099121\n",
"17.833589553833008\n",
"17.77256202697754\n",
"17.94159698486328\n",
"17.360082626342773\n",
"14.95344352722168\n",
"13.779696464538574\n",
"17.619510650634766\n",
"17.39665985107422\n",
"17.34027671813965\n",
"15.9515380859375\n",
"18.24655532836914\n",
"14.980691909790039\n",
"15.100189208984375\n",
"18.663522720336914\n",
"14.429710388183594\n",
"14.762701988220215\n",
"15.34671401977539\n",
"15.603941917419434\n",
"17.844085693359375\n",
"14.396878242492676\n",
"16.44132423400879\n",
"14.74105453491211\n",
"14.460693359375\n",
"14.309600830078125\n",
"14.838363647460938\n",
"16.653427124023438\n",
"13.572357177734375\n",
"14.395666122436523\n",
"18.12348175048828\n",
"14.245956420898438\n",
"15.573516845703125\n",
"15.885431289672852\n",
"14.404387474060059\n",
"17.079296112060547\n",
"15.201760292053223\n",
"15.457757949829102\n",
"13.696967124938965\n",
"15.21992015838623\n",
"13.501822471618652\n",
"15.713217735290527\n",
"17.039230346679688\n",
"15.352463722229004\n",
"13.331183433532715\n",
"14.037686347961426\n",
"14.475912094116211\n",
"15.033218383789062\n",
"14.212654113769531\n",
"15.011879920959473\n",
"13.989218711853027\n",
"14.684889793395996\n",
"16.003490447998047\n",
"13.3534517288208\n",
"18.113367080688477\n",
"15.766525268554688\n",
"13.79685115814209\n",
"15.681324005126953\n",
"14.373132705688477\n",
"13.612071990966797\n",
"14.888436317443848\n",
"13.91701602935791\n",
"14.809433937072754\n",
"15.310776710510254\n",
"14.299510955810547\n",
"15.394023895263672\n",
"16.06818962097168\n",
"13.481806755065918\n",
"14.091451644897461\n",
"17.476533889770508\n",
"13.11625862121582\n",
"15.260262489318848\n",
"14.466033935546875\n",
"14.652369499206543\n",
"15.203973770141602\n",
"12.611348152160645\n",
"13.802355766296387\n",
"13.701678276062012\n",
"12.002172470092773\n",
"17.47108268737793\n",
"12.778606414794922\n",
"16.25493621826172\n",
"12.363943099975586\n",
"14.833003997802734\n",
"12.832207679748535\n",
"15.066575050354004\n",
"14.566367149353027\n",
"14.7083158493042\n",
"11.102749824523926\n",
"15.345829963684082\n",
"14.21712875366211\n",
"10.669098854064941\n",
"14.981069564819336\n",
"12.768899917602539\n",
"14.53537368774414\n",
"15.360010147094727\n",
"12.52229118347168\n",
"15.298910140991211\n",
"12.180039405822754\n",
"15.207792282104492\n",
"12.739935874938965\n",
"14.423806190490723\n",
"12.861767768859863\n",
"14.22921085357666\n",
"10.41986083984375\n",
"13.890667915344238\n",
"10.88259220123291\n",
"13.321094512939453\n",
"11.663039207458496\n",
"11.932415008544922\n",
"10.430420875549316\n",
"10.10925579071045\n",
"11.853407859802246\n",
"10.79029369354248\n",
"14.451656341552734\n",
"14.122856140136719\n",
"13.67242431640625\n",
"13.091327667236328\n",
"13.005173683166504\n",
"12.685237884521484\n",
"13.005620956420898\n",
"13.553939819335938\n",
"12.267475128173828\n",
"14.190502166748047\n",
"11.30736255645752\n",
"10.914015769958496\n",
"12.445353507995605\n",
"10.342656135559082\n",
"12.580787658691406\n",
"10.81489372253418\n",
"10.514963150024414\n",
"12.052750587463379\n",
"12.980908393859863\n",
"12.380534172058105\n",
"11.733548164367676\n",
"14.590177536010742\n",
"11.620329856872559\n",
"11.907234191894531\n",
"13.085349082946777\n",
"14.325329780578613\n",
"10.8388032913208\n",
"14.467137336730957\n",
"9.308248519897461\n",
"8.384876251220703\n",
"12.706113815307617\n",
"14.621389389038086\n",
"13.659004211425781\n",
"12.338562965393066\n",
"14.114700317382812\n",
"11.619922637939453\n",
"10.123787879943848\n",
"11.096317291259766\n",
"12.064632415771484\n",
"12.706136703491211\n",
"11.947173118591309\n",
"13.065418243408203\n",
"13.177108764648438\n",
"11.022246360778809\n",
"11.563748359680176\n",
"10.7993803024292\n",
"12.816320419311523\n",
"10.871781349182129\n",
"10.823201179504395\n",
"11.811736106872559\n",
"11.892790794372559\n",
"10.869131088256836\n",
"11.294859886169434\n",
"9.91629409790039\n",
"10.768162727355957\n",
"9.785820960998535\n",
"11.019686698913574\n",
"14.884580612182617\n",
"8.939095497131348\n",
"11.121686935424805\n",
"9.681001663208008\n",
"14.208605766296387\n",
"10.606358528137207\n",
"10.238022804260254\n",
"14.800966262817383\n",
"9.842073440551758\n",
"10.831656455993652\n",
"11.870290756225586\n",
"11.204913139343262\n",
"11.685627937316895\n",
"14.99278736114502\n",
"9.216063499450684\n",
"9.720016479492188\n",
"10.127654075622559\n",
"10.670577049255371\n",
"11.96092700958252\n",
"10.001982688903809\n",
"9.47696590423584\n",
"10.871495246887207\n",
"9.49959659576416\n",
"7.812280654907227\n",
"10.504291534423828\n",
"9.808164596557617\n",
"12.335737228393555\n",
"8.316473007202148\n",
"10.166391372680664\n",
"11.575448036193848\n",
"13.462331771850586\n",
"7.442794322967529\n",
"9.652787208557129\n",
"9.31308364868164\n",
"8.436732292175293\n",
"9.656270027160645\n",
"10.529017448425293\n",
"10.405807495117188\n",
"9.85356330871582\n",
"8.51788330078125\n",
"10.666929244995117\n",
"9.939424514770508\n",
"10.175727844238281\n",
"10.883666038513184\n",
"8.571637153625488\n",
"8.996121406555176\n",
"9.41479206085205\n",
"9.743539810180664\n",
"10.75446605682373\n",
"9.327347755432129\n",
"10.134511947631836\n",
"8.580121040344238\n",
"11.218292236328125\n",
"9.430485725402832\n",
"8.714287757873535\n",
"8.426252365112305\n",
"10.338098526000977\n",
"11.220523834228516\n",
"10.878255844116211\n",
"9.532553672790527\n",
"8.539838790893555\n",
"7.672667980194092\n",
"10.36310863494873\n",
"10.894793510437012\n",
"9.161340713500977\n",
"10.406136512756348\n",
"8.765401840209961\n",
"11.612442016601562\n",
"8.186119079589844\n",
"11.474461555480957\n",
"9.826943397521973\n",
"7.516147136688232\n",
"9.676780700683594\n",
"7.89693546295166\n",
"8.4464693069458\n",
"9.113298416137695\n",
"9.641304969787598\n",
"9.967488288879395\n",
"10.455885887145996\n",
"8.237849235534668\n",
"7.047585487365723\n",
"7.616254806518555\n",
"9.337509155273438\n",
"9.760736465454102\n",
"8.914117813110352\n",
"8.533182144165039\n",
"8.845060348510742\n",
"8.286819458007812\n",
"7.084433555603027\n",
"8.056537628173828\n",
"7.038687705993652\n",
"10.185018539428711\n",
"9.052075386047363\n",
"8.2262601852417\n",
"7.8366827964782715\n",
"8.655904769897461\n",
"8.646445274353027\n",
"10.069315910339355\n",
"8.779277801513672\n",
"8.074405670166016\n",
"7.853659152984619\n",
"8.303882598876953\n",
"8.005335807800293\n",
"6.686523914337158\n",
"5.987373352050781\n",
"7.754853248596191\n",
"9.928034782409668\n",
"6.959237098693848\n",
"7.998958587646484\n",
"7.160789489746094\n",
"9.51343059539795\n",
"7.81441068649292\n",
"9.714046478271484\n",
"9.615175247192383\n",
"7.681924819946289\n",
"8.25478744506836\n",
"8.42897891998291\n",
"8.113512992858887\n",
"7.563684463500977\n",
"8.473864555358887\n",
"6.3443403244018555\n",
"9.098320007324219\n",
"7.37877893447876\n",
"6.578047275543213\n",
"7.377756118774414\n",
"8.60280990600586\n",
"7.422914981842041\n",
"7.379459857940674\n",
"8.935554504394531\n",
"6.389432430267334\n",
"8.996026039123535\n",
"8.94339370727539\n",
"5.5102949142456055\n",
"6.963716983795166\n",
"11.425223350524902\n",
"7.493052959442139\n",
"8.636194229125977\n",
"6.408836841583252\n",
"8.882610321044922\n",
"6.552365779876709\n",
"6.99999475479126\n",
"8.897416114807129\n",
"7.332728862762451\n",
"6.513792991638184\n",
"7.331733703613281\n",
"6.8848090171813965\n",
"8.904072761535645\n",
"6.285671710968018\n",
"6.721429824829102\n",
"8.188508987426758\n",
"8.023738861083984\n",
"8.223982810974121\n",
"7.631497859954834\n",
"6.868099689483643\n",
"7.845915794372559\n",
"8.994128227233887\n",
"6.873394012451172\n",
"6.575554847717285\n",
"5.681668758392334\n",
"9.183479309082031\n",
"7.379150390625\n",
"6.419661045074463\n",
"6.1244611740112305\n",
"7.120537281036377\n",
"7.353793144226074\n",
"6.679843902587891\n",
"6.7223920822143555\n",
"6.166895866394043\n",
"6.358139514923096\n",
"8.206326484680176\n",
"4.982204437255859\n",
"7.02251672744751\n",
"6.5761237144470215\n",
"6.247386932373047\n",
"5.406505107879639\n",
"5.372550010681152\n",
"5.643982410430908\n",
"4.830392837524414\n",
"5.968005180358887\n",
"6.492893218994141\n",
"4.8271164894104\n",
"7.547109603881836\n",
"5.955235481262207\n",
"6.241334438323975\n",
"7.30716609954834\n",
"7.0900068283081055\n",
"6.036428928375244\n",
"7.084930896759033\n",
"4.779949188232422\n",
"7.399024963378906\n",
"6.500931739807129\n",
"5.981712818145752\n",
"5.332479953765869\n",
"5.2537617683410645\n",
"5.506017208099365\n",
"6.235259056091309\n",
"6.222657680511475\n",
"4.586050510406494\n",
"6.005897521972656\n",
"6.667596817016602\n",
"5.921751499176025\n",
"5.109879016876221\n",
"3.7901201248168945\n",
"4.821964740753174\n",
"5.593708038330078\n",
"6.004891395568848\n",
"4.737380027770996\n",
"6.811749458312988\n",
"5.370182514190674\n",
"6.791818618774414\n",
"5.8135175704956055\n",
"6.2312235832214355\n",
"5.7961883544921875\n",
"5.895395755767822\n",
"7.342142105102539\n",
"5.725778102874756\n",
"4.989012718200684\n",
"4.228612422943115\n",
"4.350794315338135\n",
"5.3611674308776855\n",
"4.481586456298828\n",
"5.720825672149658\n",
"5.554401874542236\n",
"5.9270429611206055\n",
"4.660965919494629\n",
"5.52239465713501\n",
"5.524663925170898\n",
"4.811008930206299\n",
"5.698330879211426\n",
"6.13444185256958\n",
"4.305564880371094\n",
"4.850207328796387\n",
"4.782961845397949\n",
"5.758709907531738\n",
"5.040655136108398\n",
"4.9152045249938965\n",
"5.848385334014893\n",
"4.637724876403809\n",
"6.089776515960693\n",
"4.190829753875732\n",
"4.483164310455322\n",
"5.096348285675049\n",
"4.175036430358887\n",
"4.132568836212158\n",
"5.668691635131836\n",
"5.299254894256592\n",
"3.9058778285980225\n",
"6.583087921142578\n",
"5.95508337020874\n",
"5.55558967590332\n",
"3.896312952041626\n",
"5.0055389404296875\n",
"3.8073885440826416\n",
"4.782783031463623\n",
"5.341424942016602\n",
"5.220652103424072\n",
"4.3145551681518555\n",
"4.867645263671875\n",
"4.317051887512207\n",
"4.742574214935303\n",
"4.937697410583496\n",
"5.328518867492676\n",
"4.958920955657959\n",
"3.481821298599243\n",
"4.7215189933776855\n",
"5.35048770904541\n",
"4.197367191314697\n",
"4.604339122772217\n",
"4.23129940032959\n",
"4.756180763244629\n",
"4.884371757507324\n",
"5.155618190765381\n",
"4.574168682098389\n",
"3.391645669937134\n",
"3.3872737884521484\n",
"4.672423362731934\n",
"4.533355712890625\n",
"5.220793724060059\n",
"3.8227334022521973\n",
"3.7068281173706055\n",
"4.240579128265381\n",
"4.195023059844971\n",
"4.6160888671875\n",
"4.480436325073242\n",
"4.624906539916992\n",
"4.456013202667236\n",
"5.195305824279785\n",
"3.5553925037384033\n",
"3.7349674701690674\n",
"4.430811405181885\n",
"4.551756858825684\n",
"3.7298829555511475\n",
"4.410430431365967\n",
"3.553434371948242\n",
"3.4611258506774902\n",
"4.55808687210083\n",
"3.088973045349121\n",
"4.4551920890808105\n",
"3.5407283306121826\n",
"4.2899980545043945\n",
"4.096808433532715\n",
"3.5416300296783447\n",
"4.437430381774902\n",
"3.881909132003784\n",
"3.404031753540039\n",
"4.887137413024902\n",
"3.174898386001587\n",
"3.9722282886505127\n",
"4.0994086265563965\n",
"4.036005020141602\n",
"3.7181835174560547\n",
"3.022531747817993\n",
"4.7746100425720215\n",
"3.107983112335205\n",
"3.581955909729004\n",
"4.060939311981201\n",
"3.075693130493164\n",
"4.609142780303955\n",
"3.3816468715667725\n",
"5.256540298461914\n",
"3.156243324279785\n",
"4.870355129241943\n",
"4.980332851409912\n",
"4.01073694229126\n",
"3.294806957244873\n",
"4.303383827209473\n",
"3.302541732788086\n",
"3.692582845687866\n",
"4.456576824188232\n",
"3.896005630493164\n",
"3.8528528213500977\n",
"4.296661853790283\n",
"3.214301347732544\n",
"3.0818440914154053\n",
"3.5100314617156982\n",
"3.1930394172668457\n",
"2.9393343925476074\n",
"2.919846773147583\n",
"3.8025946617126465\n",
"2.892787456512451\n",
"4.956643581390381\n",
"2.888521671295166\n",
"3.342435359954834\n",
"3.0627057552337646\n",
"4.08694314956665\n",
"3.3463854789733887\n",
"4.060541152954102\n",
"2.9815173149108887\n",
"3.2821357250213623\n",
"3.0937447547912598\n",
"3.0083119869232178\n",
"3.954728126525879\n",
"3.924070119857788\n",
"3.2576563358306885\n",
"3.7045695781707764\n",
"3.2489681243896484\n",
"2.9568793773651123\n",
"3.5199520587921143\n",
"3.453016996383667\n",
"3.802501678466797\n",
"3.637922763824463\n",
"3.216110944747925\n",
"3.706007242202759\n",
"3.984416961669922\n",
"3.632723569869995\n",
"3.1989638805389404\n",
"3.6497514247894287\n",
"3.1191749572753906\n",
"4.295158386230469\n",
"2.5521702766418457\n",
"3.194572925567627\n",
"3.4275803565979004\n",
"3.4875407218933105\n",
"3.0968613624572754\n",
"2.905925750732422\n",
"2.9836156368255615\n",
"3.2466132640838623\n",
"3.4357643127441406\n",
"3.062093734741211\n",
"3.6897025108337402\n",
"4.675938129425049\n",
"3.194821357727051\n",
"3.059985876083374\n",
"2.917567491531372\n",
"3.519122838973999\n",
"3.1748576164245605\n",
"3.006070852279663\n",
"2.9797580242156982\n",
"3.3550615310668945\n",
"3.930544376373291\n",
"2.6503937244415283\n",
"3.325307607650757\n",
"3.9021999835968018\n",
"2.6885509490966797\n",
"3.1735334396362305\n",
"3.559260368347168\n",
"3.3719849586486816\n",
"3.2907395362854004\n",
"3.0547420978546143\n",
"3.8450896739959717\n",
"3.3199620246887207\n",
"3.0020692348480225\n",
"3.428286552429199\n",
"3.70822811126709\n",
"3.3926849365234375\n",
"3.313302516937256\n",
"2.6763691902160645\n",
"2.7028677463531494\n",
"3.20300555229187\n",
"3.1753597259521484\n",
"3.429269790649414\n",
"3.1132779121398926\n",
"2.4863622188568115\n",
"3.024989604949951\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.143970251083374\n",
"2.5822694301605225\n",
"3.616262674331665\n",
"3.358060598373413\n",
"3.2837748527526855\n",
"3.4139809608459473\n",
"3.3856797218322754\n",
"2.7517240047454834\n",
"2.959219455718994\n",
"3.2085742950439453\n",
"3.230847120285034\n",
"3.00468111038208\n",
"2.6261143684387207\n",
"3.411679983139038\n",
"2.6447250843048096\n",
"2.7005319595336914\n",
"3.232732057571411\n",
"2.8750433921813965\n",
"3.520543098449707\n",
"2.8991739749908447\n",
"2.8904154300689697\n",
"3.335700273513794\n",
"3.2348315715789795\n",
"3.2892212867736816\n",
"3.896378993988037\n",
"2.8900647163391113\n",
"3.0431835651397705\n",
"3.039491891860962\n",
"3.037858009338379\n",
"3.210955858230591\n",
"4.34665584564209\n",
"2.8819639682769775\n",
"3.726290225982666\n",
"3.7548255920410156\n",
"3.493394613265991\n",
"3.512141227722168\n",
"3.016710042953491\n",
"2.9070377349853516\n",
"3.4122490882873535\n",
"3.086726188659668\n",
"3.199303388595581\n",
"3.335646390914917\n",
"3.1801059246063232\n",
"3.533473491668701\n",
"3.0448317527770996\n",
"2.899057149887085\n",
"3.0487630367279053\n",
"3.3654282093048096\n",
"3.331571578979492\n",
"2.6786227226257324\n",
"3.215420722961426\n",
"2.7102489471435547\n",
"3.626858949661255\n",
"3.428629159927368\n",
"2.43825101852417\n",
"2.8668174743652344\n",
"3.504131555557251\n",
"2.7836339473724365\n",
"3.1075804233551025\n",
"3.2532730102539062\n",
"3.036126136779785\n",
"2.777801275253296\n",
"2.84604549407959\n",
"2.948404312133789\n",
"3.195976734161377\n",
"3.2230095863342285\n",
"3.0998973846435547\n",
"3.718081474304199\n",
"2.800168037414551\n",
"3.2489216327667236\n",
"2.8199150562286377\n",
"3.6910431385040283\n",
"3.173072338104248\n",
"2.7641613483428955\n",
"3.77432918548584\n",
"3.624133825302124\n",
"3.227541446685791\n",
"3.183627128601074\n",
"3.283186435699463\n",
"3.057158946990967\n",
"2.777653694152832\n",
"2.889397144317627\n",
"3.3341434001922607\n",
"3.4635066986083984\n",
"2.904613971710205\n",
"3.0934841632843018\n",
"3.5386648178100586\n",
"3.4154884815216064\n",
"3.043239116668701\n",
"3.262671947479248\n",
"4.166245937347412\n",
"2.961470127105713\n",
"2.8482401371002197\n",
"3.534893274307251\n",
"3.1096999645233154\n",
"2.7225348949432373\n",
"3.547924041748047\n",
"3.745229721069336\n",
"2.809553861618042\n",
"3.616903305053711\n",
"3.6580862998962402\n",
"3.534134864807129\n",
"2.620283365249634\n",
"2.9316813945770264\n",
"3.368833541870117\n",
"3.4300267696380615\n",
"3.5138981342315674\n",
"3.6709792613983154\n",
"3.0291519165039062\n",
"3.6491801738739014\n",
"4.69722318649292\n",
"3.6200473308563232\n",
"3.425041913986206\n",
"3.839419364929199\n",
"3.556222677230835\n",
"3.6515636444091797\n",
"3.271301746368408\n",
"3.6069674491882324\n",
"3.619744300842285\n",
"3.2222707271575928\n",
"3.6004958152770996\n",
"3.295170545578003\n",
"3.294151782989502\n",
"3.963306188583374\n",
"3.9888854026794434\n",
"3.395801544189453\n",
"2.7327609062194824\n",
"4.557826042175293\n",
"3.7835001945495605\n",
"4.577127456665039\n",
"3.4779865741729736\n",
"3.5828535556793213\n",
"3.474299192428589\n",
"3.554988145828247\n",
"3.217632532119751\n",
"3.7217812538146973\n",
"3.207597494125366\n",
"3.381564140319824\n",
"3.94753098487854\n",
"4.8308258056640625\n",
"5.170791149139404\n",
"3.3575708866119385\n",
"4.269369125366211\n",
"3.247758626937866\n",
"4.41501522064209\n",
"4.021042823791504\n",
"3.817215919494629\n",
"3.2845804691314697\n",
"4.472692489624023\n",
"4.064505100250244\n",
"3.219364881515503\n",
"3.844644784927368\n",
"4.002055644989014\n",
"3.275210380554199\n",
"3.7533915042877197\n",
"4.884897708892822\n",
"3.598559617996216\n",
"4.256794452667236\n",
"3.9089903831481934\n",
"3.7389872074127197\n",
"4.224829196929932\n",
"3.9468448162078857\n",
"3.9633235931396484\n",
"3.9781320095062256\n",
"4.462779521942139\n",
"3.4086098670959473\n",
"4.164345741271973\n",
"8.31614875793457\n",
"5.93451452255249\n",
"6.095202922821045\n",
"4.018470287322998\n",
"3.7814743518829346\n",
"4.7484130859375\n",
"4.075839042663574\n",
"4.838449001312256\n",
"4.355576992034912\n",
"4.112987518310547\n",
"4.829951286315918\n",
"4.8745436668396\n",
"4.683174133300781\n",
"4.704440593719482\n",
"3.4863169193267822\n",
"5.267629146575928\n",
"4.023910999298096\n",
"3.553920030593872\n",
"4.033780097961426\n",
"5.377976417541504\n",
"4.1329731941223145\n",
"5.838648796081543\n",
"4.945495128631592\n",
"4.220717906951904\n",
"6.133040904998779\n",
"4.624537944793701\n",
"6.009963512420654\n",
"4.677973747253418\n",
"4.602327823638916\n",
"3.3685643672943115\n",
"5.05299186706543\n",
"4.9118242263793945\n",
"4.44879674911499\n",
"5.753711223602295\n",
"5.209222316741943\n",
"4.082849025726318\n",
"4.3265061378479\n",
"4.385910987854004\n",
"3.9758224487304688\n",
"4.269911766052246\n",
"4.9230637550354\n",
"4.7449116706848145\n",
"4.327792644500732\n",
"3.6670405864715576\n",
"4.678028583526611\n",
"3.702824592590332\n",
"4.078690052032471\n",
"4.826587200164795\n",
"5.305018424987793\n",
"7.958984375\n",
"4.998692035675049\n",
"3.7893269062042236\n",
"3.986933708190918\n",
"3.4457900524139404\n",
"2.8971259593963623\n",
"4.211317539215088\n",
"4.968047618865967\n",
"5.523991107940674\n",
"4.809507846832275\n",
"5.05592679977417\n",
"6.213912487030029\n",
"6.041927337646484\n",
"6.665318012237549\n",
"6.162818431854248\n",
"5.069524765014648\n",
"5.342731475830078\n",
"6.58087682723999\n",
"8.372491836547852\n",
"7.336589336395264\n",
"8.737000465393066\n",
"8.807034492492676\n",
"7.782583236694336\n",
"5.720577239990234\n",
"7.279583930969238\n",
"6.570875644683838\n",
"6.362747669219971\n",
"7.10400915145874\n",
"5.268760681152344\n",
"5.338332653045654\n",
"5.195115566253662\n",
"6.703448295593262\n",
"5.327396869659424\n",
"5.858665943145752\n",
"7.968378067016602\n",
"7.84922981262207\n",
"6.56950569152832\n",
"7.11574649810791\n",
"7.103545188903809\n",
"8.86844253540039\n",
"7.60360860824585\n",
"6.428305149078369\n",
"5.785921573638916\n",
"7.077493190765381\n",
"7.3251953125\n",
"6.634866237640381\n",
"5.36044979095459\n",
"8.364887237548828\n",
"8.705479621887207\n",
"6.498984336853027\n",
"7.325703144073486\n",
"8.187382698059082\n",
"8.92408561706543\n",
"9.026152610778809\n",
"7.936668872833252\n",
"5.978590488433838\n",
"6.503759860992432\n",
"5.820070743560791\n",
"6.012407302856445\n",
"6.347898483276367\n",
"5.945157527923584\n",
"10.594411849975586\n",
"7.612778186798096\n",
"8.65031909942627\n",
"8.221137046813965\n",
"5.586372375488281\n",
"6.739910125732422\n",
"6.437129974365234\n",
"6.807755947113037\n",
"7.184710502624512\n",
"7.534907341003418\n",
"7.9743194580078125\n",
"11.786186218261719\n",
"8.150350570678711\n",
"9.01694393157959\n",
"9.777725219726562\n",
"10.811105728149414\n",
"10.679527282714844\n",
"8.601481437683105\n",
"12.968698501586914\n",
"7.7538743019104\n",
"8.0514554977417\n",
"8.91894245147705\n",
"8.04653263092041\n",
"8.193402290344238\n"
]
}
],
"source": [
"lri = []\n",
"lossi = []\n",
"\n",
"for i in range(1000):\n",
" # mini-batch\n",
" ix = torch.randint(0, X.shape[0], (batch_size,))\n",
" \n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[X[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" \n",
" lr = lrs[i]\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
" lri.append(lre[i])\n",
" lossi.append(loss)\n",
" \n",
" print (loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 82,
"id": "2efaea53",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x129ff3518>]"
]
},
"execution_count": 82,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXoAAAD6CAYAAACvZ4z8AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA+K0lEQVR4nO2deZgU1dX/P2c2BoZdENlkk0VFBRwR9x0RjChJjCYxbtGY6Bs1Rn+ocYkaxayvvpoY4p4Y475EUEFFBQRhUHZQFgcZ1oFhX2amp+/vj67qqe6u6r2H7p7zeR4eum/dqro13f2tU+eee44YY1AURVHyl4IDPQBFURQls6jQK4qi5Dkq9IqiKHmOCr2iKEqeo0KvKIqS56jQK4qi5DkxhV5EeorINBFZKiJLRORGq/0PIrJcRBaKyBsi0t5j/0oRWSQi80WkIs3jVxRFUWIgseLoRaQr0NUY84WItAHmARcCPYCPjDE+EXkYwBjz/1z2rwTKjTFb4h1Up06dTO/evePtriiK0uyZN2/eFmNMZ7dtRbF2NsZsADZYr3eJyDKguzFmiqPbbOB76RgsQO/evamoUONfURQlXkRkjde2hHz0ItIbGAp8HrbpKuBdj90MMEVE5onItVGOfa2IVIhIRXV1dSLDUhRFUaIQt9CLSGvgNeAmY8xOR/udgA94wWPXk40xw4DzgOtF5FS3TsaYicaYcmNMeefOrk8fiqIoShLEJfQiUkxA5F8wxrzuaL8COB/4kfFw9htj1ln/bwbeAIanOGZFURQlAeKJuhHgKWCZMebPjvZRwG3ABcaYvR77llkTuIhIGTASWJyOgSuKoijxEY9FfxJwGXCmFSI5X0RGA48BbYCpVtsTACLSTUQmW/t2AWaIyAJgDjDJGPNe+i9DURRF8SKeqJsZgLhsmuzShjFmPTDaer0aOCaVASqKoiipoStjFUVR8pxmKfR1Pj8vV6xFi64oitIcyGuh37W/3rX98Wkrue3Vhby9YH1az7d553721zek9ZiKoiipkrdCP3nRBo66dwoLq7ZHbNuyuxaAnfvcbwTJMvzBD/npc7qiV1GU7CJvhX76ikBqnYVVOyK2iTW1nAnHzYyVcaf0URRFaRLyVuiLCgJq7nfxw4trEJGiKEp+krdCX2gJ/d1vLeGyp8JT8wTQuVhFUZoDeSv0BdJotdtuHBt700fLNzflkBRFUQ4IeSv0hVGuzL4FfPJ1Ncs27PTuqCiKkgfEXBmbS9T6Gpi6dBMDurShoCA+P3y6I28URVGyjbwS+ocmL+fZzyoBuOS4np79xOHW8afJT6+LrxRFyVbyynXzbU1jEs3/zF0b1z4mTUGW6bphKIqipJu8EvqkUIteUZQ8p1kKvcNzk7ZFUyrziqJkK81S6J2kyxB3W5ilKIqSDTRLoXeujE2Xj151XlGUbKV5Cr3TdaMCrShKnhNPzdieIjJNRJaKyBIRudFq7ygiU0VkhfV/B4/9L7f6rBCRy9N9AcngjLB36vz++gbqfP6kjqmuG0VRspV4LHofcIsx5ghgBHC9iBwBjAc+NMb0Bz603ocgIh2Be4DjgeHAPV43hHSQTEoDO1rGGMOgu97jjD9+nNS5VecVRclWYgq9MWaDMeYL6/UuYBnQHRgLPGd1ew640GX3c4GpxpgaY8w2YCowKg3jTpjNO/cHX7u5bv7vo5UArNu+D2MM/gQD41XnFUXJVhLy0YtIb2Ao8DnQxRizwdq0Eejiskt3wLlyqcpqczv2tSJSISIV1dXViQwrLq58dq5ruz0Z++aX64Jt33lsBn3vcK197om6bhRFyVbiFnoRaQ28BtxkjAnJBGYC/o+UlM4YM9EYU26MKe/cuXMqh3Jl3fZ9wdfOFAi2PjuFevG6xBOdqc4ripKtxCX0IlJMQORfMMa8bjVvEpGu1vaugJuDfB3gTDrTw2prcrbvrWfip6uAsMnYoNDHf6z5a7ezYO320EYVekVRspR4om4EeApYZoz5s2PT24AdRXM58JbL7u8DI0WkgzUJO9JqOyA8OHl5RFt9QyDKJprrxdfg55ste4LvL3x8JmMfnxnSR103iqJkK/FY9CcBlwFnish8699oYAJwjoisAM623iMi5SLyJIAxpga4H5hr/bvPajtgzFixhTfnNz5U/PyFL6jz+aO6Xh5+bzln/PFjNu7Y79lHZV5RlGwlZppiY8wM8CyyepZL/wrgp473TwNPJzvAdPNjl7KC9Q3+qBa5XaFq8679HNKu1LWPWvSKomQrzXJlbDgNxkQVanvyduueOs8+qvOKomQrKvTAn6d8zaadtZ7b7WJV1bu8+6QrZ46iKEq6UaGHYFWqcG57dQEQSI0AjRO3bqhFryhKtqJCH4WXK6qYsWILq62Im2ghmCr0iqJkK3lVMzYTOCdvo1WRUteNoijZSrO06IsKvIKIotMQxaTXmrGKomQrzVLok9Xk6K6b6Edt8Bt8UXz8iqIomaJZCn2yMe9RXTcxDnnRX2dy2J3vJnVeRVGUVGiWQp/sxOkDk5Z5ir2zeXetj9XVu0O2L6zakdxJFUVRUqRZCn0qrHdJg7B1dy3TVzamVr786Tmc+adPmnJYiqIonuRV1E3HshJqoqxeTQcnTfgoou0nT89hyfrG1Mbz1mwDAq4eZ0pkRVGUA0FeWfSf3Hr6ATnvis27XdvtKJ3X5lU15XAURVFCyCuhb1NazJHd2oa0vfbzEzN+3kIPq91nCf0tryxw3f7szG/YvMs9I+bCqu0sXZ94ARRFUZRw8kroAd75n5O5f+yRAPTo0JJje2WsFnmQQo+4/Loo4ZQrN+/m3v8u5YZ/f+m6/YLHZjL60elpGZ+iKM2bvBN6EaFFUWETn9O9/f7/LmWRR7SNzx+4CezYW5+pYSmKogB5KPQALYoDl9VU+We8LPpX5lVx3b/mhbTd9eZi1jvq1yqKomSamFE3IvI0cD6w2Rgz2Gp7CRhodWkPbDfGDHHZtxLYBTQAPmNMeVpGHYMWRU17/wr30Ys03mTKWoQ+Xfxz9hoqt+7hzjGHN9XwFEVp5sQTXvks8BjwvN1gjPmB/VpE/gREWw10hjFmS7IDTAbbdRMrLUG6KAiz6AtEaLDOXdLENx1FUZRw4ikl+KmI9HbbZhUOvxg4M83jSommsuh/9ORsvjusB+Gem0IRGqyMOovXRUbOaGy9oihNSaqKeAqwyRizwmO7AaaIyDwRuTbFc8WN7aPfXevL6HlmrtzKr15eQIGL6yYaycj82pq99B4/iYVV29m2p46h901h/trtSRxJUZTmRqpCfynwYpTtJxtjhgHnAdeLyKleHUXkWhGpEJGK6upqr25xcXjXQCz9j0f0ithWXJhea7pbu9IIoa/1Rc9S6fThx8vHXwf+Ji/NXcvn39SwbW89f522MrGDKIrSLEk6BYKIFAHjgGO9+hhj1ln/bxaRN4DhwKcefScCEwHKy8tTcq63Kili+f2jXF04pUWF1Dekz9JvXVrE/vrE0g87bwtxe3Fc7gyaAl9RlHhIxaI/G1hujHFd3y8iZSLSxn4NjAQWp3C+hCgtLgz6wscc1TXY3qI4NArmwYuOSuk8LYoKPcMrvUjGR2+Lurr3FUVJlJhCLyIvArOAgSJSJSJXW5suIcxtIyLdRGSy9bYLMENEFgBzgEnGmPfSN/T4+b9Lh3JxeQ8gcqL25MM6pXTsFkUFEZOxsXD2X75xFz/4+yx27K3n/SUbY+4riIq9oigJEU/UzaUe7Ve4tK0HRluvVwPHpDi+tFBQIEHPR2lxqNAXpeizLykqSNiiBwnxxHz+TQ03vPgF01dsYfbtZ3FIu9KUxqQoiuKk2QR511t5Z0rDXDfpEPrwydhYiEQWE19dvQeAOo+J3KZa5asoSv7RjIQ+oJThrpvigtT+BCWFSQg93sLtdSh78Zdzu4q/oijx0GyE/sxBBwNwVPd2Ie2pWvT76htYuiGxdMIiMGnRhpC2eFfxCsnF4SuK0nzJqwpT0fjusT0476hDEIT99X5eqlgLQHFhYve6Tq1b0KdTK+ZWBqpIrduWXIKyv328KuS9nbs+2cLliqIoXjQbix4C8fUtSwp5+HtHB9sSFXqAuoZGMbYFOhGWbdgV0WbPIXgdT+VfUZRkaVZC70biETOhE6YNSQj9tzV7I9p81s3D63i2oR8ag6/yryhKbJq90MeDXbEKAv71+obUhN6NWtuib4h9PFvsbfeRoihKNFTo48BpRQvgs0S5pKggmI44VeybRyI3jh37tDqVoiixUaEHHrlkCA+N806FEB7yaLtuWhQWpM2it+8XdolBG791/FfmuWaaUBRFiYkKPTB2SHf6H9zac7szTl4ELh1+KAAHtS4JWvfpwuc3fPlto0umwRjmr93OMiuEU9MfKIqSKM1W6N+98RReue6E4Ptodnm4tt5w5mGsfnA0ZS2Kggux0oWvwXDfO0uD7xv8hu176xxjUaVXFCUxmq3QH961Lcf17hhXXxEYf94gAA4qa4GIUFAgFIhQ62uIuX9ZSWHMPjYNfhN019jvw/Pbq9QripIIzVbow4kmniJCn05lAHR1JBwrEIjHRd+qRfzr0nx+f8gE78KqHTELmSiKokRDhd7Cqde/GXN4xPajewRSJ/zstH6NjR4O8+evGh7yPhELvMFvcLr9L/3HbPbXNz41hJ+yqQqgK4qSuzSbFAjxcmyvDgw9tH1oo4Gu7VpSOWFMSPMCj5qt4flzEplA9YW5bgBqnUIf1t8YnaBVFCU6atFbDOjSBoAbzjgsIhtleErhWBSFZcRMZAK1wW8iYvOjuW78xlDn82OMYXX17oTGqShK80Ateot2LYuDFvvCqu0h2xINlQ9Pn5NIlgU3i/6BScuCr7fuqQux4DfvquXECR9xVPd2LFq3g0uHHxp1TYCiKM2PeEoJPi0im0VksaPtXhFZJyLzrX+jPfYdJSJfichKERmfzoFnknCLPtGMkuH7J1IjtsHvj3q+N75cx1vz1wffr9seyJ65aN0OAF6c861a9oqihBCP6+ZZYJRL+1+MMUOsf5PDN4pIIfA4cB5wBHCpiByRymCbinAfe+IWffJO8/qGSNdNOG8vaBT6tS4J0q56dm7S51cUJf+IKfTGmE+BmiSOPRxYaYxZbYypA/4DjE3iOE1OUZhQJxrZEm7RJ5KT5rZXF7K2Jv4c9796eUFEW7oXcYVjjOG1eVUh0UCKomQvqUzG3iAiCy3XTgeX7d2BtY73VVabKyJyrYhUiEhFdXV1CsNKnQjXTYImffj+PTu2SnlMB5pZq7YGF4d9umILt7yygIcmL4uxl6Io2UCyQv83oB8wBNgA/CnVgRhjJhpjyo0x5Z07d071cClxiLUoqrW10CkZ182jlw7lqpP6AHBU97YhqY4zTbpj65dt2Mml/5jNA+8sY9LCDXy1MZB3Z8vuuhh7KoqSDSQVdWOM2WS/FpF/AO+4dFsH9HS872G1ZT2tSoqonDCG3/53Cc/MrEx4MrawAC44plswy6WvwaQty+WBYNuegKCv3Lybf85e07hB4/cVJSdIyqIXka6OtxcBi126zQX6i0gfESkBLgHeTuZ8BwrbBROPzv/n2hER+xVbk7o+v6Fmb9Pljo/3ljJvTQ376xswxvDH979ilUe0jtfxVOcVJTeIJ7zyRWAWMFBEqkTkauD3IrJIRBYCZwA3W327ichkAGOMD7gBeB9YBrxsjFmSoevICPacbDwLpkb0PSj42q5Da0ff+Px+avbUpn+AKbC2Zi/f/dssfvPmYjbvquWxaSv58ZOfu/a1b3Thf4dEwkYVRTlwxHTdGGMudWl+yqPvemC04/1kICL0MlewhSxRr0uHshKgcYWsr8GwY58vrWOLxoYd+zHGICK8PHct97+zlAX3jKTAEU1kRwK9Oq+KV62iJvs8omhsgQ9/slGZV5TcQFMgRME2WBP10dtpiZ2um/HnDeK7w3qkdXzRsCdKH5i0lF21vrhCPBtihGWGb1WDXlFyAxX6KBQm4KN3Yj8JNLpuDN3bt+RPFx9D5YQxXHlS73QO05WZK7cAcFDrFgBs2V3LVuufFz6PRxf7+ud8E7qcQnVeUXIDFfooBC36OH03Fw3tTo8OLYPvbV99eLnBe77TGGo58bJjUxylOze9NB+AjpYbqXp3Lcc+8AHHPvCB5z776hvYUxvpYvKcjFWTXlFyAhX6KBQk6KP/yw+GMOP/nRl877TovTjniC7JDzAO2rcsBmDbnviifu5+K/75cpV5RckNVOij0DgZm1wMfNBHH6WAeKat4jJr0deu/fEJ/eZd+yPaPBdgqdIrSk6gQh+FxvDK5Ci0om4SXSyVLnfOlCUbg0K/28UlkypaqFxRcgMV+ijYQpZsSgHbbWLXm42XHh1Cc+PcNmpgUuf/4tvttCwORADt3B+f0Ifn6YFoPvqkhqUoShOjhUei0OuggOD27JBcUrLencr459XDObaXW843b8IFNFbYYzzE67pJRLxV5xUlN1Chj8LYId04uE0LTuh3UOzOHpzSP3aCtmm/Pp0z/vhx8H24VX1cn45Jn9+eX1i+YZdreziu4p27aXoURUFdN1EREU48rJPnhOknt56e9LHfuv4k/n3N8QAc3KZF2HkD/x92cGuW3TcqJL1CIhhM0O00a/XWkG1e8wZu1+qVAiL8hnTDv7/g4682JzNURVEyiAp9CtiLkZLhmJ7tObFfJyAQGTP9tjOC22z5NMbQ0lpl++b1JyV8jlWbd3uGhnq1J1Icy6nzxhjeWbiBK57R6laKkm2o0KdAYRpnI3t2bEW/zmWcc0QX+nQq4/yju/LopUOD24f0bJ/wMT9YttnTReMdMupi0cfhusnlNMyKku+ojz4FCtJ8m/zwltODrx/74bC0HNPTovd03cBD7y7DGLhj9OGAt9A73Tz2ojCNxFGU7EMt+hSws1Oef3TXGD0PHF6hoV4FyAX4+yermfjp6sZjeBzbKeq20KfzKUdRlPSgQp8ChQXC53ecxZ8vHtIk55t/9zmcfFinhPZxc9EYY/jhP9xzzyfrgbFDQAsScfIrioKvwc+tryxgzdY9GTuHCn2KdGlbSklR0/wZ27cqoagwMSF1E+5an3dKBp+/cdv2vYFUx15PBU5Nr7f2U4teURLji2+388q8Kn79yoKMnUOFPsdIdNLTbwwlhaEfczShr3fk5Rly31TqfP4opQQbRd0eV5Fa9IqSEMmuvE+EeEoJPi0im0VksaPtDyKyXEQWisgbItLeY99Kq+TgfBGpSOO4my2JCr0x0Kl1SUhbrc+9khTAzJWh8fY+vz/KZKyzn7puFCUVMpk7Kh6L/llgVFjbVGCwMeZo4Gvg9ij7n2GMGWKMKU9uiIqTww5unVD/Br+hpKiAK07sHWz7yVNz4t4/3vuKnaGzUIVeUbKOmEJvjPkUqAlrm2IV/waYDTRdjbxmzp1jDk+of53PT4EIV53UJ9i2fOOuKHuEEphkjb0yNmjRq49eUbKOdPjorwLe9dhmgCkiMk9Ero12EBG5VkQqRKSiuro6DcPKT1oUFSbkB69r8COSfHx7NNeNkxWbAjcP9dErSmI0xVLDlIReRO4EfMALHl1ONsYMA84DrheRU72OZYyZaIwpN8aUd+4cOxFYc8bLPdL7oMgsm/UNfgoLJGmhb/B7Zbpp5LNVW7juX19EHZuiKDHI4E8naaEXkSuA84EfGY9pY2PMOuv/zcAbwPBkz6c0YovpontH0teR697NbVJruW4OblOa1Ll8fuOZLsE+38rNu4Nt6rlRlOwjKaEXkVHAbcAFxpi9Hn3KRKSN/RoYCSx266skhi2w4fL727FHctJhoZku63x+RISSogJ+dlrfhM/V4DeeE7K2qDvvA+q6UZTsI57wyheBWcBAEakSkauBx4A2wFQrdPIJq283EZls7doFmCEiC4A5wCRjzHsZuYpmRjurcpXxE/K4N/CQNrzw0xEhfQMWfeB1rIlSN5H2+Y1nnO+6bftYvG5HyHYNr1SUxGiCMPrYSc2MMZe6ND/l0Xc9MNp6vRo4JqXRKa68eM0I3l+ykXatikPa3eJwl23YyaBD2gCxV60WFkgwesbm5pfmM25Yd9f+7y3ZyHtLNnLPd45oPIb6bhQlKTL5y9GVsTnIoQe14ppTI90wXsa0HU4Zy9ouLoz8Osxfu52ZK7dE3c9pkehkrKLEx/76Bv7w/vKoCxjThaYpziNiuWZiabBXHp0ilxuAE+czgAq9osTHUzO+4fFpqzh1wM6Mn0st+hzHKauxhT6Wj97961BbH93icPro4xX6v0z9mt/+d0lcfRUlH7FzTtm/r0x6PVXoc5yQGq8xvijhkz7OClbR2BdD6J3EuzL2kQ9X8MzMSvbW+WJ3VpQ8JusXTCnZRSxjOjwe/tT+obntvaJr9td7Z7sM7Nf4OlHXze2vL0qovxeL1+2g9/hJrN++Ly3HU5RME/ylNIHSq9DnEbGs6XAhD/e9e33f5q3ZFv24jj0TjbpZXZ2eYgv/mr0GgE++1vQZihKOTsbmEbGEPjyqJjxu3rtgeHT8MSz69dv3cdN/5jO4ezvatyrml2f1T/mcXjRFTLKiZIJMpilWoc8jYhnTrUtDP+5w4fcqGB6LWK6bv3+yijmVNcypDCRBdQp9ovn1vdDwfSVXiZ1NKnXUdZPjPHjRUcHXMYW+RajQh4tystaw84vqFqsvUQamFriiZB4V+hxneJ+OwdexXDdtSqM/wNlulAnjjoraL5wtu+qCr52h+F98u41d++sjbkAff7U54pzpoimsI0XJNdR1k0fEEvrTBx4cdbvtRenjyIgZD0/P/Cb4utCKxd9b52PcXz/jlP6dGNClTUj/K56Z6zinCrPSvJlbGQh20Dh6JS68Ihvt4uClxYUR2966/qTga1t0S4sLOcUKvSwpSuwrYrv9q3fVAjB9xRaemvGNZ/80ueiDZHJCS1FyFRX6PMAWZS9f+Ou/ONFz3/5dGmvQ2sZ1aXEh/7z6eConjGHcUPeEZl7YTxWn/eHjuPqr60ZpTmzYsY+FVdub/LzquskDJl5Wzsad+z23e+WwgVALuCFo0ce+/xe5ZLqEQFrjX708P+b+NukTerXkleznlIen4fMbKieMadLzqkWfB7QsKYzqV3fLYXPBMd2AUL+gHerodPF4hT+6iTzA0vU7ef2LdTHHbOP3w5ffbuPU309j1/76uPdTlFzE63cD6qNXEuSRS4Zw2oDGurtuBUXsPDduX64WDr98onHuiX5Z/caweP1Ovq3Zy5bddbF3UJQ8Ify3ksn5pbiEXkSeFpHNIrLY0dZRRKaKyArr/w4e+15u9VkhIpena+CKN2OHdOe5q4bT0rLMnfHy7910CveNPTL43u3LFWLRJ+haiZUXJxy/MUFLvsHvt/43bIriioqGBvEoSiTxWvTPAqPC2sYDHxpj+gMfWu9DEJGOwD3A8QQKg9/jdUNQ0k/HshIgVPwGHdKWn5zQO/jeLVInFYs+0SIKDX7Ytd8XfA3wvx98zfEPfsjGHfGLva6MVRRv4hJ6Y8ynQE1Y81jgOev1c8CFLrueC0w1xtQYY7YBU4m8YSgZ4kYr1UD7smLPPs5InT9ffAzH9GgX0pboZKmdYztejDHstoTeZ1n0n1qJya55voLe4ycldDxFUSJJJeqmizFmg/V6I4Fi4OF0B9Y63ldZbRGIyLXAtQCHHnpoCsNSbC4+ricXH9czah+nRT9uWA/GDesRst3XkJjQ1yUo9E7XjaXzwayai9btCPb7YOkmOrYuYdih0R8I1XOj5CpZPxlrAvlvU/qNGWMmGmPKjTHlnTt3jr2Dkhai5aEBuPT4zN50G/wm6LqxLXq3yeOfPl/BuL9+5nkc9dwoijepCP0mEekKYP2/2aXPOsBpUvaw2pQc4YywtAkdy0q4Y/SgtB3fmEYf/c0vzeeJT1ZFXY0755saduzTMExFSYRUhP5twI6iuRx4y6XP+8BIEelgTcKOtNqUHGLm+DODr687rS/XntovbcfeVesLpi+u3LqXCe8uj7DonRPCF/99FmMenZ628ytKcyDe8MoXgVnAQBGpEpGrgQnAOSKyAjjbeo+IlIvIkwDGmBrgfmCu9e8+q03JIbq3bxl8ffXJfTN+vvA8+fUNoX7/qm1RygVqfKWSA3iV7cwUcU3GGmMu9dh0lkvfCuCnjvdPA08nNTol60i0JmwyhAt9XUPsCV4Nr1Ryiaa2R3RlrJISww5tn/ZjhufmqU8wkkdRsp1EFyKmigq9khIXJpjdMh4iXTfqjlHyC78xTZpSW4VeSZp+nctcE6alSnG4RR+H6yYWt726gIufmJXycRQlHfhdvtKxQp1TQYVeSYpF945k0i9PiZoCOVnCbx5uQj/nmxp6j5/El98GqvPEehJ+uaIqGN2jKE3N4nU72FPrC753c91kcoJW89ErSdGmNJBWIdz6TgcRPnoX141dd3bmyi38Y/pqJi/aCDSu2rN/NCLCoLveTfsYFSVe1tbs5fz/m8GPHIsP3VKLVG7dgzEmI5a9WvRKXHRrV8rpAyNXLGfCdROeudLNorcrWS2s2hEUeYC731pC7/GT6HP7ZPrcPhlIPKOmosRLfYM/whLfuruW//fqQvbXBxL8fbl2OwCbrfKaAH6XZIFra/bx6ryqjIxThV6Ji89uP4tnrxwe0Z6JcEuncIN7Rkzb6KnZoznslQODMYaj753CNc9XhLT/ccpXvFSxlje/DCQB2FcXcNm0adHoQPFKCmvfFNKNum6UlGiCsHr21LoJfeDE0Sr2KEom2bqnjn31DXywzC37S0DMx7+2kP/MDeR1dLpk1m/f16T1jdWiV1IiWX9i13alcfe1c+E4sW8w82NYQInm01eUePFKyW3/JvzGBEUeQo2ii/4609OqzwRq0SspURhF6If36ciZgw5mSM/2XDJxdsi2ggRuEG61ZJ+c/k1c+yaaNllR4sVrIZ/9zQ733Tu/8/UNxjXKJlMPyCr0Skq4zcW2aVHEM1cex7G9OgStm+tO68cTn6yKuh8EfP7hVribRb+7NrLNDRV6JVN8tNzdZVMQtOhD2+vDgufdIm8yZeSr60bhmSuO43cXDU5qXzfXzdzfnE15744h2249dyBTbz41+N7rSaBf57KItt9NXpbU2AB2ujwNKM2PLbtr6T1+Ei/PXRu7c5zc985S13bbRRMu5OFGR1O6blToFc4YdDA/Or5XUvva1kuHVo3lCp3FxW0KC4T+Xdo07ucxi2vH56eLU34/LaLtgXeWMvbxmWk9j5LdrNm6B4AX536b8XOJh0UfKfSRSv/vz791Db1MFRV6JSVsvT6iW9uE9pt4Wblre6uSyJtEunlyxjcsyFAYm5KtBL6oTZFLzDZ+wn3w4VlY7c0XHNMtdP8MhLKp0CspEfRHJugK79e5zDW1sFsZwXSSCWtJyX7s71qyn/60rzbjiyPn0uzVW/lq004gDteN39CiqID//cGQYNuQnu2THGF0VOiVlChwhJIlgldYZqbz3WvcffMkGPGShEn/ydfVXPnMXB6ftipm30smzmbmyq1AbNfNfl8DrUoKQyz4l342IuHxxYMKvZISpcWBr1Dblsn71p/48bHB1+Fhl31dJmdTwZfoo4eSF9jfqmTu89VW6oI1NXs8+7iFSoY3hScy27a3ng5lJSFtLYoy47pMWuhFZKCIzHf82ykiN4X1OV1Edjj63J3yiJWsYkjP9tx9/hH8/rtHJ7yv/b135tBxWvTv33Qq7914avhuKaEWffMk2SdPiC+rpFvivfBzhYcNb9tTR8dWjUJ/cJsWCY8tXpKOozfGfAUMARCRQmAd8IZL1+nGmPOTPY+S3YgIV53cJ+7+15/RL+IR2GnFOx9jBx7ShnTT4PKD3LGvnlYlhREFT5T8IQXPTeMxoixncruBhLtqfGHfvZo9dfTo0AqApfedm9AiwkRJ1zf7LGCVMWZNmo6n5Cm3njuIygljgMaQTKcV36Io9ldyys3JW/luFv0xv53CTS/NT/qYSu6QjM7Hs49bqo3wrKsNfkOfTgFXZKfWLdi2t46OZYHfQKuSItew5HSRLqG/BHjRY9sJIrJARN4VkSO9DiAi14pIhYhUVFdXp2lYSjbz+i9O4v4LB4cIfbSUCjYDuiRv6W/Ysc+1fdLCDSyq2sFTM+JLraDkFo0WfWZcd24GRLjQ+/x+WlpiLuLuo88UKadAEJES4ALgdpfNXwC9jDG7RWQ08CbQ3+04xpiJwESA8vJydaTmILNuPzOh3O99OpUFLRybTLvQL3gsdKGU84f/ncdmAHC1hyuq9/hJXDikG/97ydDMDVDJKCm5bqLYIGu27uHoHu1D2v4Rlo+pwW+CLp5d++up8/lDfPSZJB0W/XnAF8aYTeEbjDE7jTG7rdeTgWIR6ZSGcypZSNd2LSOEO1GaMnUruD9yR4u1f3P++kwOR8kQtsAnMxnr9ZVsU9poJ4cbEG74/CY4DtsgaiqLPh1CfykebhsROUSsgGkRGW6db2sazqnkKbF+h/0Pbp3W87nV7vRKP6vkLvbHnE4zokVRQYjbcen6nVH7+/0mwpDpkAsWvYiUAecArzvarhOR66y33wMWi8gC4FHgEpPJCrhKznPR0O5Rt795/UlpO9esVVtdV/S6VbRSchvbkk+n/Pisla02ox+dHrN/+MOiPRmbaVLy0Rtj9gAHhbU94Xj9GPBYKudQmgcf/Oo0RKBf59a88z8ne4ptWYv0Zda+7KnPmX/PyIh2N4teC5jkNkGhT+EY4S76hoaA0O+ti88wcProbZrKotd89EpWcJjDJTO4ezvXPo9cMiSt5/S5/PCAYFFnCFiAD727nC5t46+IpWQfwft0Ui569518fmMZHvGlwvb5TcT5O+ZK1I2iNBXJhFX+oLwnAC9VuOchd5t4rfX58TX42VvfwO79PiZ+ujrh8yrZhe2ySWoy1oXdtT4a/IYWxfF7v90s+rZpTsvthS4FVHKGZDJbtm1ZxNE93Z8QwN0ls7++gTveWMTR905JShjW1uxl+966hPdTMof9KabigbPDK+et2cbge96nrsFPSQKrqX1+f8SzQSZSEruhQq/kDMlktmxTWkzn1t45RDbtrI1oq/X5ebmiCnDPYRKLU34/jdP+8HHC+ymZw35ySyZ8N/xe/+W324KvixIQ+v31fteymE2BCr2SMxS5FJr98q5zItoGd28sgtKmtIhWJd4eSrdICefk2t665H6YO/ZpCcNswrbk0+G5ceakKS6MND7OP7orrcOCBg6yfPE1ew7Mk54KvZIzFLr8qNwWnLzw0xGMs8I0W7coCqZSjpe9jsLjuw+QBaakFxMMr0z9WM4nSzd34pCe7UP6/M+Zh7kuJBxzdNfUBxMnKvRK1mP7QZ0/qkm/PJnffy+QGnn5/aOC7ZUTxtCuZTFDD20PBEocJprj+4XPG+uK7knSoleyi3REx9rZKz9btSXY5vaUWVggIZlQfX7jmj7hvMGHpD6oOFGhV7KeEmtRitNKOrJbOy62Imrcsv79eEQvPht/Jkd2a5ewRT9jZeMPeXetd4z0xh37ufutxfga/Oyp9QULVCjZhz2pvm67e1K71+ZVMXVpRBYXIDIi8/0ljf2KXJ4yiwokZCGV329ck55lumymExV6JeuxhT6Rx24RoVv7lkBo1Z7rz+gXkqMkFis27fLcduurC3h+1hrmVNYw5tHpHPe7D+IfoNKkOL86bqtjb3llAdc8X+FaFzZa5JXbZGxhQUGI797nN+xzWVRV6PI0kClU6JWs54yBBwONgp8otkXfpkURt547KKFH5v/7aGVEW4Pf8MQnq4LWYauSIiq37o3op9k+sgenWEdz47jVhbUjdtzcL8UuVnlhASGumwa/YZBLEZ2mrHOjQq9kPQ+NO4ppvz6ddknWpbUt+lrLWnP7oSci/m8vWMeEd5ezujpQQ7S23t29o8nRsgfnTTd87YTTZVO1LfKGHd7fLpgDuFYlKxCJEPqHxh3t2Cdwc1CLXlEclBQVpJT+2F69aBeCcDO0Rx7ZJa5jdW7Tgg+WbQ5p+8HE2a59d+7XEMtswZm8Lly4r3m+orGfgV+8MI8vHLHy4UspnGsr3Hz0rUqKGNarvWN/Q8uSQsp7deD8o7tyXO+OQHxFdtKFCr2S99gTY4cfEoivd1s0E4+X5bQBnSmQQDUqL65/4Yvg6wO1OEaJxOm68bmlLLXYvGs/kxdt5Of/mhdsC3fBOWvBntAvJKcjAK1Li7j7/CO5dHggWMCuU/zqz0/ksR8OC07MJrMAMFlU6JW8R0R49boT+NdPjwfcRT2e8LveB7WKWUFr0qLGm4AKffbg/Hyj6DxW+YyQz7nB4aNftmEndY4J276dWlM5YQynDegcbGvdooiSogKG9uwQ2D/sC2fXVDioddMkNANNaqbkEcf36ei5rbx34za3SdJ4ctqUlhSyz8Mf74aX7145EMRn0dtRN8402U6hPu+R0JXUJUWBG4MzVNKe/Lct9nBX0V3nH8GYo7umVPs4UVTolbxgyW/PdZ0Yc8NN0uOJkGlVXBTy2B6LT1dUs9/np2+nMuav3c7SDTu5deTAJktkpTTi1Fqn8FZU1oT0qw8KfePnHK20pP2dcz692es6bP99+KddWlzIif2atqKqCr2SFyRSkMTtdxvLdXPz2QNoWZKYp/Pxaat4fNoqCqTx+KMHd+WoHu1Yv30ffmPo0aFV3Mf728erOK53h5CnEyU+nE9stoX+3uKNXOfwxQPUNUSmSmhc7BR5g7a/d3McN4xDrNoF5w3uSsWIbdx4dv+Ux58qKfvoRaRSRBaJyHwRqXDZLiLyqIisFJGFIjIs1XMqSirY1vsxPdo52hq3uxWDOLn/QbR0WYEbD86byIKq7Tw5fTUnTviIkx+ehjGG3uMn8eDkZTGP8/B7y/neE7OSGkO+s2LTLi5+YhZ7at3nRZyfgc8S83CRB6h3eWJzW+xkY4u6nY5j3NDuQfEvKSrg/gsH0ylK9tSmIl0W/RnGmC0e284D+lv/jgf+Zv2vKAcEW9SdLhSnxffMFccx9vGZIfv4jXuqhUT5zZuLQ97bLoKJn67mjtGHp3z85spv/7uUOZU1VKzZFjIxahMtjt7J0g2RBb7tfEduLhxb1C8u70mvjq04vFvbiD7ZQFNE3YwFnjcBZgPtRaTp0rYpShh2CoQyR/pipxC4hTf7/YFY6HRjly1MpIDFjBWhNtWyDTv51+w1aR1XrrF9XyD9r9eiOjfXTSy+3bqXXfvr2WPlO3Lud8WJvfntBUeG9D++70FNVjEqUdIh9AaYIiLzRORal+3dAWcdtyqrLQQRuVZEKkSkorq6Og3DUhR37hxzOLefNyhkkZTTWHPTgQZjknbdRMOO4gnPa75++z4G3PkuS9fvtMbUOKgHJi0N6XveI9MjnhSaG9v2BBanOee5b31lAXe/Ffi7OD/Tz1ZtDbZH49Q/TOPCx2cG3UFOt9C4Yd25/MTeqQ+8iUiH6+ZkY8w6ETkYmCoiy40xnyZ6EGPMRGAiQHl5uSYJUTJGm9JifnZaPxav2xFscxYrccMYMiP0lv83PDnWh8s2Udfg55+zK/E1mBBR6W4la1Ma2WWtQnZmiXxlXqBK2H1jB4fcyO9K4Ka4qnoPh7QL+OG3721c6Wy35QopC70xZp31/2YReQMYDjiFfh3Q0/G+h9WmKAeUwd3bsfS+c6nz+WnfqoQ/X3wMXdu1dA2/7NOpjE0796d9DGtqArlVwkND7TGsrdnHjJVbmOlIneyW8jYRqrbt5S9TV/DQuKOSThTXlPxzViV3vbWEVQ+O9lxNav9NvPzvqRQFn7lyK9BYNeyyEb04uE1uCX1Kn7KIlIlIG/s1MBIIv12+DfzEir4ZAewwxnivIVeUJqRVSRHtWwWibMYN68EJ/Q6KyEHy9g0n0a19ywgf/bWn9k35/Fc+MxeALbtrmb92O0/N+CYYyw2N8wXrdzTeZLwW/ESL93Zy5xuLee2LKmaszA0X6f2TAhFJzkVM++sb+N7fPuO5zyqBRqH3edT4TUcmUVvoR/SNTHuQ7aR6O+8CzBCRBcAcYJIx5j0RuU5ErrP6TAZWAyuBfwC/SPGcipJRBndvy02O2GfbZeN03Qw9tD3XnBIQ+uJCCYbXpcKFj8/k/neWcv0LXwR9827U+wz/mfNtiPBB/Ja+bd0WNGFSrXRQ72u8vkc/XEHFmm3c8/YSoHFFq8/v55+zKpkbthDq33PWkirb9wYmfFtlYFI+06TkujHGrAaOcWl/wvHaANench5FaUpEhJvOHsD/frACaExz7BT6R34wNOhGKC0uTGu6gykelY5s5lTWMKeyhm9r9nLbqEHB9mhhg05soU9XUq2vN+2ibWlxxv3WgRtbIKpld1i8vH3plz01J2K/HXvrWbB2e8rn32PNp+Si0Ge/g05RDjB2muNSxw/80INaBddJ9u3cmi2765p8XH/9eBWbdzW6dLzCBo0xXPzELKYs2RjoZ2dPTJNFP/IvnzLioQ/TcqxoONMSxJvuArzLByZLIquwswUVekWJQall0beyLPoRfQMpCDqUlfDIJUN46vJyzxWZ5xzRhQcuHJz0uWNZ6e8v3tjY18M/XdfgZ05lDdf/O5BC2T6kZJnrZnX1bve6rdZ4nVkjExH60Y9Oj90pATKxniLTqNArSgxsi76osIApN5/Ks1cOD24bO6Q7nVq34PozDqPXQZF5a753bA9O6d+YwOqhcUcldO7PVm2Nut25Wnfc32aGFMywsQtl2MJuT9pmW6nDM//0SUgRkHDqQiz69N6kOrQqZuyQbnH1bZEDkUrh5N6IFaWJcf6wB3Rp45oKoUNZCZ/cegY/PblPSLsQGiN/8mGhWQt/dPyhKY3NaV2uqt7DuL9+xqdfVwdX3C6q2sFjVt1b2yVvu3jq/YYGvwmJ8slmbKGv2rbXtZZvItwxelDI+1m3nxWcg7ELhniRyNNEtpB7I1aUJiYRF8cxPdtH7OssIB1uDaaaP6fIZUL1J0/P4ZcvfgnAdx6bwROfBApe21E2tuvG1+Dnimfm0P/Od/n9e8vZV9fA5gTWCrw1fx1PTl+d0vgTYeueWgDufXtpjJ6xufbUfiHvS4sLg5/z4O7t+NP3j/FMRqZCryjNnNFHdeXmswcE35cUFYRY9OEi4SbUibBzn/vcwOff1PD0jG9C2myht1029Q2G6VbenL9+vIqfPP05wx+Mf1L1xv/M54FJjVk3d+yNv0bu5l37efNL93WT9Q1+18VpVz0bcOuEh5UmUiMgGvb93G/gu8f2YPIvT3btl263UVOgQq8oHjzx42O5IsF8JoUFwo1n9+fLu87h5rMHcMphnUIKSIevRN3lMYkbL9O+2uzavmNfPfe9E2r52kJmT/CGu2zmVgb8+xf/fVZE+GI8THhvecw+j320gmPvn8rw333ITS/NZ1X17og+419bxPEPfsjcyhr21PpCJmFr9kRGNw1/8IO4x3jGwM7c7zE5bt9z7Ruh0y3m9N+rRa8oecSowYdwb1iGwnjpUFbCjWf3p6BAKC5wt+h/cXo/fnlmakUpahOwZm2L3l496rXCds43NcxYEVg1O29NDSc//FEwl0w0Ji1cH/K+elctvcdP4sNljZE0f5zyNVsdYv31xl1AaHTRa18EctR8/4lZEZOzw+6fGnwKsdke40ni5Z+dEHz9zJXDuWxEL9d+BWGT1a0c2U3vdKSQVqFXFCUCp0XvfOy/bdQgDmlXSm+XaJ14+Wi5u0Xvxo599cxatTVoId/80gLPvrYR/ciHK6nato+KysZonpcr1nLmnz6O2Gfnfh97an0cdsdkfv3KAtZs3QPA1c9V8NlK93IVW3bXsmnnfgb85l3X7bGijmLRtV0px/XuELXPif0CKQ2Cri2rvbBAqJwwhsoJY2jrSH+croVmTYkKvaJkmKICoWVxIb+94EjXid3JN57CgntGUt4ruiClg0v/MZtvtuyJ2W9B1XZ6j5/EBmux0ZXPzmX73jp++eKX3PbqQlZXux9j5F8+xec3vDqvijVb9wbbf/jk5679t+2t5/gHP4x7VW+iPDTuKM/J9NMGdObS4T15+orjQtrdhpKLIZVOcnv0ipIDiAjL7h8VTDV8Sv9OjBvaWJKhVUkR7VoW8+9rRgTbXvv5CeGHaVImfhqIplmxudGH/t+FG3h7wXqvXYDQVai3vBL6xOBzCeP8uxURlCwPfzf6ugS7oMvjPxzGn74fmq3luauG89C4o4ORT6cPDFSmOtblhptti8sSJffW8ipKjvPPq90raTonao/tFb0A+L3fOYJ7/5t6mGEiJJLH3Y0rn50b0bYnSj1WJwO7tKFdq2LmfBOarMyZjuDoHu1YWLUjZPthB7cGYMzRsYvanT7wYL56YFQwt1E+oUKvKDlG5YQxAHEJ/dUn9+GpsDDLA0X4JGqivHjNCKYt30xBQWOo5cgjDmHQIW0Y3L0dD1w4mGUbdnLYwa15bNpKfnXOgIRFOx9FHtR1oyhZRVlJIb888zAAVj04OjhRCNCvcxkzx58ZfB++ytaNu84/gi/vOif9A41C6wwk/epQVkxhgXD2EV04c1BjCciSogLeu+lU/vj9YygtLmTooR1oU1rM7ecdnreinQwq9IqSRSy5bxS/GjkQCER3/OUHQ7jutH589cAoptx8WkgZwX9ePdz1GKcO6BzyvkNZScj7jmHvbW45ZwBHdW/nWWA7Xs6Pw03iRs+O7iUS7xx9OI9eMjSVITV71HWjKFlMl7aljD9vkOs2EeG5q4Zz+dOBHOyv/fwEhvbsQEGBsHnX/mA9WifDe3fk/gsH4zeGf81eQ4Pf8J+5gaIcR3Rry/+c1Z/7/ruUp2fGdvcc17sDcyu3cUr/TkxfsYVbzhmAAc4bfEjwmPHStrSIV352omu642vSUMkrHVw2ohdfbdp1oIeRFEkLvYj0BJ4nUGXKABONMY+E9TkdeAuwvzWvG2PuS/aciqKEctqAzkGfvROvmqYvX9cYzfO7i47iN28uCr63V4KGpxhw46NbTqNv58BEpzGB5GjOVA/TbzuDU34/LWK/x384LJguGQKRLh9/Vc2Ce0YiIkwYdxTjX18UsV824LWiNhdIxaL3AbcYY76w6sbOE5GpxpjwGaLpxpjzUziPoigZomNZY+IuO0LlqO7tIvq98YsTuf31RSy3VrJ2c7iQRCRkURgQssAIAjeG9q1K6FhWwpijx7Bxx35EoH2rYmp9/mD44sXlPVm+cRc/OK4n5z3inUf+9V+cyKrNkekTFHeSFnqrwPcG6/UuEVkGdAeaNuZLUZSkueGMw+h9UCvGDukeXPH5/fKenDKgMydN+CjYb+ihHXj2yuH86MnZ3DH68JhZN9u1LOaLu85h2P1TAejUpgVtSxvF31l20DlpWlAgwbQT153WzzWmHWDYoR0YdmjmF5jlC2nx0YtIb2Ao4Lb87QSrePh64NfGmCUex7gWuBbg0ENTy9GtKEoo7954Cvtc6tqWFBUwbliPkLbCAqF7+5aM6NuR2atrWHrfuUBAnD+85fS4z+mc9G1dkrjUeM1NKIkjqVaZEZHWwCfA74wxr4dtawv4jTG7RWQ08IgxJmYWp/LyclNR4V1pRlGUzFPf4Ke+wR+S3CtRFq/bwdzKGq48qU/szkpKiMg8Y0y527aULHoRKQZeA14IF3kAY8xOx+vJIvJXEelkjElt5YSiKBmnuLAg5UyNg7u3Y7CLz19pWpL+FCUwe/IUsMwY82ePPodY/RCR4db5UktHpyiKoiREKhb9ScBlwCIRmW+13QEcCmCMeQL4HvBzEfEB+4BLTLZVJFYURclzUom6mUGg9nG0Po8BjyV7DkVRFCV1NAWCoihKnqNCryiKkueo0CuKouQ5KvSKoih5jgq9oihKnpPyythMICLVwJokd+8E5MuCrHy5lny5DtBryUby5TogtWvpZYzp7LYhK4U+FUSkwmsZcK6RL9eSL9cBei3ZSL5cB2TuWtR1oyiKkueo0CuKouQ5+Sj0Ew/0ANJIvlxLvlwH6LVkI/lyHZCha8k7H72iKIoSSj5a9IqiKIoDFXpFUZQ8J+eFXkTuF5GFIjJfRKaISDePfpeLyArr3+VNPc5YiMgfRGS5dS1viEh7j36VIrLIut6sLMOVwLWMEpGvRGSliIxv4mHGhYh8X0SWiIhfRDzD3nLkc4n3WrL6cxGRjiIy1fotTxUR1+KxItJgfR7zReTtph5nNGL9jUWkhYi8ZG3/3CrXmjzGmJz+B7R1vP4l8IRLn47Aauv/DtbrDgd67GFjHAkUWa8fBh726FcJdDrQ4031WoBCYBXQFygBFgBHHOixu4zzcGAg8DFQHqVfLnwuMa8lFz4X4PfAeOv1+Ci/ld0HeqzJ/o2BX9haBlwCvJTKOXPeojeOcoVAGeA2u3wuMNUYU2OM2QZMBUY1xfjixRgzxRjjs97OBnpE65/NxHktw4GVxpjVxpg64D/A2KYaY7wYY5YZY7460ONIB3FeSy58LmOB56zXzwEXHrihJEU8f2PnNb4KnGVX60uGnBd6ABH5nYisBX4E3O3SpTuw1vG+ymrLVq4C3vXYZoApIjJPRK5twjEli9e15NpnEotc+1y8yIXPpYsxZoP1eiPQxaNfqYhUiMhsEbmwaYYWF/H8jYN9LKNpB3BQsidMqTh4UyEiHwCHuGy60xjzljHmTuBOEbkduAG4p0kHGCexrsPqcyfgA17wOMzJxph1InIwMFVElhtjPs3MiL1J07VkBfFcSxzkzOeSC0S7DucbY4wREa8Y8V7WZ9IX+EhEFhljVqV7rLlATgi9MebsOLu+AEwmUujXAac73vcg4KdsUmJdh4hcAZwPnGUs55zLMdZZ/28WkTcIPAY2uaCk4VrWAT0d73tYbU1OAt+vaMfIic8lDrLic4l2HSKySUS6GmM2iEhXYLPHMezPZLWIfAwMJeAbP9DE8ze2+1SJSBHQDtia7Alz3nUjIv0db8cCy126vQ+MFJEO1gz9SKstaxCRUcBtwAXGmL0efcpEpI39msB1LG66UcZHPNcCzAX6i0gfESkhMOGUVZER8ZIrn0uc5MLn8jZgR85dDkQ8qVi/9RbW607AScDSJhthdOL5Gzuv8XvAR17GX1wc6BnoNMxgv0bgR7UQ+C/Q3WovB5509LsKWGn9u/JAj9vlOlYS8MnNt/7ZM+7dgMnW674EZugXAEsIPI4f8LEncy3W+9HA1wSsrGy9losI+FBrgU3A+zn8ucS8llz4XAj4qj8EVgAfAB2t9uBvHjgRWGR9JouAqw/0uMOuIeJvDNxHwDgCKAVesX5Lc4C+qZxPUyAoiqLkOTnvulEURVGio0KvKIqS56jQK4qi5Dkq9IqiKHmOCr2iKEqeo0KvKIqS56jQK4qi5Dn/H1uwf1MglKf4AAAAAElFTkSuQmCC\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# in the beginning, as we increase the learning rate, we are not so stable \n",
"# somewhere around 10**-1 = 0.1\n",
"plt.plot(lri, lossi)"
]
},
{
"cell_type": "code",
"execution_count": 83,
"id": "ca8ce989",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"2.43825101852417\n"
]
}
],
"source": [
"# over the late stage, loss increases, diverged!\n",
"print (min(lossi).item())"
]
},
{
"cell_type": "markdown",
"id": "e07358e8",
"metadata": {},
"source": [
"# 7. cross-validation\n",
"train, valid, test split\n",
"\n",
"80%, 10%, 10%\n",
"\n",
"param, hyperparam, evaluate (few times)"
]
},
{
"cell_type": "code",
"execution_count": 84,
"id": "52c638a6",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([182625, 3]) torch.Size([182625])\n",
"torch.Size([22655, 3]) torch.Size([22655])\n",
"torch.Size([22866, 3]) torch.Size([22866])\n"
]
}
],
"source": [
"def build_dataset(words):\n",
" block_size = 3\n",
" X, Y = [], []\n",
" for w in words:\n",
" context = [0] * block_size\n",
" for ch in w + '.':\n",
" ix = stoi[ch]\n",
" X.append(context)\n",
" Y.append(ix)\n",
" context = context[1:] + [ix]\n",
"\n",
" X=torch.tensor(X)\n",
" Y=torch.tensor(Y)\n",
" \n",
" print(X.shape, Y.shape)\n",
" return X, Y\n",
"\n",
"import random\n",
"random.seed(42)\n",
"# shuffle\n",
"random.shuffle(words)\n",
"train_perc = 0.8\n",
"dev_perc = 0.1\n",
"test_perc = 0.1\n",
"\n",
"n1=int(train_perc*len(words))\n",
"n2=int((train_perc + dev_perc)*len(words))\n",
"\n",
"Xtr, Ytr = build_dataset(words[:n1])\n",
"Xdev, Ydev = build_dataset(words[n1:n2])\n",
"Xte, Yte = build_dataset(words[n2:])"
]
},
{
"cell_type": "code",
"execution_count": 85,
"id": "5dabdf90",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([182625, 3]), torch.Size([182625]))"
]
},
"execution_count": 85,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# training dataset\n",
"Xtr.shape, Ytr.shape"
]
},
{
"cell_type": "code",
"execution_count": 86,
"id": "e7b6f9ed",
"metadata": {},
"outputs": [],
"source": [
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"n_output1 = 100 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2,n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 87,
"id": "a1b009dc",
"metadata": {},
"outputs": [],
"source": [
"# find a reasonable learning rate \n",
"lri = []\n",
"lossi = []\n",
"learning_rate = 1\n",
"\n",
"for i in range(100000):\n",
" # mini-batch\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
" \n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[Xtr[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,6) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Ytr[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" \n",
" lr = 0.1 #lrs[i]\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
"# lri.append(lre[i])\n",
" lossi.append(loss)\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 88,
"id": "a4384ccb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 2.321897268295288\n",
"dev loss 2.329921007156372\n"
]
}
],
"source": [
"# compute loss on dev dataset\n",
"emb = C[Xtr] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ytr)\n",
"print ('train loss', loss.item())\n",
"\n",
"# can have even smaller loss with smaller learning rate\n",
"emb = C[Xdev] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ydev)\n",
"print ('dev loss', loss.item())"
]
},
{
"cell_type": "markdown",
"id": "e90a7368",
"metadata": {},
"source": [
"# 8. use a bigger hidden layer"
]
},
{
"cell_type": "code",
"execution_count": 89,
"id": "e9cca6e5",
"metadata": {},
"outputs": [],
"source": [
"Vdim=27\n",
"Cdim=2 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"batch_size = 32\n",
"n_output1 = 300 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2, n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 90,
"id": "699aceed",
"metadata": {},
"outputs": [],
"source": [
"# run this a bit more times to see if we can reduce loss further\n",
"stepi = []\n",
"lossi = []\n",
"\n",
"for i in range(100000):\n",
" # mini-batch\n",
" ix = torch.randint(0, X.shape[0], (batch_size,))\n",
" \n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[X[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Y[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" \n",
" # update the parameters\n",
" # we could do learning rate decay\n",
" lr = 0.1 if i < 10000 else 0.01\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
" stepi.append(i)\n",
" lossi.append(loss)"
]
},
{
"cell_type": "code",
"execution_count": 91,
"id": "00018c8d",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x12ada6be0>]"
]
},
"execution_count": 91,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXQAAAD4CAYAAAD8Zh1EAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAaqklEQVR4nO3deZwU9bnv8e+DILggQhgRxTBq3DBERESNieKCxyWJW+KVk6smJiGJ5lyN5h7HGDXRq4n3uETjkrjFJR40igYDirJGAQMZBGGQZZB1cDa2GbZhtuf80TVNDzO9MFtPFZ/369WvqfpVdddTXT3frq7V3F0AgPDrku0CAABtg0AHgIgg0AEgIgh0AIgIAh0AIqJrR06sb9++npub25GTBIDQmzt37np3z0k3XocGem5urvLz8ztykgAQema2OpPx2OQCABFBoANARBDoABARBDoARASBDgARQaADQEQQ6AAQEaEI9CmLS/Xk9OXZLgMAOrVQBPr0peV69sOV2S4DADq1UAQ6ACA9Ah0AIoJAB4CICE2gc+9TAEgtFIFulu0KAKDzC0WgAwDSI9ABICIIdACICAIdACIiNIHOMS4AkFooAp2DXAAgvVAEOgAgPQIdACKCQAeAiAhNoHPmPwCkljbQzewIM5tmZp+a2SIzuylo/7WZrTOz+cHj4vYq0jj3HwDS6prBOLWSbnX3j82sp6S5ZjYpGPaIuz/YfuUBADKVNtDdvVhScdC9xcwWSzq8vQsDAOyZPdqGbma5kk6WNDto+pmZLTCz582sd5LnjDazfDPLLy8vb121AICkMg50MztQ0lhJN7t7paSnJB0taYhia/APNfc8d3/a3Ye5+7CcnJzWVwwAaFZGgW5m3RQL81fc/U1JcvdSd69z93pJz0ga3n5lcoMLAEgnk6NcTNJzkha7+8MJ7f0TRrtcUkHblwcAyFQmR7mcKekaSQvNbH7Q9ktJo8xsiGLXzVol6cftUB8AIEOZHOUyQ81fH+udti8HANBSoTlTFACQWmgCnV2iAJBaKAKdM/8BIL1QBDoAID0CHQAigkAHgIgIT6CzVxQAUgpFoBu3iQaAtEIR6ACA9Ah0AIgIAh0AIoJAB4CICE2gc5ALAKQWikDn1H8ASC8UgQ4ASI9AB4CIINABICJCE+jcJBoAUgtFoLNPFADSC0WgAwDSI9ABICIIdACIiNAEOrtEASC1UAQ6Z4oCQHqhCHQAQHoEOgBEBIEOABFBoANARIQm0DnzHwBSSxvoZnaEmU0zs0/NbJGZ3RS09zGzSWZWGPzt3V5FGoe5AEBamayh10q61d0HSTpd0o1mNkhSnqQp7n6MpClBPwAgS9IGursXu/vHQfcWSYslHS7pUkkvBqO9KOmydqoRAJCBPdqGbma5kk6WNFtSP3cvDgaVSOqX5DmjzSzfzPLLy8tbUysAIIWMA93MDpQ0VtLN7l6ZOMxjFytvdreluz/t7sPcfVhOTk6LC3VO/geAlDIKdDPrpliYv+LubwbNpWbWPxjeX1JZ+5TI9dABIBOZHOVikp6TtNjdH04Y9Lak64Lu6ySNa/vyAACZ6prBOGdKukbSQjObH7T9UtLvJP3VzH4gabWkq9qlQgBARtIGurvPUPKtHue1bTkAgJYKzZmiAIDUQhPonPoPAKmFI9A5zAUA0gpHoAMA0iLQASAiCHQAiIjQBDr7RAEgtVAEurFXFADSCkWgAwDSI9ABICIIdACIiPAEOntFASClUAQ694gGgPRCEegAgPQIdACICAIdACKCQAeAiAhNoDuHuQBASqEIdA5yAYD0QhHoAID0CHQAiAgCHQAiIjSBzk2iASC1UAQ6p/4DQHqhCHQAQHoEOgBEBIEOABERmkBnnygApBaKQOcm0QCQXtpAN7PnzazMzAoS2n5tZuvMbH7wuLh9ywQApJPJGvoLki5spv0Rdx8SPN5p27IAAHsqbaC7+weSNnZALQCAVmjNNvSfmdmCYJNM72QjmdloM8s3s/zy8vJWTA4AkEpLA/0pSUdLGiKpWNJDyUZ096fdfZi7D8vJyWnh5CTn3H8ASKlFge7upe5e5+71kp6RNLxty2qMU/8BIL0WBbqZ9U/ovVxSQbJxAQAdo2u6EcxsjKQRkvqaWZGkuyWNMLMhip3vs0rSj9uvRABAJtIGuruPaqb5uXaoBQDQCqE4U1Ti1H8ASCcUgc4+UQBILxSBDgBIj0AHgIgg0AEgIgh0AIiI0AQ6Z/4DQGrhCHTO/QeAtMIR6ACAtAh0AIgIAh0AIoJAB4CICEWgs0sUANILRaADANIj0AEgIgh0AIiIUAU6N4oGgORCEeicKAoA6YUi0AEA6RHoABARBDoARASBDgAREapA5yAXAEguFIFunPwPAGmFItABAOkR6AAQEQQ6AEREqAKdfaIAkFwoAp1T/wEgvbSBbmbPm1mZmRUktPUxs0lmVhj87d2+ZQIA0slkDf0FSRfu1pYnaYq7HyNpStAPAMiitIHu7h9I2rhb86WSXgy6X5R0WduWBQDYUy3dht7P3YuD7hJJ/ZKNaGajzSzfzPLLy8tbOLkYrocOAMm1eqeox1I2adK6+9PuPszdh+Xk5LRoGuwTBYD0WhropWbWX5KCv2VtVxIAoCVaGuhvS7ou6L5O0ri2KQcA0FKZHLY4RtJHko4zsyIz+4Gk30kaaWaFks4P+gEAWdQ13QjuPirJoPPauBYAQCuE4kzRBhzjAgDJhSLQOfUfANILRaADANILRaDnr94kSarnxCIASCoUgT59aewM043bqrNcCQB0XqEI9AasoANAcuEK9GwXAACdWLgCnVV0AEgqZIGe7QoAoPMKVaADAJIj0AEgIkIV6GxyAYDkwhXoHOcCAEmFK9DJcwBIKlSBDgBIjkAHgIgIVaCzxQUAkgtVoNfVuyqrarJdBgB0SqEK9MemFOorv35fG7buzHYpANDphCrQJywsliSt38pldAFgd6EK9AYcjw4ATYUq0OvqY0FeW0egA8DuQhXoDf4wtTDbJQBApxPKQC8s3ZrtEgCg0wlloHOzaABoKqSBnu0KAKDzCWWgc5QLADQVzkAnzwGgCQIdACKia2uebGarJG2RVCep1t2HtUVR6azbvEMlFVU6tFePjpgcAIRCW6yhn+PuQzoqzBuc/tspKt/CNV0AoEEoN7k0OPW+ydpeXZvtMgCgU2htoLuk981srpmNbm4EMxttZvlmll9eXt7KyTVVsYPL6QKA1PpA/5q7D5V0kaQbzeys3Udw96fdfZi7D8vJyWnl5JpK3EFaV++qqatv82kAQBi0KtDdfV3wt0zSW5KGt0VRe+LTzyv11rwibamq0UWPfqBj7ni3o0sAgE6hxUe5mNkBkrq4+5ag+wJJ97RZZRn64Uv5TdpenLVK1301t6NLAYCsas0aej9JM8zsE0lzJE1w94ltU1br3P32IlXV1GlHdV22SwGADtPiQHf3Fe5+UvA40d3va8vCWmvovZN0wl3Nf798+nmlzn1oOvcnBRApoT5sMZXtwdr58rItTYb9fvIyrSjfplnLN3R0WQDQbiIb6A3Of/gD5eZN0Mr12+Jr5F3MJEnurp21dbr9zYVaz42nAYRcq079D5NzHpwuSbr/8sGauqRMklTnrncXlmjMnDWqrKrRo/9riLru00Xz127WG3PXah8z/ebSL2exagDI3F4T6A1++dbCePejkwvV76DY9WAmLCjWhAXFWvW7S3TZEzPj4xDoAMJirwv0RIVlW1VY1vh2dicm2ZEKYO/w8kerVL61WreMPDbbpeyxyG9D31PbdjvUcVqweQbA3uHOcYv02JRw3og+FIF+7vGHZG3a33/hX6qpq1f5lp2d+rj20soqVdXE6vusfKvmrNyY5YoAdLRQBPr5J/TL6vRfzy/SqfdN1gl3TdS2nemv7jhtSZlWlG9t0v741ELl5k1QfTM3Rf14zSbljV0gb+HdO067f4p+FJw1e95D/9BVf/pIm7dXqy7FDVinLimNfwlkavP26hbV1xLVtfXauC359HZUx45QyuQCbWWVVZq6pLRJu7unfQ/GzV+nM347JeV72ZlV1dRl/LmqqqlTxfbG7+cT05brg2V7fmG9eWs26d2FxXv8vFv/+ome+WBFs8MaVq5SWVC0WSvXb1PBuoo9nnbYhSLQg6MMsyZxR+pzM1bGuwvWVSg3b4KWl23R8rKtWrtxu6TYWv25D/1DkvSXf66Oh/uD7y+TFDu6ZndXPDlLr/5rrapq6vXWvCLl5k1QVU2d5q3ZpIkFxbrtjQWN/qmWlW7Rk9OXS4pt85OkDwvXN3rNIfdM0v99/RNJsVDaurNWb8wt0uRPS/XWvCJd/0K+fvP3RcrNm6DX/rVGm7dXq6SiqtFrbN1Zq+KKHZKkiQXFGnLPJD0wcYken7rrJ+nO2jrd9Oo8/X7yMm3cVq3t1bXxw0Cra+u1dbcvwav++JF++pe5Wla6Rc9+uEJ//+RzSdKaDdtVVlml3LwJuvGVj/WTv8zV0HsnafrSMs36bNe8bdxWrbItVRozZ43GzFmjO95aqDUbYu/97BUb9NT0z1RX7xo3f138y/OKp2bp+hfy9cDEJTruV+9q7upNGju3SA+9v0zH3zlRz364QnePK1DFjhqt3bhdc1dv1PKyrRr9Ur5uenW+iiuqtLi4UqWVsfdnRuF65eZN0NB7J+nZDxuHz8v/XK2CdRUaO7dIf5hSqKqaOhVt2q6nP/hM3/zDDF3x5K6d7hc88g/l5k3QztrYl0pxxY5GX5qzPluvy56YqQcmLtE1z82OL2tJ2razVoPumqgnpi3XG3OLJMW+oC59Yqb+Nm+dyrfs1JaqGh1/50T94vUFem9RiXb3yuzV8fdOkr71+AyddM/7Ou+h6fEvuv96b6mufX6O6utdlVU1ys2boNEv5aukokrXPDdbEwuKm6ykzFy+Xpc/OUs/feVjTSwoiX+h1NbV61d/W6h1m3fE652+tCx+Yb1x89dp7MdFuu+dxXpy+vL4tNZu3K6dtXUa8V/Tdep9k5WbN0FTl5Sqvj72hby0ZEv8Utrfenymznlwur7xhxlaUb5Vv3xrof42b128ttfz1+qucQWat2aTdtbW6c2PizToromq2F6T9Mu9qqZOv31nsbZX1+qteUWNTkqcvrRMbwef4YJ1Fbr08RnKzZugz4L/+5c/WqWV67c1+7ptzVq6RtgSw4YN8/z8ptdeSWfs3CLdGgRTZ3fryGP10KRlGY174zlHa2nJVs1fuzkegGa7riDZ98DubXp8/NnH5ugfe7CmdclX+mvCgtga1n//8DT9fcHnGjNnbaNxDuvVQ8Ny+8Q/0Il+d8Vg5b0Z+zK8ePChemdhicb86HSNeuafTcY99KAeKqmsatKeyomHHaRFn1fG+xffc2GzZwcf2feANv2H+vLhB6lgXWWT9sf//WRdMri/jrz9nbSv8dcfn6GePbrqokc/lCR179pFO2tjVwrt0a2Lnr32VP3v52anfI2nrzlFo1+e26jt68f0bfLFnuhHXz9SZVt26oErv6KbXp2n9xY1/dWSqYP376bNwdr8rSOP1XMzV+rEww7SDSO+pO8+27j2G0YcrV9ccJymLS3TD17M12lH9tErPzxNX0q4mF7i62XqwhMP1cSEL6orhw7Q2I+Lmh33ye8O1YjjcjTorvfibacM7K25qzc1O/4VQw9XFzPNWr5enwcrOvt27aLqYDkN6n+QNm6rjn9ue+3Xrcmvxfd/fpYueOQDSdLMvHN1+MH77dH8NTCzuZncRCgUgf72J5/r/4yZ1w4VAUDH+MnZRyvvouNb9NxMAz0Um1yO69cz2yUAQKsk++XQlsIR6IcS6ADCrSPugRyKQJekZ67t0HtQA0Cb6nvgvu0+jdAE+pAjDs52CQDQYsMG9mn3aYQm0HN6dtd/nPulbJcBAC3y/qdNDxtta6EJdEn6+fnHauGvL2jxoT8AkC0dcV5aqC7O1aWLqWePbpp8y9l6fuZKndC/p47OOVDFFVU6pGd3TV1Spv83YbFuGXmsJi8u1YKi2Jlib97wVQ04eD8Nv39Kk9d84MrBum3srhOHdj+uFQDCIhTHobfE4uJKXfToh3rwOyfp26cMkCTl5k2QJF13xkDd9c0TtaOmTgd27xpvl6Q5d5ynQ3r20KLPK3TJYzM6pFYA0ffa6NN12lFfaNFzI3Ucekuc0P8gzb9rZDzMJemj28/V8Nw++vnIY7VPF9OB3WM/UFbcf7EK77tIM247R4f0jF0f/cTDeunm84+JP/eovgfozm8M0vj/+Jok6bYLj9fj/36yvn5MX4278Uzde9mu66Zfe8bARrWsuP9ifSehjoevOklXn3qEfnZO430Co4YfoVtGHqsBvZtuUtqv2z6N+u/+5iC9d/NZ8f7zT+inab8YEe9/8frh8e6+B+6riwcfqoN6NP+D7MZzjpYkTb31bPXar5uk2Bl3k285q8m4N59/jK4cGpuX4/r11ANXDtbiey7UpUMOi48z+PBejZ5z/gn9dMXJhwfduy60dvWpR8S7rzl9oK4aNkCPjTpZb97wVUlS7/27acRxOTr84P100oBe6t+rR7P1J3roOyfppYR5l6SfnfMlffnwgyRJU249W4vvuVD3Xz5Y3z3ti5Kkn444Ov4e7O74Q3vqPy88Lt5/4zlH65Ce3TVq+K7ae3Rr+b9R4vuWzJ+/d6ruu/zLGjmon7731dx4+4/PPkqjzzoq7fMf/M5JzS7L3R3V94B497dOalzXxYMP1c3nH6P9uu2jUcO/mPQ1fnXJCfHul64fruODQ46Pyom99sH7xz5f/Q7qrvsvH5z0dT6+c6T23af59zXx83XvpSc2+t9q8IUDUh9R8sU++6cc3lIN/z8NvjJgV60tDfM94u4d9jjllFM8bJ6YVuifrN2U0bglFTt8edmWeP/SkkpfUlwZ76+urfOyyqomz6urq/eKHdWN2gpLK71g3eZGbYuLK3zgbeN94G3jk9YwdUmpT1pU4u7u23bWNDs9d/c38tf60bdP8J01dU2GVdfWeX19vbu7f7is3Nds2JZ0eok2bt3pdXX1GY2bqKa2aQ2p5K/a6LV19fEa3d3vHlfgj0xa2mi8ok3bvbRiR8avu6O61ues3OCr1m/1HdW1GS/3RIWlW7x4865p3vv3RZ6bN95rg/dl8/ZqX7Nhm9fXN66/waRFJT7oznf9883b48/Z3bKSyiafl7LKKh9423ifubw83la0aXujeXhp1kr/cFm5r16/zSt3VHv5lir/52frG73Ohq07fUd1bbx7WUmlN6fhczh1cWnaZf7+ohLfvK066fCBt433u8cVeG1dvW/YurPRsJKKHf7anDXu7j57xQavqqlNOS1398rgvZm2pNR3VNd6bV29V6f4jNXU1vn973wan+8d1bXNLpvN26s9f9VGf2Jaob84a2WjYXV19fH3YfX6bT7wtvH+wbIyd3d/+aNV/pOX89PWnYqkfM8gYyO7yQVA+/nzzJUafmQfnXhYr/Qjo9Uy3eQSqp2iADqH7595ZLZLQDMiuw0dAPY2BDoARASBDgARQaADQEQQ6AAQEQQ6AEQEgQ4AEUGgA0BEdOiZomZWLml1C5/eV1Ly25lHE/O8d2Ce9w6tmeeB7p6TbqQODfTWMLP8TE59jRLmee/APO8dOmKe2eQCABFBoANARIQp0J/OdgFZwDzvHZjnvUO7z3NotqEDAFIL0xo6ACAFAh0AIiIUgW5mF5rZUjNbbmZ52a5nT5jZEWY2zcw+NbNFZnZT0N7HzCaZWWHwt3fQbmb2WDCvC8xsaMJrXReMX2hm1yW0n2JmC4PnPGZm1vFz2pSZ7WNm88xsfNB/pJnNDup8zcz2Ddq7B/3Lg+G5Ca9xe9C+1Mz+LaG9030mzOxgM3vDzJaY2WIzOyPqy9nMfh58rgvMbIyZ9Yjacjaz582szMwKEtrafbkmm0ZKmdynLpsPSftI+kzSUZL2lfSJpEHZrmsP6u8vaWjQ3VPSMkmDJP1/SXlBe56kB4LuiyW9K8kknS5pdtDeR9KK4G/voLt3MGxOMK4Fz70o2/Md1HWLpP+WND7o/6ukq4PuP0r6adB9g6Q/Bt1XS3ot6B4ULO/uko4MPgf7dNbPhKQXJf0w6N5X0sFRXs6SDpe0UtJ+Ccv3e1FbzpLOkjRUUkFCW7sv12TTSFlrtv8JMngzz5D0XkL/7ZJuz3ZdrZifcZJGSloqqX/Q1l/S0qD7T5JGJYy/NBg+StKfEtr/FLT1l7Qkob3ReFmczwGSpkg6V9L44MO6XlLX3ZerpPcknRF0dw3Gs92XdcN4nfEzIalXEG62W3tkl7Nigb42CKmuwXL+tyguZ0m5ahzo7b5ck00j1SMMm1waPjQNioK20Al+Yp4sabakfu5eHAwqkdQv6E42v6nai5ppz7bfS/pPSfVB/xckbXb32qA/sc74vAXDK4Lx9/S9yKYjJZVL+nOwmelZMztAEV7O7r5O0oOS1kgqVmy5zVW0l3ODjliuyaaRVBgCPRLM7EBJYyXd7O6VicM89hUcmeNHzewbksrcfW62a+lAXRX7Wf6Uu58saZtiP5PjIrice0u6VLEvs8MkHSDpwqwWlQUdsVwznUYYAn2dpCMS+gcEbaFhZt0UC/NX3P3NoLnUzPoHw/tLKgvak81vqvYBzbRn05mSvmVmqyS9qthml0clHWxmXYNxEuuMz1swvJekDdrz9yKbiiQVufvsoP8NxQI+ysv5fEkr3b3c3WskvanYso/ycm7QEcs12TSSCkOg/0vSMcGe830V25nydpZryliwx/o5SYvd/eGEQW9LatjTfZ1i29Yb2q8N9pafLqki+Nn1nqQLzKx3sGZ0gWLbF4slVZrZ6cG0rk14raxw99vdfYC75yq2vKa6+3clTZP07WC03ee54b34djC+B+1XB0dHHCnpGMV2IHW6z4S7l0haa2bHBU3nSfpUEV7Oim1qOd3M9g9qapjnyC7nBB2xXJNNI7ls7lTZgx0SFyt2dMhnku7Idj17WPvXFPuptEDS/OBxsWLbDqdIKpQ0WVKfYHyT9EQwrwslDUt4reslLQ8e309oHyapIHjO49ptx1yW53+Edh3lcpRi/6jLJb0uqXvQ3iPoXx4MPyrh+XcE87VUCUd1dMbPhKQhkvKDZf03xY5miPRylvQbSUuCul5W7EiVSC1nSWMU20dQo9gvsR90xHJNNo1UD079B4CICMMmFwBABgh0AIgIAh0AIoJAB4CIINABICIIdACICAIdACLifwCVTs/l6kGHJQAAAABJRU5ErkJggg==\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(stepi, lossi)\n",
"# there is flucturation because of mini-batch\n",
"# increase batch size to be more stable"
]
},
{
"cell_type": "code",
"execution_count": 92,
"id": "697a5cdb",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 2.329075336456299\n",
"dev loss 2.3230931758880615\n"
]
}
],
"source": [
"# compute loss on dev dataset\n",
"emb = C[Xtr] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ytr)\n",
"print ('train loss', loss.item())\n",
"\n",
"# can have even smaller loss with smaller learning rate\n",
"emb = C[Xdev] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ydev)\n",
"print ('dev loss', loss.item())"
]
},
{
"cell_type": "code",
"execution_count": 93,
"id": "c7fde9ee",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAewAAAHSCAYAAAAuWvi9AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAyk0lEQVR4nO3de3zcdZ3v8fcnk0lIk5aUQi+mxUIpeCHUQilVBFLLKu1RqYq7Zb1Qi/RYceWy7lk9cPSwu6yrHpVl1XqKdgH3ElwV7MF2qyDh3qv0ztILF9PSC5SmJWmazOV7/shMSSa/SeaWSb6T1/PxyIPM/H6/+f3yZTrv+X5/34s55wQAAIa2ssG+AAAA0D8CGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ED5YF9AOqeffrqbPHlyxvu3tbWpurp64C7IY5RNMMolPcomGOWSHmUTLNty2bhx4+vOuTOCtg3ZwJ48ebI2bNiQ8f5NTU1qaGgYuAvyGGUTjHJJj7IJRrmkR9kEy7ZczOyVdNtoEgcAwAMENgAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEAyEE0FtexExHF4q4o5xuyc4kDADDUdERjWrl1v5Y27dGuQ60qLzNF407njq3RFxqmaF79BFWWhwbk3AQ2AAAZ2NTcooXL1ykSi6utMyZJisS6atcvHGzV7Q9u0x0rdui+RTM1bVJtwc9PkzgAAP3Y3Nyia5etUUt75GRYp2rrjKmlPaIFy9Zoc3NLwa+BwAYAoA8d0ZiuW75O7ZHgoE7VHunavyOa2f6ZokkcAIA+rNy6X5FYvMdzf/GBczR/ep3eaOvU/pZ2bd13TPc8+eLJ7ZFYXKu2HlBtAa+DGjYAAH1Y2rSnRzP4BRNP1dzzx2vePz6phcvXqX5iba9j2jpjWtq0u6DXQQ0bAIA0YnGnXYdaezw34+2j9bsdB9URjatD0qPPHww8duehVknVBbsWatgAAKTR1hlVeZnldGx5mSnuCjdGm8AGACCN6opyRVMmRtnwyhHNeec4VZaXaURFSB94x9jAY6NxpzLLLeyD0CQOAEAaoTLT1LE12nnwrWbxLXuP6pHnD2rVTZfp9dZOvXDwTb15ItLr2HPH1kiihg0AQFEsaZii6oqes5cte+JFfeC7j+uzy9eqrrZKW/cd7bG9uiKkJQ3nFPQ6CGwAAPowr36CwqGecfnNj9dr5Zffr9/8xWX6z20HtP3VYz22h0Nlmls/vqDXQZM4AAB9qCwP6b5FM7Vg2ZqTk6fc1Lgp7f5V4a79Cz2nODVsAAD6MW1SrRoXz1JtVbhX83hSdUVItVVhNS6eNSBziVPDBgAgA9Mm1WrtbXO0ausBLW3arZ09VusaqSUNUzS3fjyrdQEAMNgqy0OaP71O86fXKRZ3auuMqrqiXKEcx2png8AGACAHoTLTqFPCRTsf97ABAPAAgQ0AgAcIbAAAPEBgAwDgAQIbAAAPENgAAHiAwAYAwAMENgAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeIDABgDAAwQ2AAAeILABAPAAgQ0AgAcIbAAAPEBgAwDggbwD28wmmdljZrbDzLab2U0B+5iZ3W1mu81si5ldmO95AQAYTsoL8BpRSX/pnPuDmY2UtNHMfuec29Ftn7mSpiZ+LpG0NPFfAACQgbxr2M65/c65PyR+f1PS85LqUna7WtL9rssaSbVmNiHfcwMAMFyYc65wL2Y2WdITks53zh3r9vzDkv7BOfdU4vGjkv7aObch5fjFkhZL0rhx4y5qbGzM+Nytra2qqanJ+28oRZRNMMolPcomGOWSHmUTLNtymT179kbn3IygbYVoEpckmVmNpF9Kurl7WGfDObdM0jJJmjFjhmtoaMj42KamJmWz/3BC2QSjXNKjbIJRLulRNsEKWS4F6SVuZmF1hfW/Oud+FbDLPkmTuj2emHgOAABkoBC9xE3STyU975z7XprdVkj6bKK3+CxJR51z+/M9NwAAw0UhmsQvlfQZSVvNbFPiuf8p6UxJcs79WNJKSfMk7ZZ0XNLnCnBeAACGjbwDO9GRzPrZx0m6Md9zAQAwXDHTGQAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeIDABgDAAwQ2AAAeILABAPAAgQ0AgAcIbAAAPEBgAwDgAQIbAAAPENgAAHiAwAYAwAMENgAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeIDABgDAAwQ2AAAeILABAPAAgQ0AgAcIbAAAPEBgAwDgAQIbAAAPENgAAHiAwAYAwAMENgAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeIDABgDAAwQ2AAAeKEhgm9lyMztkZtvSbG8ws6Nmtinx8/VCnBcAgOGivECvc6+kH0i6v499nnTOfbhA5wMAYFgpSA3bOfeEpDcK8VoAAKC3Yt7Dfq+ZbTazVWb27iKeFwAA75lzrjAvZDZZ0sPOufMDto2SFHfOtZrZPEn/6JybGrDfYkmLJWncuHEXNTY2Znz+1tZW1dTU5Hr5JY2yCUa5pEfZBKNc0qNsgmVbLrNnz97onJsRtK0ogR2w78uSZjjnXk+3z4wZM9yGDRsyPn9TU5MaGhoy3n84oWyCUS7pUTbBKJf0KJtg2ZaLmaUN7KI0iZvZeDOzxO8zE+c9XIxzAwBQCgrSS9zM/l1Sg6TTzWyvpG9ICkuSc+7Hkq6RtMTMopLaJS1wharaAwAwDBQksJ1z1/az/QfqGvYFAABywExnAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeIDABgDAAwQ2AKDkRGNxHTsRUSxeOstWFGQucQAABltHNKaVW/dradMe7TrUqvIyUzTudO7YGn2hYYrm1U9QZXlosC8zZwQ2AMB7m5pbtHD5OkVicbV1xiRJkVhX7fqFg626/cFtumPFDt23aKamTaodxCvNHU3iAACvbW5u0bXL1qilPXIyrFO1dcbU0h7RgmVrtLm5pbgXWCAENgDAWx3RmK5bvk7tkeCgTtUe6dq/I5rZ/kMJTeIAAG+t3LpfkVi8x3M3zj5Hn7iwTofbOrW/pV1b9x3TPU++eHJ7JBbXqq0HNH96XbEvNy/UsAEA3lratKdHM/j5daP0kWkTNO/uJ/W5f16vCybW9jqmrTOmpU27i3iVhUENGwDgpVjcadeh1h7PzZx8mlZvP6gTkbikuB55/mDgsTsPtSoWdwqVWRGutDCoYQMAvNTWGVV5joFbXmZq64wW+IoGFoENAPBSdUW5oikTo6x96Q198F3jVFlepuqKkOa8c1zgsdG4U3WFX43Mfl0tAAAJoTLT1LE12nnwrWbx7a8e08Nb9mvVTZfpcFuntuxtCTz23LE1XjWHS9SwAQAeW9IwRdUVPWcv++Fju/WB7z6uT/74Wb30eluvY6orQlrScE6xLrFgCGwAgLfm1U9QOJRdlIVDZZpbP36ArmjgENgAAG9Vlod036KZqgoHzxF+1yO7eozBrgp37e/jnOIENgDAa9Mm1apx8SzVVoV7NY8nVVeEVFsVVuPiWd7OJU6nMwCA96ZNqtXa2+Zo1dYDWtq0Wzt7rNY1Uksapmhu/Xgva9ZJBDYAoCRUloc0f3qd5k+vUyzu1NYZVXVFuXe9wdMhsAEAJSdUZhp1SniwL6OguIcNAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeIDABgDAAwQ2AAAeILABAPAAgQ0AgAcIbAAAPEBgAwDgAQIbAAAPENgAAHiAwAYAwAMENgAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeKAggW1my83skJltS7PdzOxuM9ttZlvM7MJCnBcAgOGiUDXseyVd1cf2uZKmJn4WS1paoPMCADAsFCSwnXNPSHqjj12ulnS/67JGUq2ZTSjEuQEAGA6KdQ+7TlJzt8d7E88BAIAMmHOuMC9kNlnSw8658wO2PSzpH5xzTyUePyrpr51zG1L2W6yuJnONGzfuosbGxozP39raqpqamtz/gBJG2QSjXNKjbIJRLulRNsGyLZfZs2dvdM7NCNpWXrCr6ts+SZO6PZ6YeK4H59wyScskacaMGa6hoSHjEzQ1NSmb/YcTyiYY5ZIeZROMckmPsglWyHIpVpP4CkmfTfQWnyXpqHNuf5HODQCA9wpSwzazf5fUIOl0M9sr6RuSwpLknPuxpJWS5knaLem4pM8V4rwAAAwXBQls59y1/Wx3km4sxLkAABiOmOkMAAAPENgAAHiAwAYAwAMENgAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeIDABgDAAwQ2AAAeILABAPAAgQ0AgAcIbAAAPEBgAwDgAQIbAAAPENgAAHiAwAYAwAMENgAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeIDABgDAAwQ2AAAeILABAPAAgQ0AgAcIbAAAPEBgAwDgAQIbAAAPENgAAHiAwAYAwAMENgAAHiCwAQDwAIENAIAHCGwAADxAYAMA4AECGwAADxDYAAB4gMAGAMADBDYAAB4gsAEA8ACBDQCABwhsAAA8QGADAOABAhsAAA8Q2AAAeKAggW1mV5nZC2a228y+GrB9oZm9ZmabEj+fL8R5AQAYLsrzfQEzC0n6oaQ/kbRX0nozW+Gc25Gy6wPOuS/lez4AAIajQtSwZ0ra7Zx70TnXKalR0tUFeF0AAJBgzrn8XsDsGklXOec+n3j8GUmXdK9Nm9lCSd+U9JqknZJucc41B7zWYkmLJWncuHEXNTY2Znwdra2tqqmpyeMvKV2UTTDKJT3KJhjlkh5lEyzbcpk9e/ZG59yMoG15N4ln6P9J+nfnXIeZ/XdJ90n6QOpOzrllkpZJ0owZM1xDQ0PGJ2hqalI2+w8nlE0wyiU9yiYY5ZIeZROskOVSiCbxfZImdXs8MfHcSc65w865jsTDn0i6qADnBYCsRWNxHTsRUSyeX+siUGyFqGGvlzTVzM5SV1AvkPTn3XcwswnOuf2Jhx+V9HwBzgsAGemIxrRy634tbdqjXYdaVV5misadzh1boy80TNG8+gmqLA8N9mUCfco7sJ1zUTP7kqTVkkKSljvntpvZ30ja4JxbIenLZvZRSVFJb0hamO95ASATm5pbtHD5OkVicbV1xiRJkVhX7fqFg626/cFtumPFDt23aKamTaodxCsF+laQe9jOuZWSVqY89/Vuv39N0tcKcS4AyNTm5hZdu2yN2iOxtPt0hXhMC5atUePiWYQ2hixmOgNQkjqiMV23fF2vsN5+x4cC92+PdO3fEU0f7sBgIrABlKSVW/crEotndUwkFteqrQcG6IqA/BDYGFLowYtCWdq05+Q960y1dca0tGn3AF0RkJ9ijcMG0qIHLwotFnfadag1p2N3HmpVLO4UKrMCXxWQHwIbg4oevBgIbZ1RlZfZyfdSNsrLTG2dUY06JTwAVwbkjiZxDJpkD96W9kjapsu2zpha2iNasGyNNje3FPcC4a3qinJFc7ytEo07VVdQl8HQQ2BjUKTrwZsOPXiRjVCZaerY3Oa1PndsDc3hGJIIbAyKoB68E0dX6ZFbr9A3P16v395yue5fNFOV5W+9RenBi2wsaZii6ors+j5UV4S0pOGcAboiID8ENgZFuh68k8eM0M+efUUf/P4TOnYiornnTzi5jR68yMa8+gkKh3p/xL37G6vTHhMOlWlu/fiBvCwgZwQ2iq6vHrzNR9q1Y/8xSdK2fUc18bSqHtuTPXiB/lSWh3TfopmqCmdWy64Kd+3PiAQMVQQ2ii7ZgzdIZ/StZvJYXL32S/bgBTIxbVKtGhfPUm1VOG3zeHVFSLVVYaYlxZBHV0gUHT14UUzTJtVq7W1ztGrrAS1t2q2dPcb6j9SShimaWz+emjWGPD75UHTJHrw7D2Y/sQU9eJGLyvKQ5k+v0/zpdYrFndo6o6quKOe9BK/QJI5BEdSDd++Rdn3oridOPr7nyRd11yO7Tj6mBy8KIVRmGnVKeEiENVPxIhvUsDEo5tVP0B0rdkjKfFw1PXhRCpiKF7miho1BQQ/eoYWaXnFsam7RJXc+qtsf3KadB1vlXNdUvM69NRXvJXc+yqx+CEQNG4Mm2YP3upS5xLurrggpHCpjLvEB0FdNb8k7IuqIxviCVEDJqXj7mt2v699ATAuWraHXOnqhho1BlezBe+fH6nXeuBqZSeGQyUw6b9xI3fmxeq29bQ4fXAXWX01v35F2anoFxFS8KARq2Bh09OAtrkxqenHnTi66Qk0vf6lT8f71Vefp1ZYT+tmaVyRJN185VW0dMd3z5Isn90lOxTt/el3RrxdDEzVsDClDqQdvKaKmNzhSp+J9eMt+ffiCt6bd/W/1E/Twlld7HMNUvEhFYAPDSNCiK/1h0ZX8BE3Fu/3VYxpTU6mxIyv1zgkjdbQ9ov1HT/Q6lql40R1N4sAwkm7RlYmjq3Tfopna1Nyi+rpT9cyzz0qKSHqrpkfTbG6SU/FGYj2D9zdb92te/QSdMbJSD2/ZH3hscireUaeEi3GpGOKoYQODqJjDqfpadEWSzhpTrX9JrJT2Zlt7j23U9HKXbirehze/qo9Me5vmnj9ev9kaHNiRmNPvth/glgQkUcMGim6wJs5IV9NL2tfSrufS9Aqnppe7dFPx7jrUqurKkA4e69Brb3akPf7rv96uv334+YIMbYzG4joeidGp01PDJrB5o2Io2NTcooUp486TAZqcOOOOFTsGZNx5f4uuHA9oKk9i0ZX8LGmYotsf3NbrdsRVdz0ZuP8vl7xPn1j6jKT8x2Yzs1rpKOkm8Y5oTA8+t1cf/P7jmnr7Kl30t7/TObet1Ie+/7gefG4vzUwoquRwqpb2SOB9ZKnrwzk5nKrQY6CTNb1csOhKfubVT1A4lPnHbTKsu8ulxz4zq5WWkg1s3qgYSobKcKqgRVf6w6Ir+ct2Kt7td3wo8PlseuwP9hdEFF5JBjZvVAw1Q2U4VbqaXupKad2x6EphJKfira0KZ/2lKSnTsdlD5QsiCqvkAps3Koai1OFUF0w8VatuukyV5WWqCof021su17njejZXD8TEGSy6MriSU/H+7fzzc36NTHrsd/+COHF0lR699Qr9n09eoN//5RW668/eo0vPGaNffOG9euwrDZo28VRJjLf3QckFdmpN5pY/OVeLLp188vFXPniePtftscQbFQMraDjVlr1H9cjzB/WXHzxPX5v3Dj303L5evYilgRlOlUlNr8xMtVVhpiXNUDbD8yrLQ7ryXeMUDuXWJyDZY78vqV8Q3z5mhO554iXN+d7jmnJGja5+T52u+fGz+vuVz+vG2V23O5hZbegrucBOfaP+x4ZmffzCiZIkM+kj0ybowef29TiGNyoGUnI4Vaq7H92ly6aergvqTtWPH98TeGwmH8656G/RlbrRVSy60o98OrX212O/L/312A/6gth8pF0vHHxTzkk7D72pp3e/Lkn6rwPHNHF01cn9GG8/tJXUOI2gN+reI+06crxT737bKJ1eU6ntrx5Ty/FIr2OTb1R6wqLQ0n04146o0IiKkMrLTJXlocDbOAM5nKqvRVeamppoBu9DvsPz0o3NTuorMvvrsR803r4z+laro3Pu5GPnpFDZW/U2xtsPbSVVw05Xk3lgfbOuuWiiPnnRRP18Q3PgsQNVkwHSDaf6+4/V67u/3amHNr2qr859R+CxxRpOVcqLrqQ2V+c7u1yhOrWm67FfOyKsluOdgcdk0mN/IGvvGFwl9X8m3Rt19fYDuuVPzlW4rExfbnwu8FjeqBhIqRNnfPzCOkXjca3Y/KrKTPrVkvfpvVPG6Nk9h08ew3Cq3AVOFhJzCofK1BmLq7xMijllPXmIk3Lq1Lr2tjm9Xn9e/QTdsWKHpLdea+zISjUunqV7nnwp8PUy6bHfX+29L4y3H9pKKqHSvVEjMac1ew7r2ImI0n3x5I2KgZT64fyrP+zTr/7Q1Zci7qT5P+o9UQbDqXLTV3N1Z6JDarKFONvZ5Y62R3p0av3UJWfqU5ecKUkaeUpYe4+069p71vQ4Jt261ske+wu6rU1+6M0OfeC7jweeO5se+92/IKYO2fvKf2w5+Xv3bXxBHPpKqklcCm5mMpOmn1mrB9YHN4fzRsVAYzhVcWTSXJ1085VTdcNlZ2c1J8Nrb3b0eN1/XftHzbv7KX30B09r/9ET+slTL/Y6pq9OrZn02K+uCGXdYz/bmdUkviD6oOQCO/WNes7YGj3+ldl6es9hvXz4eOAxvFFRDAP14Ywu2c7BkKq/ORlicacTaV77Gx95t57d87oeff5Q4Pa+el/312P/zo/VZ91jny+IpamkmsSl3s1Muw+16vLvPJZ2f96oKKbkh/OqrQe0tGm3dvZYjGGkljRM0dz68bwfc5DJbHI3zj5Hn7iwTofbOrW/pV1b9x3rsT1d87XU1anVrPdts2sumqi60VX6+optac/bX+/rvnrs5yr5BfG6lNsD3VVXhBQOlQ3IYjMovJILbIk3Koa2gfhwRs85GBZffrY6o3Hd+8zL+l8ffqfeOWGU/n7l8/qziydqc/NRfe1XW/XwX7y/V2Anm6+DAru6olzO9awln183SjdcdrY++X+fkeujY3Y2nVqTPfYLgS+IpaUkA1vijQo/FPLDeThLnYNh/Utv6POXna17n3lZ9XW1qigv06yzTtOrLSf07J7Dau2I6pHnDwa+Vro5GaLxeK/7wte9d7JqR4TVeMMsSdKWfUf11V9u7fWag9mplS+IpaNkA1vijQr4LJs17FMnC9m676jq605VTWW5OqNxbX/1qN5WW6W31VZp3ctv9PlaQc3XyZ7nN0ztWY3+q19sST28l8ryMn3+srP73a8Y+ILot5IO7O54owJDX+D46bjrd7x06hwM0bhT85Hjuuaiidr4xyP6r/3HNOvsMRo/6hQ1v3Fc1RUhzXnnOP3b2j/2eq3U5utkz/P2SExx5zRxdJV+et3FaVc36/03xfV3D+/QueNGcvsNeSm5XuIA/LTxlSOa+XeP6LZfbc16Dfug2eTWv/yGbrj8bK176bDWv/yGrnzXOL1yuE2rbrpM9y6aqS17e7+O1LP5Ot+e50lHT0RZyhd5I7ABDJrkAhqXf/sxfWLpMzp6IqrjkeCe3v2Nl06dg2HdS29o7MhK/eGVFr3e2qmOSEyN65v1ge8+rk/++Fnd1LhJ9zzZc9x06pwM/fU8n3RalX7z5ffrgsQSlUkL3zdZj9zatZRlEkv5Il/DpkkcwNASNCNZJtJN95k6m9wzew5r6m2rTm5PN4NYd6lzMqSu/tfd2adX65/+fLq+8h+b9fz+N3ts+8yst+tTP1mrA8dO9Hi+r2FjQH+oYQMouqAZyUadUq5Pz3p7RscHrWEfMtOPPn1hxpOFpEqdkyFo9b+k06ordM9nZ+imxk29wvrO+edr0mkjdO+ii3X9+8/qsS2TpXzzXZwEpYsaNoCiSndfeFRVWJ+Z9Xb9y5pX9PEL63RDomf18/uP6dafb+6xbzL45taP79VJLRJzKrOu2nJHtO+JVKT0czIELVOZ9OaJiPa1tOviyaO1OyXUb3tom6447wxdu2yNjmS4lG+une0wvBDYAIoq3X3hv77qHXr7mBF69NYrNKamQrP/T5OOHI/o1Krg0R0vHGzVzDsfVTRgkY+4k+S6grtrzWdTLB60Wlf6ORn6WqYyEnP67z/bqPuvn6m2jphWbH41478/ddhYvmtrY/ggsAEUVbr7wt/6z//SueNG6t/W/VFn1FSerJ0ebe9dS03qa1tHIvROKS/T8oUX65Kzx5wM7kzmZOhvmcr2SEzX37teP7v+Eh3vjOqRNPOIp+o+bKz7kLGJo6v0qy/2Hi7WVVYxLVi2hjnmhznuYQMomr7uCw+UE9G4vvivf1A03lWrT87JkMkESkGr/3VfkvLYiaiu/uHTGYe19NawsWyHjNHLHAQ2gKJJ3hfuyzO7X9e8+vGqHdHVZJyuSTwbQZ3UMpHLMpV96T5sLOjWQKjM9M2P1+u3t1yu+xfNVGV5z3Pn+negNBDYAIqmr/vCrR1RVVeGtOtQq3742G49sPi9WnXTZfpfH35nj/0mjq7S6psv7/FcVTik5Qsv1qqbLtPqmy/Xhy+Y0GN7W2dMP2ranXXv62yXqUx6/7ceC+xw1n3YWNCtgcljRuhnz76iD37/CR07EdHc83v/Hf31Mkfp4h42gKLp675wy/GINr5yRKtvvlxNLxzKeOpPSbrivDN08NgJLbp3vSRpZGXvj7adB1t14d/8NtHZLPPe191X/yuz3Jujuw8bS3droPlIu3bs71pBbNu+o5p4WlXvvyPN4iQofdSwARRV0H3hpJsaN+lDdz2hb676rz5fo7zMdNefvUeP3HqFfvSpC/Xy6626bOrp+upV79DFk0frzY5o4HHRuDKa6jRVcvW/utFVOm9cjcykcMhkJp152giNCIc0Ihz8cVpdEVJtVbhHh7F0twY6uw1Di8UVuE+ylzmGH2rYAIoqdUayXEwZW6P/8cst2vjKEX37mgt02dSx+m93P6XZ543VVz54np7e87rufrTvpuNse19XlncF7+pbrujV07wjGstqKd++bg30J5u1tVFa+L8OoKiS94UXJIYz5WJfS7s2vnJEkvTgc/v0hcvP1v3PvqyHNu3TsRMRLbh4UsavlW6q076krv6X7VK+/Q0Z68tgrq2NwUWT+BDHNIUoRcn7wrVV4bTN4+melyTnev57qK4s10M3XqqVX36/bpozVf/0++w6ZhWy93Wmw8ZSbw10Hy4mSfc8+aLuemRXj2NSFyfB8EINewhimkIMB8n7wn01Jf/gsV3afait17ETR4/QhWfW6g9/bNHV73mb/nP7Af3kyZfSnmvi6Crd97mZ2rrvqM6vO1U7D76pW3++SScSK4Mle18Xc1GOXG4NpC5OguGFwB5imKYQw0l/TclOTrc/uK3X8Kc9h1r1mfdO1revOVW7Dr2pf1nzSr/nSr3v/ZlZk3ssr1ns3tfZ3hpIXZwEww9N4kNI0ApGqfpbExjwVVBTctDEJXuPtGvO9x7XLQ9s0pXfe1xL/uUPJ2vKfUm9733x5NE9tg9G7+tMbw2k9jLH8ERgDxF9TVP4yyXv6/Uc0xRiOMh14pIgqfe9U3uFDFbv6+StgTs/Vt9ryNh540bqzo/Va+1tcwhr0CQ+VKRbwUiSPrH0mcDnkx1linnfDSi27hOXdL9V1F11RUjlZaaln75I31ixTbsyuO+9/uU3emwfzN7X2fYyx/BEDXuISLeCkSRtv+NDgc8zTSGGi0xqoetuv1LvO+d0fXH2OYHNy8n73o/ceoVOrQr3uO89lHpfZ7M4CYYXathDQD4rGDFNIYaLTGuh6XpfR+NOtzywKfC16X0NH1DDHgIyWcEoHaYpxHDUVy002/ve9L6GLwoS2GZ2lZm9YGa7zeyrAdsrzeyBxPa1Zja5EOctFUxTCBRWau/r1ElJJHpfwz95B7aZhST9UNJcSe+SdK2ZvStlt+slHXHOnSPp+5K+le95S0lymsJcME0hEIze1yg1haiazZS02zn3oiSZWaOkqyXt6LbP1ZL+d+L3X0j6gZmZSx1nMYwtaZgSOEFEX4ZSRxlgKKL3NUpJIZrE6yQ1d3u8N/Fc4D7Ouaiko5LGFODcJSNogoikd39jdeDzdJQBMkfva/jO8q3kmtk1kq5yzn0+8fgzki5xzn2p2z7bEvvsTTzek9jn9ZTXWixpsSSNGzfuosbGxoyvo7W1VTU1uTUrDxXtkZhefK1N8Qz+n5SZ6ewzqjPqWFMKZTMQKJf0KJtglEt6lE2wbMtl9uzZG51zM4K2FaJJfJ+k7mvZTUw8F7TPXjMrl3SqpMOpL+ScWyZpmSTNmDHDNTQ0ZHwRTU1Nymb/oWpzc0u/E0SEQ2VZzSVeKmVTaJRLepRNMMolPcomWCHLpRCBvV7SVDM7S13BvEDSn6fss0LSdZKelXSNpN9z/zpYJisYza0fzxAUABhm8g5s51zUzL4kabWkkKTlzrntZvY3kjY451ZI+qmkn5nZbklvqCvUkQYdZQAAqQoygNc5t1LSypTnvt7t9xOSPlmIcw03yY4yAIDhjZnOAADwAIENAIAHCGwAADxAYAMA4AECGwUVjcV17EREsRwXMwEABGOZJ+StIxrTyq37tbRpj3b1GDdeoy80TNG8+gmMGweAPBHYyMum5hYtTJmZLRLrql2/cLBVtz+4TXes2JHVzGwAgN5oEkfONje36Npla9TSHkm7ylhbZ0wt7REtWLZGm5tbinuBAFBCCGzkpCMa03XL16k9ktlyoO2Rrv07opkvHwoAeAuBjZys3LpfkVj85OOJo6u0+ubLTz6+4bKzdfOVU3scE4nFtWrrgaJdIwCUEgIbOVnatCdtM3g6bZ0xLW3aPUBXBACljcBG1mJxp12HWnM6duehVoZ8AUAOCGxkra0zqvKUlcOiMafuT1WGg99a5WWmts7oQF4eAJQkAhtZ23Ww9eTQraTXWzs0pqZStSPCqgiVac47xgYeG407VVcwmhAAssUnJ7KyublFn/7J2l7PR+NOdz+6S7++8VIdOHZCe14LbjI/d2wN63oDQA4IbGSsv6Fc9z7zsu595uW0x1dXhLSk4ZwBujoAKG0ENjKWOpQraf576rTw0smqCJk2Nbfo9oe2KahfWThUprn144twpQBQeriHjYwFDeWackaNPjxtgq5Z+ozm3f2UYnFp/vS6XsdWhUO6b9FM5hQHgBxRw0ZG0g3luvScMaqvO1UrvnSpJKkyHNLhto5e+/3bDZcwlzgA5IHARkaSQ7lSe4ebmX65ca++vfqFtMeWl0lTxtYM9CUCQEmjSRwZqa4oVzTgxvTTu1/X3PoJGlNdIUk6tSqsutqqHvvEnHIeysX62gDQhRo2MhIqM00dW6OdB3s2i+8+1Krv/vYF/ez6mTIzRWNOX//1Nu1raT+5T7ZDuVhfGwB6I7CRsSUNU3T7g9t6dTx7eMt+Pbxlf+Ax2Q7lYn1tAAhGkzgyNq9+gsKh7N4y2QzlYn1tAEiPwEbGKsu7hmZVhTNrjs5mKBfrawNA3whsZGXapFo1Lp6l2qqwqiuCg7i6IqTaqrAaF8/KuNk63aQs17//LK2++XKtvvlyLbp0co9trK8NYDjhHjayNm1SrdbeNkerth7Q0qbd2tmjY9hILWmYorn147PqGBY0Kcv5daP0yRkTNf+HT8tMeujGS7X2pTe0/dVjkt5aXztoopbBEo3FdTwSU3VFOXOmAygoAhs5qSwPaf70Os2fXqdY3KmtM5pzSKWblOXiyadp9faDJ5vJ/3PbAV08+bSTgS29tb72YIYjvdoBFAOBjbyFykyjTgnnfHy6SVkykVxfO5/z54Ne7QCKhXvYGHTpJmVZ99Ib+uC7xumUcJmqwiF96N3jtf7lN3rsM5jra9OrHUAxEdgYdMlJWVJtf/WYfrFxr3594/v10I2X6oH1f+zRHC4N3vra9GoHUGw0iWNISDcpy0+fekk/feqlwGMGc33tdL3al33mIk04tUqV4TL989Mv6d/XNZ/cluzVPpQ6yQHwBzVsDAkDPSlLoQX1apekv/rFFn3kB0/pI//0lBa+7yzVjnjr3nqyVzsA5IIaNoaE5KQsC5atyaiZeTDX107Xq12SPnfpZH3o3V1fIibUnqKzxlTrueMtJ7cPhV7tAPxEDRtDxkBNylJoyV7tqWadfZouPed0fexHT2vuPz6pHa8eU2W45z+xZK92AMgWNWwMKQMxKUuhpevVPvKUsI62R3QiEteUM6o1PeALxWD2agfgNz45MOQUclKWgZBuqdHHX3hNn7rkTD1y6xV68bVWPRcwjGuwerUD8B+BjSEt30lZBkpQr/bOWFwL/3l92mMGs1c7AP9xDxvIgW+92gH4j8AGcjCQS40CQBACG8iRL73aAZQG7mEDefChVzuA0kBgA3ka6r3aAZQGmsSBAkr2aiesh69oLK5jJyKKBYzVB/JBDRsA8tQRjWnl1v1a2rRHu3rcFqnRFxqmaF79BG6LIG8ENgDkYVNzixYuX6dILH5yXH4k1lW7fuFgq25/cJvuWLFD9y2aScdD5IUmcQDI0ebmFl27bI1a2iOBq7dJXau0tbRHtGDZGm0OmP0OyBSBDQA56IjGdN3ydRmtLidJ7ZGu/Tuime0PpCKwASAHK7fuVyQWz+qYSCyuVVsPDNAVodQR2ACQg6VNe9I2g6fT1hnT0qbdA3RFKHUENgBkKRZ32nWotf8dA+w81MqQL+SEwAaALLV1RlWe41j78jJTW2e0wFeE4YDABoAsVVeUK5pjLTkad6quYEQtskdgA0CWQmWmqWNr0m7/54UXa+zIysBt546tYSY85ITABoAcLGmYknaVts/du16H3uzo9Xx1RUhLGs4Z6EtDiSKwASAH8+onKBzK7iM0HCrT3PrxA3RFKHUENgDkoLI8pPsWzVRVOLM5wqvCXfszpzhyRWADQI6mTapV4+JZqq0Kp20er64IqbYqrMbFs5hLHHmhqyIA5GHapFqtvW2OVm09oKVNu7Wzx2pdI7WkYYrm1o+nZo28EdgAkKfK8pDmT6/T/Ol1isWd2jqjqq4opzc4CorABoACCpWZRp0SHuzLQAniHjYAAB4gsAEA8ACBDQCABwhsAAA8kFdgm9lpZvY7M9uV+O/oNPvFzGxT4mdFPucEAGA4yreG/VVJjzrnpkp6NPE4SLtz7j2Jn4/meU4AAIadfAP7akn3JX6/T9L8PF8PAAAEyDewxznn9id+PyBpXJr9TjGzDWa2xszm53lOAACGHXOu70XYzewRSUHLy9wm6T7nXG23fY8453rdxzazOufcPjM7W9LvJc1xzu0J2G+xpMWSNG7cuIsaGxsz/kNaW1tVU5N+fdrhjLIJRrmkR9kEo1zSo2yCZVsus2fP3uicmxG0rd/A7ouZvSCpwTm338wmSGpyzp3XzzH3SnrYOfeLvvabMWOG27BhQ8bX0tTUpIaGhoz3H04om2CUS3qUTTDKJT3KJli25WJmAxbY35F02Dn3D2b2VUmnOef+R8o+oyUdd851mNnpkp6VdLVzbkc/r/2apFeyuJzTJb2e3V8wbFA2wSiX9CibYJRLepRNsGzL5e3OuTOCNuQb2GMk/VzSmeoK1z91zr1hZjMkfcE593kze5+k/ysprq575nc5536a80nTX8uGdN9KhjvKJhjlkh5lE4xySY+yCVbIcslr8Q/n3GFJcwKe3yDp84nfn5FUn895AAAY7pjpDAAAD5RSYC8b7AsYwiibYJRLepRNMMolPcomWMHKJa972AAAoDhKqYYNAEDJ8jaws1h45Ewz+62ZPW9mO8xscpEvtegyLZvEvqPMbK+Z/aCY1zgYMikXM3uPmT1rZtvNbIuZ/dlgXGuxmNlVZvaCme1ODM1M3V5pZg8ktq8dDv9+pIzK5dbE58kWM3vUzN4+GNc5GPorm277fcLMXGLUUMnLpFzM7E8T75vtZvZv2Z7D28BW5guP3C/pO865d0qaKelQka5vMGVaNpL0t5KeKMpVDb5MyuW4pM86594t6SpJd5lZbfEusXjMLCTph5LmSnqXpGvN7F0pu10v6Yhz7hxJ35f0reJeZfFlWC7PSZrhnLtA0i8kfbu4Vzk4MiwbmdlISTdJWlvcKxwcmZSLmU2V9DVJlyY+X27O9jw+B3a/C48kCqzcOfc7SXLOtTrnjhftCgdPRouymNlF6pr//bfFuaxB12+5OOd2Oud2JX5/VV1f8AInMSgBMyXtds696JzrlNSorjLqrnuZ/ULSHDOzIl7jYOi3XJxzj3X7LFkjaWKRr3GwZPKekboqAt+SdKKYFzeIMimXGyT90Dl3RJKcc1lXHn0O7EwWHjlXUouZ/crMnjOz7yS+CZW6fsvGzMokfVfSV4p5YYMs08VqJElmNlNShaRe896XiDpJzd0e7008F7iPcy4q6aikMUW5usGTSbl0d72kVQN6RUNHv2VjZhdKmuSc+00xL2yQZfKeOVfSuWb2dGIhrKuyPUleE6cMtH4WHjnJOefMLKi7e7mkyyRNl/RHSQ9IWiip4DOtFVsByuaLklY65/aWUoWpAOWSfJ0Jkn4m6TrnXLywV4lSYWafljRD0hWDfS1DQaIi8D11fc6ip3JJUyU1qKtF5gkzq3fOtWTzAkOWc+7KdNvM7KCZTei28EhQ88JeSZuccy8mjnlI0iyVQGAXoGzeK+kyM/uipBpJFWbW6pzr6373kFeAcpGZjZL0G0m3OefWDNClDgX7JE3q9nhi4rmgffaaWbmkUyUdLs7lDZpMykVmdqW6vghe4ZzrKNK1Dbb+ymakpPMlNSUqAuMlrTCzjyZmwCxVmbxn9kpa65yLSHrJzHaqK8DXZ3oSn5vEV0i6LvH7dZJ+HbDPekm1Zpa8B/kBSX0uOlIi+i0b59ynnHNnOucmq6tZ/H7fwzoD/ZaLmVVIelBd5dHninIlYL2kqWZ2VuLvXqCuMuque5ldI+n3rvQnb+i3XMxsurrWSPhoLvciPdZn2TjnjjrnTnfOTU58tqxRVxmVclhLmf1bekhdtWtZ10JY50p6MZuT+BzY/yDpT8xsl6QrE49lZjPM7CeS5JyLqSuMHjWzrZJM0j2DdL3F1G/ZDFOZlMufSrpc0kIz25T4ec+gXO0AS9yT/pKk1ZKel/Rz59x2M/sbM/toYrefShpjZrsl3aq+RxyUhAzL5Tvqapn6j8R7JPXDuSRlWDbDToblslrSYTPbIekxSX+VWI8jY8x0BgCAB3yuYQMAMGwQ2AAAeIDABgDAAwQ2AAAeILABAPAAgQ0AgAcIbAAAPEBgAwDggf8PhUA2U68NcKUAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 576x576 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"# visualize the embedding\n",
"\n",
"plt.figure(figsize=(8,8))\n",
"plt.scatter(C[:,0].data, C[:,1].data, s=200)\n",
"for i in range(C.shape[0]):\n",
" # plt.text(x,y, the actual text)\n",
" plt.text(C[i,0].item(), C[i,1].item(), itos[i], ha='center', va='center', color='white')\n",
"plt.grid('minor')\n",
"\n",
"# visualize the char with the 2-D embeddin\n",
"# C learns to separate stuff\n",
"# aeiou are similar "
]
},
{
"cell_type": "markdown",
"id": "ce5424d2",
"metadata": {},
"source": [
"# 9. increase embedding size"
]
},
{
"cell_type": "code",
"execution_count": 94,
"id": "4825573a",
"metadata": {},
"outputs": [],
"source": [
"Vdim=27\n",
"Cdim=10 # hyperparamter\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"batch_size = 32\n",
"n_output1 = 200 # hyperparamter\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2, n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True"
]
},
{
"cell_type": "code",
"execution_count": 95,
"id": "fe484cca",
"metadata": {},
"outputs": [],
"source": [
"# run this a bit more times\n",
" \n",
"# find a reasonable learning rate \n",
"lri = []\n",
"lossi = []\n",
"stepi = []"
]
},
{
"cell_type": "code",
"execution_count": 96,
"id": "97328a1f",
"metadata": {},
"outputs": [],
"source": [
"for i in range(100000):\n",
" # mini-batch\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
" \n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[Xtr[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Ytr[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
" \n",
" loss.backward()\n",
" \n",
" lr = 0.1 if i < 10000 else 0.01\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
"# lri.append(lre[i])\n",
" stepi.append(i)\n",
" lossi.append(loss.log10().item()) # plot log-loss to squash it in"
]
},
{
"cell_type": "code",
"execution_count": 97,
"id": "91727b0a",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x12b1601d0>]"
]
},
"execution_count": 97,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAArvklEQVR4nO3dd3wUZf4H8M83nUAgkAJIAqFJEaRFBEFp0j1Qz1NQ8fRUPCvWO1BBsaHe/TjPEwXOwp0nIFaQqiiCgCCh10BCTWgBklDTn98fO7uZ3ezubJJNNjN83q8XL3Znnp35TmbnO888zzOzopQCERFZS1CgAyAiIv9jcicisiAmdyIiC2JyJyKyICZ3IiILCgnUimNjY1VSUlKgVk9EZEobN248pZSKMyoXsOSelJSElJSUQK2eiMiUROSQL+XYLENEZEFM7kREFsTkTkRkQUzuREQWxORORGRBTO5ERBbE5E5EZEGmS+55hcX4cmMG+KhiIiLPAnYTU0W9tXQPPllzEDF1wtCvTXygwyEiqpFMV3M/eS4fAHA+ryjAkRAR1VymS+5ERGTMtMmdLe5ERJ6ZLrlLoAMgIjIB0yV3IiIyZtrkzqGQRESemS65i7BhhojIiGFyF5GPReSkiOwwKHeNiBSJyG3+C4+IiCrCl5r7LABDvBUQkWAAbwH43g8xERFRJRkmd6XUKgBnDIo9DuArACf9EZQ3bJQhIjJW6TZ3EWkC4BYAH/hQdqyIpIhISlZWVmVXTUREHvijQ/UdAH9VSpUYFVRKzVRKJSulkuPiDH+8m4iIKsgfDw5LBjBXG8USC2CYiBQppb71w7I94khIIiLPKl1zV0o1V0olKaWSAHwJ4JGqTOx9rrTV+JvFRFbVKoiITM+w5i4icwD0BRArIhkAXgIQCgBKqelVGp0bdSJsIYeFmG6IPhFRtTFM7kqp0b4uTCl1b6WiKQc2yxAReWa66i+HQhIRGTNdciciImNM7kREFmS65M4HhxERGTNdciciImOmTe4cLUNE5JnpkjsbZYiIjJkuuRMRkTEmdyIiCzJtcldgozsRkSemS+4cCUlEZMx0yZ2IiIyZNrlzKCQRkWemS+5sliEiMma65E5ERMZMm9zZKkNE5JnpkrvwHlUiIkOmS+5ERGTMtMldcbgMEZFH5kvubJUhIjJkvuRORESGDJO7iHwsIidFZIeH+XeJyDYR2S4ia0Wkk//D1NFaY4pL2CxDROSJLzX3WQCGeJl/AEAfpVRHAK8CmOmHuDz6ftcJAMAHP6dX5WqIiEwtxKiAUmqViCR5mb9W93YdgAQ/xOXRqfP5AIDMnEtVuRoiIlPzd5v7/QCWeJopImNFJEVEUrKysiq1Iv5QNhGRZ35L7iLSD7bk/ldPZZRSM5VSyUqp5Li4OH+tmoiIXBg2y/hCRK4G8CGAoUqp0/5YJhERVVyla+4i0hTA1wDGKKX2Vj4k73jvEhGRMcOau4jMAdAXQKyIZAB4CUAoACilpgOYBCAGwPtaO3iRUiq5qgImIiJjvoyWGW0w/wEAD/gtIiIiqjTT3qHKsTJERJ6ZNrmz6Z2IyDPTJnciIvLMdMk9v6gYALD72NkAR0JEVHOZLrmnHj8X6BCIiGo80yV3IiIyZrrkzkfKEBEZM11yJyIiY6ZL7sIR7kREhkyX3ImIyBiTOxGRBTG5ExFZkOmSO0fLEBEZM11yJyIiY6ZL7hGhwYEOgYioxjNdcn/g+uaBDoGIqMYzXXKPCGHNnYjIiOmSOxERGTNdcudoGSIiY6ZL7oo/wUREZMgwuYvIxyJyUkR2eJgvIvKuiKSJyDYR6er/MEvlaT/WQUREnvlSc58FYIiX+UMBtNb+jQXwQeXD8owPDiMiMmaY3JVSqwCc8VJkJID/Kpt1AKJFpLG/AnTFNnciImP+aHNvAuCI7n2GNq1KHMvNq6pFExFZRrV2qIrIWBFJEZGUrKysCi0j9Th/GJuIyIg/knsmgETd+wRtWhlKqZlKqWSlVHJcXFyFVsY2dyIiY/5I7gsA3KONmukBIFcpdcwPyyUiogoKMSogInMA9AUQKyIZAF4CEAoASqnpABYDGAYgDcBFAPdVVbBEROQbw+SulBptMF8BeNRvERngaBkiImOmu0OVyZ2IyJjpkrvehfyiQIdARFQjmS6560fLvLZoVwAjISKquUyX3PUjIXMvFQYuDiKiGsx8yV2HyZ2IyD1TJ/fM7EuBDoGIqEYyXXLXD5Y5ePpiwOIgIqrJzJfcORaSiMiQ+ZJ7oAMgIjIB0yX30GDThUxEVO1MlynrhAcHOgQiohrPdMmdiIiMMbkTEVmQ6ZI7R8sQERkzXXInIiJjTO5ERBbE5E5EZEFM7kREFmS65N6jRYzT+7zC4gBFQkRUc5kuuTesGx7oEIiIajzTJXcOhSQiMuZTcheRISKSKiJpIjLezfymIrJCRDaLyDYRGeb/UN3befRsda2KiMg0DJO7iAQDmAZgKID2AEaLSHuXYi8CmKeU6gJgFID3/R2oIx6X919uPFJVqyIiMi1fau7dAaQppfYrpQoAzAUw0qWMAlBXe10PwFH/heidUtW1JiIi8/AluTcBoK8eZ2jT9F4GcLeIZABYDOBxdwsSkbEikiIiKVlZWRUIl4iIfOGvDtXRAGYppRIADAPwqYiUWbZSaqZSKlkplRwXF1ehFbE/lYjImC/JPRNAou59gjZN734A8wBAKfUrgAgAsf4I0FV8VERVLJaIyFJ8Se4bALQWkeYiEgZbh+kClzKHAQwAABFpB1tyr5J2l9g6YU7vT5zNq4rVEBGZmmFyV0oVAXgMwDIAu2EbFbNTRF4RkRFasWcAPCgiWwHMAXCvUtXT1bkilW33RESuQnwppJRaDFtHqX7aJN3rXQB6+Tc0IiKqKNPdoUpERMYskdyTxi8KdAhERDWK6ZI771kiIjJmvuTO7E5EZMh0yd2Tc3mFgQ6BiKjGMF1yVx4aZqb+sLeaIyEiqrlMl9yDPTx/oKiY7TVERHamS+4hwe5D5jNniIhKmS65ExGRMcskd1bciYhKWSa551ziaBkiIjvLJPf5W45i59FcPPrZJhQVlwQ6HCKigLJMcgeA4e+uxqLtx3Dw9MVAh0JEFFCWSu5ERGRj2eSennUeeYXFSM86j2p6tDwRUY3h0/PczeZ8fhFunrYGLeNqIz3rAiYMbYuH+rQMdFhERNXGkjX3SwXFAID0rAsAgClL9uC3A2cCGRIRUbWyZHIvcdMMc/uMXwMQCRFRYFgyud/14fpAh0BEFFCWTO5ERJc7JnciIgvyKbmLyBARSRWRNBEZ76HM7SKyS0R2ishs/4ZJRETlYTgUUkSCAUwDMBBABoANIrJAKbVLV6Y1gAkAeimlskUkvqoCriqnz+ejWCnER0UEOhQiokrzpebeHUCaUmq/UqoAwFwAI13KPAhgmlIqGwCUUif9G2bVUEphypLd2Hk0F91eW47ur/8Y6JAuS3mFxTiemxfoMIgsxZfk3gTAEd37DG2a3pUArhSRNSKyTkSGuFuQiIwVkRQRScnKyqpYxH50qbAYM1bux20fcJhkIP1p1gb0mMITK5E/+atDNQRAawB9AYwG8G8RiXYtpJSaqZRKVkolx8XF+WnVvissLnH6IW37cPhLhcXVHguVWpt+OtAhEFmOL8k9E0Ci7n2CNk0vA8ACpVShUuoAgL2wJfsapfULS9Dx5e/xt2V7cOKs92aAD35OR+rxcwCAT9cdQtL4RTh9Pr86wiQvtmXk4MNf9gc6DKIaz5fkvgFAaxFpLiJhAEYBWOBS5lvYau0QkVjYmmlq7BE4bUU6rn3jR+w6dtbtfKUU3lq6ByPeW43sCwX4+7JUAEBG9qXqDJPcGPHeGry2aHegwyCq8QyTu1KqCMBjAJYB2A1gnlJqp4i8IiIjtGLLAJwWkV0AVgB4TilV46+1/zDde1t7flEJbnh7BXK1X3l6et6WaojK5siZi1i641iVriOvsBhvLd2DPDZLEVmOT0+FVEotBrDYZdok3WsF4Gntn6l9tTEDreLrON6fyy9yvLY/iKw6DH/3F5zNK8LBN4dX2TpmrT2ID35OR63QYDwxoMa1ohFRJVjykb+V8cwXWwMdAgDgbF6R0/sjZy4isUGkX9dRUGT7OcLCGvSzhPO3ZCI4SHDT1VcEOhQiUzPl4wd+3zUhYOtekVr9Q/jXpp/C9W+vwLebXfuxrWfc3C14bPbmQIcRcKfO5yPnYkGgwyATM2Vyv7lL4Gp1W4/kVPs67aN2tgRg3eRfX27MwIFTxs17ya8tR+dXfqiGiMiqTJncA0kgbqdP/T4V01em+7SMpz7fghe/3e5T2eIS8/xEYGFxCaYs3u22xvnprwfxw64TAYiqZnn2i60Y9s9f/La8MxcK8PiczVi97xSSxi/Cbg8jwMpj0bZjSDnIH7cpr3FzNyNp/KJAh+HA5O4n7/6UhjeX7PGp7DebM/G/dYcd7w+fvoh3lu91/NbrY7M3Oebd+e91Pi0z+0L5L+HL89OySinMWJle5jEBJSUKa9JOQSmF8V9tx4xV+90OVZw4fyce/G9KuWP0p3kbjuCnPbYTTNrJ8ziWG5ihrf68ae6pz7fgu61HcfdHtt8wWLrjuM+f3Xk0FyVuKg+Pzt6E2wxGkrkqLC7Bwm1HTfd7xQ99moI3FvtnaO38LUf9shx/YXIvp38s34tJ83dU+PO5FwvLTLtv1m94Z/k+LNpuG/q4cFvpEMj1B85g8ne7ynxGb+mOY+jy6g9ef0rw8OmLWLK9dLlr0k7hH8v3AoCHaxFn+09dwJQle/Dn/210mv7h6v2468P1eHreVny1KQNAaQft5O92Ivm15U7lH/jPBh/WVnknz+aVGeL5l6+24U+zbCeYG6euRM8pPznN33goG0njF2FHZi5+O3AGU79PxbwNR/yWsBZv9//Q1pxLZb9PvtiWkYPh767GeyvSfP7Mgq1Hy9zIdyG/COfzi/Duj/vw2OzNaDdpKXYddX/1kDR+Eab4KZH6y7KdJzBz1X7szzpvuhOTEVMm98b1Avvkxv/+esiwTNrJ85iyZHeZmpG7sfL22vBjszfjo9UHPC5z1tqDmLYiDUnjF+HgqQsoLC5B9oUCrNeSurcx64PfWYWHPyu9Iijvr1Vtz8gFAKfHNwBwtB+764v4ZM1BnHJJBst3n8SF/CIkjV+Ej1YfcNxDUBF5hcXIvlCArUdyMP6rbU4HZ/c3fsT95TyRLN9tq9Wv3JuF22f8ind/SsNfvtqGVftOAQDeWLzb6apKb+R7q5E0fhG2ZeSgpERBKYX8otJ9se/EOTzymfvPAkBmju0q4suNGfjD9LXlilvvnz/uK7OP7FbuzULS+EUY8d5qxw152zNzfVruybN5eGLOZjz0qfPJvcPLy9DhpWWO73BeYQnGzd2M5btOuG3ambHKP/c2Hsu9hC6vfI+0k+f9srz+/7cSs9YehFIKn6w5gPP5RcYf0izefgzr9te823pMmdxbxUcFOgTM23AESeMXub0V/oOf03Hj1JWYsXI/Zqzaj4XbSi/XTp4rTXZFxSVoMWERLhSUJoFXF3qvpf9Nu1u2799/xrNfbEWXV39wNK9sPJSNyd/tdPs5b00Bpy4UIL+oGE/P24JjuZdwNq/Q0da/RPviPvn5FgCAa93GkU911f9tGbn4amOGx/V9rY36+c/ag3jGw41hf1vmvYlr/pZMjPloPbq8+gNGTluDuRuO4Owl5wNyTZrxATf2vynIyL4IwHMz1f6s85j6fSpmrtrvdFWlt1U7+Y14bw0+Wn0AN/1rNdq8uBQbD2UDAC4WOP/9d+iS6tr0U+j15k+YvyUTz36xFRsOZhvGXVKicNeH69yeVF2bB/MKi/G/dYcwa42t4rAtIxdHtZOJr5XVAu1qzP45O3efVwAe+G+Ko2nnfH4RZq8/XLZgJSzadgzZFwvx2XrnilZ+ke3GvIsFRVp8Cm8v3YP9WcYngU2Hc7A67RQmf7fLbXNoQVEJ7vz3Ory8wPkYe+SzTRg107j5dOOhMxjz0XoUVdPQY45zr6B//rgPAPDaot3o0SLGMb2wuARvLS09uOyvY+uEo3tSA6cO0ns+/g2V6S+1t/GJLrFm5pT/0bmHTl/AT7tP4utNmTh1vgCr9mbhjz2bYfLIDk61fQDIuViISwXFWLA1E3/9ajsiQm31g/26G7wOnLrgdL+AayfTxG9tzVoKCsd1z/i575PfHK+nrUjHgHYN0bVpfQC2E6HoNnTc3C3l3k53vt91AqEhQZh2Z1fHNHFpp/LWLFZUXIIsl6uT3cfPYqfWNLFw21F0Towu87nRM9dh++TBtvLHbKOh/rP2YJly2RcKUK9WKIKCnIM6l1/k8eTl+piMt5em4uM1nq8I7UpKFApL3Ccecf2j+GjUzF+x+XAO8os8J7SZq9LR58p4tGnkW6VtW0YODp62fd9cBzh8tu4wPvg5HcEieHZwG2TmXML7P6fj4zUHsOfVoV6Xe/jMReQVlmjrKD35XvniErRvXBe9WsVgbfpprE0/jZdHXOVTrICtKfbX/afw2qLdyMi+hGO5eX6/Z8UdJvcKytTVYG7612rH6683ua+x5lwsxO+nr3V6nk1VPA1RKYWXF+zEje0aonfrWLdlClwOtDVppzGgbUMAwKq9tkcxf7kxA/f3blHms2cuFKDv31fgxFlbQrMfDBVx5Mwl1GsS6ni/ItX5MdC3vr8Wc8f28KlWBACdXvkeAJD2eulBrJTCqwt3O125eBrRoLTrkvSTvt+J/MbiPWUSp75j7ZM1B5F1Lr/M3/xcfhH2njiHKxuWJrRNh3PKLL/Lq7bhkKv/2g+T5u/ET3tO2u5a9lIpWLnX+e94+oK3B97ZFpR7qRCdJn+PmNphXsp6Xa2Dvqlk3X7P/UBvLd2D+KhwvLF4D/6+bC/2avtt7m+HUVBcgnt6JiH7QgFCggVREaXfkxHvrXG8/nLjETSPjcSYnkkA4DiJ2Pt97FcW9u/pybN5gADxURF41KXisvVIjtv+p4KiEmw5kuM0FFkphSNnLmHfyXNlyr+zfC+e6N/acUIeMHVlmebJ6sDk7mdzfjvidrprR6Q/6WsvB09fwC/7TmHW2oNIblYf2RcL8OMzfZ3K93rrJ7h6xaU56EJBMW742wq367Mndn8oLPKeLirSAfeFrknoXz+lGdZaM7Mv4aFPU9C4Xi0AcHQMe1JSonA09xIS6kdi9m9l+19ch696asoZ9I9VeKhPC8d9DN7syMzFT3tsN9DlXix0nIg8ST1+DkmxkQgPCXY73z6iSSlbf8r6A7aKxmndqKuBU1fih6f7ALB1nALAsdw8DPi/n/HjM32x+bBx85EnxSUKH/xcOnS4oLgEr3y3C59vOOxopry2eQwGv7MKYSFB+HxsD7RrXLdMH83ZvCJMnL8T+UUlGNOzmaOd392Vxvr9p3GHVlHYOXmwYwCDnv5jZ/MKUVd3UtFrPmGx2+kA8M7yfejRIsZxRR+IxA4wuftdIG40WrC1tKZ45EzpFUWK1t47VDeuesuRHGSdqzmPLk494T2xbc3wrcNPb8LXpfcQTP1hr2H58u6zEdNWY0fmWcx5sEelrlwAYMbK8ncwdnrle2yeONBrmcHvrMJ1LWPw8b3XeB2ipwD87r3Vbuft09XA9SfZ9KwLOJZ7Cbe8X9rxu6kciT4j+yJ6v1W24uB6Eh78zioAtprzLe+vRbvGdfHRH5PdLvO1RbudhuDak/T/1pWefO/QXQFe9dIyt8vR9yFc/fL3+MuQNt43xoOiYoUdmbmOTvpAYHK3AKOagf7GlpunrfFSknyxI9P29xzt4z0I/vDit86deL+knTL8zNr002g7canXMvarAU/O5RXiXF5RmSYz12Gk5Xmo3s+pFfsVtt3HznrtqNezXxWUd3TONy6P+NBfXZSHiHNzrd71b6/At4/2ctsX408SqLGdycnJKiWl4je11KQ7wYjIGkKCBEV+uCt8+t3dDJtiK/rEVxHZqJRyfwmjY8qhkEREVcEfiR2o2j42XzG5ExFZEJM7EZEFmTa522+eISKiskybIddNGBDoEIiIaizTJvfoSO930hERXc5Mm9yJiMgzn5K7iAwRkVQRSROR8V7K/V5ElIgYjsEkIqKqY5jcRSQYwDQAQwG0BzBaRNq7KRcFYByA8j0ovBKSm9WvrlUREZmKLzX37gDSlFL7lVIFAOYCGOmm3KsA3gJQ/mfOVlDTmKp/bCYRkRn5ktybANA/6jBDm+YgIl0BJCqlvD4TQETGikiKiKRkZVXs+RJ6r93cAeMGtK70coiIrKbSHaoiEgRgKoBnjMoqpWYqpZKVUslxcXGVXTUiw0Lw1MArK70cIiKr8SW5ZwJI1L1P0KbZRQHoAOBnETkIoAeABexUJSIKHF+S+wYArUWkuYiEARgFYIF9plIqVykVq5RKUkolAVgHYIRSquKPfCQiokoxTO5KqSIAjwFYBmA3gHlKqZ0i8oqIjKjqAImIqPx8+rEOpdRiAItdpk3yULZv5cMiIqLK4B2qREQWxORORGRBlkju4SGW2AwiIr+xRFb85L5rAh0CEVGNYonkfl3L2ECHQERUo1giuQNANz5EjIjIwTLJ/auHrwt0CERENYZlkjsREZWyVHI/+ObwQIdARFQjWCq5ExGRjeWS+6P9WgY6BCKigLNccn9ucFs2zxDRZc9yyd2bMT2aBToEIqJqcVkl96sT6lXoc20bRfk5EiKiqmXp5H5ju3jH66jwELRrXLdMmbCQIHz0R+cfjfr3PckY2qGR4/3tyYn4+dm+VRYnEZG/+fQ8dzOyt7v/uPsEmsfWRou4Osi+UOBU5vlhbXF7ciKiI8Ocpg9s3xAD2zfEsdxLGDdnC27p0gT1azuXISKqySxdcweAAe0aokVcHQBA/dph+HVCf/RuZXsWTeuGUWUSu17jerUw7889HYn9sweu9Wmdc8f2qGTU5G/PDW6DTonRgQ6DqNpYPrm7alyvluMRwaKb/uuE/oaf7dXKtweU9WgR4/Q+OEg8lDT27uguPpdNfW2I1/kHpgyrcByVseGFGwOyXr1H+7XCTR0bBzqMMvq1iQt0CBQAlckJvrrskjsATLm1I+7rleSowQO2pP/to70w50H/1LrHDWjteO3LbhzWsZHb6SM6XeHT+oZ2aITwkGCfytrpa7LpbwzDdS2dT0rXtYyB+OE7GBcVXvmF+MHA9g29zk9sUKvKY/jgrq5O7+tEhFbp+v7QLaFKltsitrZP5e66tmmVrD+QerRoUOllPD+snR8i8e6yTO7xdSPw0u+uQkiw8+Z3ToxGT5cEZ6R5bG23CbBz02jH65u7NHG8/p2bZL1l0kCnxLxz8mDMuu8a7H1tKADginoRAIDxQ9s6fU7/JMz37nROGu6IS6D/HtMNANC/bTyCgwR/HeK8/PiocKeDePnTN6BOuG/dNMuf7uN2+vPD2rqd7ouRnb2f6FrF13G8nn53N9zsUj4ptjbWTRjg8fOeYvbm3/c4d8a/9fuO2DF5sMfy+qu6h/q08GkdU27t6FO5P/dpie0vD3Ka1lL3N9EL0dUc3Q00MFK3lm8npddv6YhlT97geTkRtu9Tzxbej7terUrnd2hiHO/gq7yfyCvD29V0dGTp3+WJ/q3QVZcH9P7UK8nPUZV1WSZ3f0h/Yxh++Us/zH+sl9v59kOnc2I03tQdnO4qwtGRYbinp20Mfscm9VA7PAR928QjzN58pCXl4R0bOx0o+idhervM69823nGi0IuLCsfBN4fj43ttP3ZyRbRzzTVIBLO1K5nuzRugVXyU28T16f3dy0xrFV8Hb97aERtftDXJvHlrR0y/uxtGdGpSpiwArP5rP3x4T7JTTXPby4OcTib/HNUFD17fHK/f0gEA0CwmEgemDMPBN4cj7fWhWPbkDYjR+kfqR4aiUb2yNfFG2onSnfCQYMx7qCdae0iIrv41ugsGtm+Ig28Ox61dm6BZTCRu6ZLg9gS4/eVB+HVCf6eO+YTosvFtnjgQr97cwWlaX4Omm54tYjD7gWsxfmhbRLlcCXjaFv1Jacm4670u3+6hG1pg1XP9MPuBax0Vmvt7Nzf8XBvdUOKJN7XHHcmJAIBaocGOk08XD0nQ7nbtM08MaI2FjxvHO2NMMmLrVHwQhP3k19lNP018VAQ6NnEeVv1w35bYMXkwUnRNkEEejslXR15VpqJVFXxK7iIyRERSRSRNRMa7mf+0iOwSkW0i8qOIWPZuoRljumHKrR0RHCRIbBCJuhGh+PCeZKcmHr2oiBCnKwRP+7RL0/pY+HhvzHuoZ5l5+s+08XHM/VVXlNZunh/W1nGisHvvzi5lvmBxUeF4TZ9YBGhYNwKLn7jecQJwt/zrW5cmn/FD26J7ku2ydVT3poipE+54PaRDI4/bn1A/Eje2b4i//aGTY1rdiNAyJ5MXhrfHXdc2w9Inr8f8R3s5tiEkOAjBQYIWcaVXGs1jI92vzA17Z3n35g3ww9N98NsLnmv4dvqrsKm3d8bK5/o5/s4/PdMHmyYOdMyPighFY9eTjQge7tMSjepGYPaD12L98wNQv3YYxvRo5ri34vH+rcp+DkDTBqXbNmdsD1yn+/49oTUJ/vDUDRjQriGWPXkDnhl4pcu6DTevjAnD2qFpTCSuaxWL2mG2E9jwq0v7MVxr1K+MvKrMMu7u0dRxJdsxoZ7HMGqFOjcxxtYJx/43huGpG1uXKWvPofba/QOOE47z0vVNn5snDsTDfZ0fVWIfEr3t5UH47vHeuLZ5A8wd2wMH3xxepnKk//vb11Qn3HasP9G/lTat7Na99Lv2GNMzqewGVwHDa2wRCQYwDcBAABkANojIAqXULl2xzQCSlVIXReRhAG8DuKMqAg60wVeVbRsf0K4hBrRriKTxixzTXBPne3d2wWOzN2PIVY0wf8tRAMCcB3sg7eQ5R5kOTdzfZNXhinrIyL6E8FBb4ph4U3u0aWg7+OuEh2C4m47CRU9cj1vfX4NNh3MQHFSa2JeMux4J9WuVqeHZ3dYtAZPm70CJKv1ytr/C+aBd+HhvJDaIxKB/rMSJs/lO8/7cpyX+3Me35/vc37s5Plp9oMyB7Iu2jbxfmosI6mrb6Msluq+d5XY3Xe29c9Y+Qssbge1vu+75sieSbx7phWU7jzs16QG2prHU4+fxzeYMHD5z0e1yn7qxNUZdk+i4EmvTKAptGkVh+e4T2JqR61h3ZUy9oxNmrz+MLonR2P7yIJzNK0JkaDB2HTuLuz5cD8B5H4UFB6GguASALak3rBuOZwe1waJtR7HpcI6jEgAA3z7aC+0aR6HNi0sB2JpBbP0/pVHvemUw2k9aBgDo3ToOq/ZmOb6vSitjLz7tzq5IbFALVydEI/tCAY5kX0T92mF4dlAbdEqohz//bxMA23fAPoS6buNQfK6raNnvh9l4KBsAUKIUPOnePAZAGq5Jqo+f9550mleRJrCK8qUBtTuANKXUfgAQkbkARgJwJHel1Apd+XUA7vZnkGbUKaEewoKD8Ehf21n8pquvwJCrGjnV4nu2jPGpjX/qHZ1w/9HmiI+yNSnoL4W9te9Ov7sbFm47hua6dnOjL1dEaDA2TxqEez/5DU8MaOW2jP0k9N3jvXHotPsE44n98IyLCsfEm9ojKSYSPavoZxLtB7eX49CjBpFh6No0GpsO53hYdsXT4+juiZjz2xEEeVlGrbDgMokdAFrFR6FVfBROnc/H8t0n3XzSFptrExsAzH+sN8Z8tB6/7DtlGP8n916DGavSsW7/Gbfz46Mi8OSNtquBqIhQR2WhV6tYjB/aFm8u2YMrot03gdUJD8H6523NF50To3Fr1wRcdUVdvLpwl2OanrtBBZFhpalr3IDW2JmZi65No7E67ZRjeph2rHVrVt/RHFe/dpijaSw4SDCkQ2OEhwQhv6jE698DKK3EAaXJ/aor6mLn0bNOV6S9W8dix+TBtua571OdllGR72JF+dIs0wTAEd37DG2aJ/cDWOJuhoiMFZEUEUnJysryPUqTaKxrz42ODMPe14c6JW/XDlxfRYaF4Jok33roFzzWC88Osh108XUj8Ccf2kRd1asVim8e6YVmMd5HRMRHRfgclydjeiY5dYRW1uu3dES/NnHolFgPcKnJuZp6eycPc2z76utHejk6yP45qrPT/CFuruA8ca3lx2kn6fqRvo+UWTO+P1JeLG3PtffRRJdjGYCtySYqPASdE6LdzrePbOrXNr5c26j30A0tsGXSQCTUL226uFMbNRMS5HwMhIUEoVNiNEKCg/DNI9dhga4Pq3/bePiiW7P62DhxIO7v3QI3XBmHh26wdVR/en93PN6/FRrW9f9oLXuS9nSOtPe72K9K7UNeE+pX/YgsO7/eoSoidwNIBuB22IFSaiaAmQCQnJxcjeew6rF03A3IuVRgXLAKXZ0Qjas9HLg1SXlqMF893BPHc/ONCwK4smEUPrnP1sFrr625Phtoy6SByLlYiKTY2rjqino4cTbP4/IWPt4bOzJzMaRDY/yw6wQWbjuGd+7o7NTW7I27J5Q+3r8VWsbVxpAOvifPJi41cRHBrxP6O9VgfXFNUgNs93K1980j1yHlYLZjHRUhImVuDpx0U3s8P6yd147/Lk2dfwf5g7u74lxekcfy9/VKQr82pSeAepGh+O+fSjv3W8TVwTOD2vgcd3m+k46mH4MGrnfu6IzP1h/GuAGtcb6gyNFUWB18+WZkAkjUvU/QpjkRkRsBvACgj1LKtyPRYupFhqKeDzWpni1ikFdUXA0R1UAVyBfdmlXs6qBzYjS+evg6dHJ5YFx0ZJgj+djboz1JqB/pqIHaL/NDK3gFZhcaHISRnb1d/PrGXUdreTzaryVSj59zmqbfXn8KChKElfPGnfCQYITX8dwf89LvynbYVkRFzmH2E4HRJsXXjcBTWmd2dSZ2wLfkvgFAaxFpDltSHwXgTn0BEekCYAaAIUop9w2B5DDHgo8n+NfoLmVGEHhXPRdu+nsBKmviTe0RUyesSsdQV6fnBpfec9CvTRyO5jhfwVTHXZQ1QdMGkdh74nw5k7xBu0wNYJjclVJFIvIYgGUAggF8rJTaKSKvAEhRSi0A8DcAdQB8oV3KHVZKjajCuKmGcXdzljtGl7F2T7sO3asB6tcOwwvD2wc6jCphb8rSu61bAvadOIfCEoWsc9a9GP/fA9di06FsRJRj1Naoa5pi+e6TeHbQlRjz0W8Y3tG37391ElWd3bc6ycnJKiUlJSDrpsDJOpePa15fjpjaYdioGwdORL4RkY1KqWSjcrxDlapVDb6KJbIUJneqVvax3eW5BCai8rPsj3VQzdSgdhieG9zG7V21ROQ/TO5U7R7t5/7OVyLyHzbLEBFZEJM7EZEFMbkTEVkQkzsRkQUxuRMRWRCTOxGRBTG5ExFZEJM7EZEFBezBYSKSBeBQBT8eC+CUYSlr4TZfHrjNl4fKbHMzpVScUaGAJffKEJEUX56KZiXc5ssDt/nyUB3bzGYZIiILYnInIrIgsyb3mYEOIAC4zZcHbvPlocq32ZRt7kRE5J1Za+5EROQFkzsRkQWZLrmLyBARSRWRNBEZH+h4ykNEEkVkhYjsEpGdIjJOm95ARH4QkX3a//W16SIi72rbuk1EuuqW9Uet/D4R+aNuejcR2a595l2RmvGrpSISLCKbRWSh9r65iKzX4vxcRMK06eHa+zRtfpJuGRO06akiMlg3vcZ9J0QkWkS+FJE9IrJbRHpafT+LyFPa93qHiMwRkQir7WcR+VhETorIDt20Kt+vntbhlVLKNP8ABANIB9ACQBiArQDaBzqucsTfGEBX7XUUgL0A2gN4G8B4bfp4AG9pr4cBWAJAAPQAsF6b3gDAfu3/+trr+tq837Syon12aKC3W4vraQCzASzU3s8DMEp7PR3Aw9rrRwBM116PAvC59rq9tr/DATTXvgfBNfU7AeA/AB7QXocBiLbyfgbQBMABALV0+/deq+1nADcA6Apgh25ale9XT+vwGmugD4Jy/mF7Alimez8BwIRAx1WJ7ZkPYCCAVACNtWmNAaRqr2cAGK0rn6rNHw1ghm76DG1aYwB7dNOdygVwOxMA/AigP4CF2hf3FIAQ1/0KYBmAntrrEK2cuO5re7ma+J0AUE9LdOIy3bL7GbbkfkRLWCHafh5sxf0MIAnOyb3K96undXj7Z7ZmGfsXyC5Dm2Y62mVoFwDrATRUSh3TZh0H0FB77Wl7vU3PcDM90N4B8BcAJdr7GAA5Sqki7b0+Tse2afNztfLl/VsEUnMAWQA+0ZqiPhSR2rDwflZKZQL4O4DDAI7Btt82wtr72a469qundXhktuRuCSJSB8BXAJ5USp3Vz1O2U7NlxqeKyE0ATiqlNgY6lmoUAtul+wdKqS4ALsB2Ke1gwf1cH8BI2E5sVwCoDWBIQIMKgOrYr76uw2zJPRNAou59gjbNNEQkFLbE/plS6mtt8gkRaazNbwzgpDbd0/Z6m57gZnog9QIwQkQOApgLW9PMPwFEi0iIVkYfp2PbtPn1AJxG+f8WgZQBIEMptV57/yVsyd7K+/lGAAeUUllKqUIAX8O27628n+2qY796WodHZkvuGwC01nrgw2DriFkQ4Jh8pvV8fwRgt1Jqqm7WAgD2HvM/wtYWb59+j9br3gNArnZptgzAIBGpr9WYBsHWHnkMwFkR6aGt6x7dsgJCKTVBKZWglEqCbX/9pJS6C8AKALdpxVy32f63uE0rr7Tpo7RRFs0BtIat86nGfSeUUscBHBGRNtqkAQB2wcL7GbbmmB4iEqnFZN9my+5nnerYr57W4VkgO2Eq2JkxDLZRJukAXgh0POWMvTdsl1PbAGzR/g2Dra3xRwD7ACwH0EArLwCmadu6HUCybll/ApCm/btPNz0ZwA7tM+/BpVMvwNvfF6WjZVrAdtCmAfgCQLg2PUJ7n6bNb6H7/AvadqVCNzqkJn4nAHQGkKLt629hGxVh6f0MYDKAPVpcn8I24sVS+xnAHNj6FAphu0K7vzr2q6d1ePvHxw8QEVmQ2ZpliIjIB0zuREQWxORORGRBTO5ERBbE5E5EZEFM7kREFsTkTkRkQf8PkXT9MEO3dNkAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x288 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.plot(stepi, lossi)\n",
"# there is flucturation because of mini-batch\n",
"# increase batch size to be more stable"
]
},
{
"cell_type": "code",
"execution_count": 98,
"id": "6218b79e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss 2.2576844692230225\n",
"dev loss 2.2767961025238037\n"
]
}
],
"source": [
"# compute loss on dev dataset\n",
"emb = C[Xtr] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ytr)\n",
"print ('train loss', loss.item())\n",
"\n",
"# can have even smaller loss with smaller learning rate\n",
"emb = C[Xdev] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss = F.cross_entropy(logits, Ydev)\n",
"print ('dev loss', loss.item())"
]
},
{
"cell_type": "markdown",
"id": "b496f3e3",
"metadata": {},
"source": [
"# 10. generate new words with the trained model"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "4e888e6a",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"careah.\n",
"ambyn.\n",
"vih.\n",
"jors.\n",
"revty.\n",
"salaysy.\n",
"jazhnen.\n",
"amerync.\n",
"kaqii.\n",
"nelania.\n",
"chaiiv.\n",
"kaleig.\n",
"halm.\n",
"joce.\n",
"quinn.\n",
"shon.\n",
"raiviani.\n",
"wanthon.\n",
"jaryni.\n",
"jace.\n"
]
}
],
"source": [
"g = torch.Generator().manual_seed(2147483647+10)\n",
"for _ in range(20):\n",
" \n",
" out = []\n",
" context = [0] * block_size # we start from ...\n",
" while True:\n",
" # embed current context using C\n",
" # 1 dim training set size \n",
" emb = C[torch.tensor([context])] # (1, block_size, d), d is the embedding vector C \n",
" h = torch.tanh(emb.view(1, -1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" probs = F.softmax(logits, dim=1)\n",
" # from the prob, sample by probability to get the next index \n",
" ix = torch.multinomial(probs, num_samples=1, generator=g).item()\n",
" # shift to the next char\n",
" context = context[1:] + [ix]\n",
" out.append(ix)\n",
" if ix == 0:\n",
" break # stop \n",
" \n",
" print (''.join(itos[i] for i in out))"
]
},
{
"cell_type": "markdown",
"id": "3897ee18",
"metadata": {},
"source": [
"# 11. hyperparameter tuning"
]
},
{
"cell_type": "code",
"execution_count": 122,
"id": "805a4253",
"metadata": {},
"outputs": [],
"source": [
"Cdims = [2, 30, 50]\n",
"n_output1s = [100, 200, 300]"
]
},
{
"cell_type": "code",
"execution_count": 123,
"id": "fa12acef",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6min 14s, sys: 1min 35s, total: 7min 49s\n",
"Wall time: 7min 9s\n"
]
}
],
"source": [
"%%time\n",
"loss_list = []\n",
"\n",
"for Cdim in Cdims:\n",
" for n_output1 in n_output1s:\n",
" Vdim=27\n",
" block_size = 3 # hyperparamter\n",
" n_input1 = Cdim * block_size\n",
" batch_size = 32\n",
" n_input2 = n_output1\n",
" n_output2 = Vdim\n",
"\n",
" g = torch.Generator().manual_seed(2147483647)\n",
" C = torch.randn((Vdim, Cdim), generator=g)\n",
" W1 = torch.randn((n_input1,n_output1), generator=g)\n",
" b1 = torch.randn(n_output1, generator=g)\n",
" W2 = torch.randn((n_input2, n_output2), generator=g)\n",
" b2 = torch.randn(n_output2, generator=g)\n",
" parameters = [C, W1, b1, W2, b2]\n",
"\n",
" for p in parameters:\n",
" p.requires_grad=True\n",
" \n",
" for i in range(100000):\n",
" # mini-batch\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
"\n",
" # forward pass\n",
" emb = C[Xtr[ix]] \n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Ytr[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
"\n",
" loss.backward()\n",
" \n",
" lr = 0.1 if i < 10000 else 0.01\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad\n",
" \n",
" # compute loss on dev dataset\n",
" emb = C[Xtr]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss_train = F.cross_entropy(logits, Ytr)\n",
"\n",
" # can have even smaller loss with smaller learning rate\n",
" emb = C[Xdev]\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss_dev = F.cross_entropy(logits, Ydev)\n",
" \n",
" loss_list.append([Cdim, n_output1, loss_train.item(), loss_dev.item()])"
]
},
{
"cell_type": "code",
"execution_count": 124,
"id": "9712fb33",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd"
]
},
{
"cell_type": "code",
"execution_count": 125,
"id": "030c17b5",
"metadata": {},
"outputs": [],
"source": [
"df = pd.DataFrame(loss_list, columns = ['Cdim', 'hidden_layer_size','train loss', 'dev loss'])"
]
},
{
"cell_type": "code",
"execution_count": 126,
"id": "2bb18918",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>Cdim</th>\n",
" <th>hidden_layer_size</th>\n",
" <th>train loss</th>\n",
" <th>dev loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>50</td>\n",
" <td>100</td>\n",
" <td>2.215908</td>\n",
" <td>2.255365</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>30</td>\n",
" <td>100</td>\n",
" <td>2.226047</td>\n",
" <td>2.261118</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>50</td>\n",
" <td>200</td>\n",
" <td>2.183852</td>\n",
" <td>2.274030</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>30</td>\n",
" <td>200</td>\n",
" <td>2.195534</td>\n",
" <td>2.276445</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>50</td>\n",
" <td>300</td>\n",
" <td>2.179892</td>\n",
" <td>2.299338</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>30</td>\n",
" <td>300</td>\n",
" <td>2.198899</td>\n",
" <td>2.305203</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>2</td>\n",
" <td>300</td>\n",
" <td>2.326979</td>\n",
" <td>2.324543</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>2</td>\n",
" <td>200</td>\n",
" <td>2.333926</td>\n",
" <td>2.331151</td>\n",
" </tr>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>2</td>\n",
" <td>100</td>\n",
" <td>2.368216</td>\n",
" <td>2.365333</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" Cdim hidden_layer_size train loss dev loss\n",
"6 50 100 2.215908 2.255365\n",
"3 30 100 2.226047 2.261118\n",
"7 50 200 2.183852 2.274030\n",
"4 30 200 2.195534 2.276445\n",
"8 50 300 2.179892 2.299338\n",
"5 30 300 2.198899 2.305203\n",
"2 2 300 2.326979 2.324543\n",
"1 2 200 2.333926 2.331151\n",
"0 2 100 2.368216 2.365333"
]
},
"execution_count": 126,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dfs = df.sort_values('dev loss')\n",
"dfs"
]
},
{
"cell_type": "code",
"execution_count": 154,
"id": "323b9fe5",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAZIAAAFzCAYAAAAUmo/dAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAA71UlEQVR4nO3de3iUd53//+c7mUnIoYQcbIohgNiiPSGHSMH225Zdqrvq91tQe9h1W1vrhbq625N7ratbtV+367oq1r12V1pttR6+Uqptf1V3raAFt1oKoVBooaUthYaDLYQQyIFkDu/fHzPAEGYmk9yZZEJej+vKxeRz3/fMOzeTeedzfz73+2PujoiIyGAVjXQAIiIyuimRiIhIIEokIiISiBKJiIgEokQiIiKBKJGIiEggoZEOYCjV1dX51KlT027r7OykoqJieAMagEKPDwo/xkKPDwo/RsUXXL5j7OyN0tYZoas3Sm80jpkd3+bulISKKC8JUV0RpqLk1I/4vvFt2LDhgLu/KUhMp1UimTp1Ks3NzWm3rV69mssvv3x4AxqAQo8PCj/GQo8PCj9GxRdcPmJ0dx7ZuIdvrtxOV2cvZZEY47LcAmgGFi6mrKKEW6+YzuJZDccTTt/4zGxX0PhOq0QiInK62dfezS3LN7FlTztdvbGcjnGHrt4YXb3dfP6R53hwfQt3XzuTiVVleYlRYyQiIgVq/c6DLFy6hg272nJOIn11R2Js2NXGwqVraN55cIgjTFAiEREpQOt3HuT6+9bR2RMjGg9Wyioadzp7Ylx337pBJ6RsdGlLRKTA7Gvv5obvraM7kv1Df2bjBO54/3nE3dm8+xBf/sW2rPt3R2K8eqCXfe3dQ3qZSz0SEZEC4u7c/JNN9ETi/e67p62bv/zOWq5a9hS1FaW8rf6MHJ4fblm+iaEs2KtEIiJSQB7ZuIfn9rbndDlrf0cPPdFEwonG48RySA6Os3l3O49u3BM41mOUSERECoS7J6b4DnAc4+1nnUFNRSkvv9GR0/7dkRhLV20fsl6JxkhERApE8642Wjt7B3RMVVmYO688n0//eOOAjmvt6GXDrrYBHZOJeiQiIgXioeaWfgfYUxUXGXdfM5N//uU29nf0DOi1uiMxVjS3DDTEtNQjEREpEM072xjI1ab3XTiRGZOq+If3ngvAv/7qBZ557VBOx7onekBDQYlERKQARGJxWtq6BnTMY8/u5bFn9w76NVsOdgFYf/v1R5e2REQKwIGOHkJFw/uRHCoyKA4F7lAokYiIFIBI1LHAfYOBMTNsCF5ViUREpACEQzag8ZGh4O74ELyqxkhERApAXWUp0Xj2u9nPPKOU+294J+ecWcl5X3ycWNy54/3ncmHDBJ7f286dP98KkLYtnWjcIRaNBo1dPRIRkQIQLi6isbo86z7t3RH+8rtr2dhyCIDz3zye8pIQV9/zFOHiImZMqkrblkljTTlA4B6JEomISIFomlqddcSiJxrncPeJDsSsydU8+dIBAJ58+QCzJ1enbUvHDJqmpN82UEokIiIF4qqmRsrCxTnvP35ciI6eRGI5cjTC+LJQ2rZ0ysLFXN3UGDxolEhERApG05RqaitKct7/yNEolaWJRFFZGuZwdzRtWzp1laXMUY9EROT0YmbcesV0ykty65U881obF59dC8AlZ9ex8bW2tG19lYWLuXXhOcfXcQ9KiUREpIAsntXAhQ1ViZsF+wgVGT+66SLOnTieH3x0LuHiInqicVZ8fD4xd57d3c7zew+f0pbKMGZMqmLRrIYhi1nTf0VECoiZcfe1M1m4dA3RnpMLOEbjzl/d9/RJbZuSM7hSZZvyawbfunbWkPVGQD0SEZGCM7GqjAdunDuggfdclIWLeUtdBWdVjRvS581bIjGzRjN7wsy2mtnzZnZzmn2uNLPNZrbJzJrN7JKUbZPN7Ndmti35HFPzFauISKFpmlrDD2+aS0VpcdrLXAMRKjIqSov54U1zcx5/GYh89kiiwO3ufh4wD/iUmZ3XZ5/fAO9w95nAR4Hvpmz7AfA1dz8XmAu8kcdYRUQKTtPUGlbddhlzplQPundSXlLMnCnVrLrtMpqm1gxxhAl5GyNx933AvuTjI2a2DWgAtqbsk7ouZAXJOyyTCSfk7ivT7CciMmZMrCpj+ZJ5PLpxD0tXbae1o5fuSCxrhSyzxGWs2soSbls4nUWzGoZ0TKSvYRlsT16WmgU8nWbbYuArwJnA+5LN04FDZvYw8BZgFfBZdx/YQsYiIqcBM2Px7EksmtXAhl1trGhuoXlXGy0HuwgVGWaGuxONO4015TRNqebqpkbmTKnOawI5Ht9QLf6e8QXMKoE1wF3u/nCW/S4FvuDuC83sQ8B9JJLPa8CDwH+5+31pjlsCLAGor6+fs3z58rTP39HRQWVlZdAfJ28KPT4o/BgLPT4o/BgVX3DDGaMD0ZjjOIYRKrZ+V6nqG9+CBQs2uHtTsEDc8/YFhIHHgdty3H8HUEdiTGVNSvt1wH/0d/ycOXM8kyeeeCLjtkJQ6PG5F36MhR6fe+HHqPiCK/QY+8YHNHvAz/p8ztoyEr2Kbe6+NMM+Zyf3w8xmA6VAK7AemGBmb0ru+iekjK2IiEjhyOcYycUkehJbzGxTsu1zwGQAd18GfBC43swiQDdwTTJDxszsM8BvkolmA/CdPMYqIiKDlM9ZW0/Sz6Ly7v5V4KsZtq0EZuQhNBERGUK6s11ERAJRIhERkUCUSEREJBAlEhERCUSJREREAlEiERGRQJRIREQkECUSEREJRIlEREQCUSIREZFAlEhERCQQJRIREQlEiURERAJRIhERkUCUSEREJBAlEhERCUSJREREAlEiERGRQJRIREQkECUSEREJRIlEREQCUSIREZFAlEhERCQQJRIREQlEiURERAJRIhERkUCUSEREJBAlEhERCUSJREREAlEiERGRQJRIREQkECUSEREJRIlEREQCUSIREZFAlEhERCQQJRIREQlEiURERAJRIhERkUCUSEREJBAlEhERCUSJREREAlEiERGRQJRIREQkECUSEREJRIlEREQCyVsiMbNGM3vCzLaa2fNmdnOafa40s81mtsnMms3skj7bx5vZbjP793zFKSIiwYTy+NxR4HZ3f8bMzgA2mNlKd9+ass9vgMfc3c1sBrACeHvK9i8Dv8tjjCIiElDeeiTuvs/dn0k+PgJsAxr67NPh7p78tgI49hgzmwPUA7/OV4wiIhLcsIyRmNlUYBbwdJpti83sBeCXwEeTbUXAN4DPDEd8IiIyeHaiQ5CnFzCrBNYAd7n7w1n2uxT4grsvNLNPA+Xu/q9mdgPQ5O6fznDcEmAJQH19/Zzly5enff6Ojg4qKyuD/TB5VOjxQeHHWOjxQeHHqPiCK/QY+8a3YMGCDe7eFOhJ3T1vX0AYeBy4Lcf9dwB1wI+B14CdwAHgMPAv/R0/Z84cz+SJJ57IuK0QFHp87oUfY6HH5174MSq+4Ao9xr7xAc0e8LM+b4PtZmbAfcA2d1+aYZ+zgVfc3c1sNlAKtLr7h1P2uYFEj+Sz+YpVREQGL5+zti4GrgO2mNmmZNvngMkA7r4M+CBwvZlFgG7gmmSGFBGRUSJvicTdnwSsn32+Cny1n32+D3x/yAITEZEhpTvbRUQkECUSEREJRIlEREQCUSIREZFA8jlrS0RkxEVicQ509BCJOuGQUVdZSrhYf0MPJSUSETmtuDvNu9p4qLmF5p1ttLR1ESoqwgzcIRqP01hdTtPUaq5qaqRpSjWJ295ksJRIROS04O48snEP31y5ndbOXrojMY7dlRaJxU7ad8eBTl5t7eQXm/dRW1HCrVdMZ/GsBiWUQVIiEZFRb197N7cs38SWPe109cb6P4BE76SrN0ZXbzeff+Q5Hlzfwt3XzmRiVVmeoz396EKhiIxq63ceZOHSNWzY1ZZzEumrOxJjw642Fi5dQ/POg0Mc4elPiURERq31Ow9y/X3r6OyJEY2nr640s3ECP/vku3joE/O54/3nZnyuaNzp7Ilx3X3rlEwGSIlEREalfe3d3PC9dXRHsvdC9rR185ffWctVy56itqKUt9WfkXX/7kiMj3xvHfvau4cy3NOaEomIjDruzs0/2URPJN7vvvs7euiJJvaLxuPEcqgL2xOJc8vyTaiGbG6USERk1Hlk4x6e29ue8XJWOm8/6wxqKkp5+Y2OfveNxp3Nu9t5dOOeIGGOGUokIjKquDvfXLl9QAPrVWVh7rzyfP7+p5tzPqY7EmPpqu3qleRAiURERpXmXW20dvbmvH9xkXH3NTP5519uY39Hz4Beq7Wjlw272gYa4pijRCIio8pDzS39DrCnet+FE5kxqYp/eO+5LF8yj9mTJ+R8bHckxormlkFEObbohkQRGVWad7YxkKtNjz27l8ee3Tuo13JP9IAkO/VIRGTUcKClrWtYX7PlYBeRWP+zw8YyJRIRGTWiMSdUNLwfW6Ei48AAx1bGGiUSERk1HGe46yqaGZGoZm5lozESERk1DMs6PjKzcQJ3vP884u5s3n2IL/9iG0suncYV59Wzp62bzzz0LNG4p23LxD2xjolkph6JiIwaoWIjGs88XtG3HMpFb6lh/rRarlr2FC/88TDvPr+e2oqSU9qyicadusrSof5RTitKJCJS8Nyd9TsPsretOzHinkHfcijn1J/B2h2tADz58gFmT67mwklVp7Rl01hTrhUV+6FLWyJSsPouVvXJt/USiff/sXWsHMrh7sjxO9OPHI0yvizM+HFhOnqiJ7VlYgZNU7InGlEiEZECNZjFquBEOZRP/3gjFzRUMbFqHACVpSEOd0c4cjR6SlsmZeFirm5qDPaDjAHqr4lIwRnsYlV9y6Fs3n2Ii6bVAnDJ2XVsfO1Q2rZM6ipLmaMeSb/UIxGRgnJssaqBlEE5JrUcCsC//uoF1r3aykOfmM/eQ93c//tXicT8lLZ0ysLF3LrwHK3jngMlEhEpGLkuVpVumi+kL4fyzGuHWLZmx0lty9bsOKUtVajImDGpikWzGgb5k4wturQlIgVhIItVDXTVw4EqDRfxrWtnqTeSI/VIRKQgDGSxqtRy8LmuepirsnAxD9w4l7OSA/LSPyUSERlxg1msCga26mF/QkVGabiIB26cS9PUmsDPN5YokYjIiBvoYlVw8jTfoMpLirmwoYq7r53JxKqywM831iiRiMiIG+hiVUFWPTzGLHEZq7ayhNsWTmfRrAaNiQySEomIjLiBLlaVbprvM1nuB+krXGwsntXA1U2NzJlSrQQSkBKJiIyoSCw+4MWqgqx6eMxdiy9UDa0horMoIiPqQEePFqsa5ZRIRGRERaJarGq0UyIRkREVDmVfrCoftFjV0NIYiYiMqLrK0qyLVU2vr+QrH5hBLO5Uxjv4xpYt3PH+c7mwYQLP723nzp9vBUjblokWqxpa6pGIyIgKFxfRWF2ecfuO/Z188Nt/4Op7ngISdbbKS0Jcfc9ThIuLmDGpivPfPP6Utmy0WNXQ0pkUkRHXNLU64zhJasmUWDzGu95ay5MvHQBOrHA4a3L1KW2ZaLGqoadEIiIj7qqmRsrCxRm3Lzz3TB6/5VLKx40jXFyUssJhhPFlIcaPC53SlokWqxp6SiQiMuKaplRTW1GScfuqbW/wnrt/R0dXN9G4U1maSBSVpWEOd0c5cjR6SlsmWqxq6CmRiMiIMzNuvWI65SWn9kpKUsYyeiNR3J2Lz05d4bCNZ15rO6UtHS1WlR9KJCJSEBbPauDChipCRSd/yF/2tjfx4JJ5PLhkHuXjSvn2mlfoicZZ8fH5xNx5dnc7z+89fEpbX1qsKn/yNv3XzBqBHwD1gAP3uvu3+uxzJfBlIA5EgVvc/Ukzmwl8GxgPxIC73P3BfMUqIiPPzLj72pksXLqGaM+JAo4rt77Oyq2vA3D7hVHcQ2mn9/Y35VeLVeVPPnskUeB2dz8PmAd8yszO67PPb4B3uPtM4KPAd5PtXcD17n4+8GfA3WY2IY+xikgBmFhVxgM3zs068D4YWqwqv/KWSNx9n7s/k3x8BNgGNPTZp8P9+D2tFSR6Lrj7dnd/Kfl4L/AG8KZ8xSoihaNpag0/vGkuFaXFp1zmGqhQkVFRWswPb9JiVfk0LGMkZjYVmAU8nWbbYjN7AfgliV5J3+1zgRLglTyHKSIFomlqDatuu4w5U6oH3TspLylmzpRqVt12mZJInpnnuciNmVUCa0iMczycZb9LgS+4+8KUtonAauAj7r42w3FLgCUA9fX1c5YvX572+Ts6OqisrBzsj5F3hR4fFH6MhR4fFH6MhRjfoa4Irx85SjTmvGmc83p39v2LzAgVG/VnjGNCeXh4gkxRiOcwVd/4FixYsMHdm4I8Z14TiZmFgV8Aj7v70hz23wHMdfcDZjaeRBL5Z3f/aS6v19TU5M3NzWm3rV69mssvvzzX0IddoccHhR9joccHhR9jocbn7mzY1caLG5/mvh1ltBzsIlRkmBnuTjTuNNaU0zSlesQXqyrUc3hM3/jMLHAiyeesLQPuA7ZlSiJmdjbwiru7mc0GSoFWMysBHgF+kGsSEZHTl5nRNLWGjp1l/Pb2y4nE4hzo6CESTVTxrassVe2sEZTP6r8XA9cBW8xsU7Ltc8BkAHdfBnwQuN7MIkA3cE0yqVwNXArUmtkNyWNvcPdNiMiYFy4uYmJV2UiHIUl5SyTu/iSQtW/p7l8Fvpqm/UfAj/IUmoiIDCH1BUVEJBAlEhERCaTfRGJmFWZWlHw83cz+T3I2loiISE49kt8B48ysAfg1iQH07+czKBERGT1ySSTm7l3AB4D/dPergPPzG5aIiIwWOSUSM5sPfJhEGROAoa2oJiIio1YuieQW4B+AR9z9eTObBjyR16hERGTU6Pc+EndfQ6JWFslB9wPu/rf5DkxEREaHXGZt/T8zG29mFcBzwFYz+7v8hyYiIqNBLpe2znP3w8Ai4L+Bt5CYuSUiIpJTIgkn7xtZBDzm7hGSC1CJiIjkkkjuAXaSWMHwd2Y2BTicz6BERGT0yGWw/d+Af0tp2mVmC/IXkoiIjCa5DLZXmdlSM2tOfn2DRO9EREQkp0tb9wNHgKuTX4eB7+UzKBERGT1yWY/kre7+wZTv70xZqEpERMa4XHok3WZ2ybFvzOxiEqsZioiI5NQj+STwgJlVkVjx8CBwQz6DEhGR0SOXWVubgHeY2fjk95r6KyIix2VMJGZ2W4Z2ANx9aZ5iEhGRUSRbj+SMYYtijIrE4hzo6CESdSIxJxKLEy7W6sciMrpkTCTufudwBjIWuDvNu9p4qLmF5p1ttLR1ESoqwgw+9fZe/voLv6KxupymqdVc1dRI05Tq4z1AEZFClctguwTk7jyycQ/fXLmd1s5euiMxPFmtLBKLARD3RK9kx4FOXm3t5Beb91FbUcKtV0xn8awGJRQRKVhKJHm2r72bW5ZvYsuedrp6Yzkd4w5dvTG6erv5/CPP8eD6Fu6+diYTq8ryHK2IyMDlUiJFy+oO0vqdB1m4dA0bdrXlnET66o7E2LCrjYVL19C88+AQRygiElwuI7uvmtm9ZvanpusrOVu/8yDX37eOzp4Y0fiJqvvT6yv52SffxYqPz+drH5oBwB3vP5cPXnEJX/zf56V9rmjc6eyJcd1965RMRKTg5JJI3g6sAj5FIqn8e+qd7nKqfe3d3PC9dXRHTu2F7NjfyQe//QeuvucpAGY2TqC8JMTPVj5JuLiIGZOqMj5vdyTGR763jn3tKiwgIoWj30Ti7l3uvsLdPwDMAsaTXMNdTuXu3PyTTfRE4mm3p/ZOeqNx3vXWWp586QAAT758gNmTq7M+f08kzi3LN+GutcVEpDDkdNOCmV1mZv8JbADGkagCLGk8snEPz+1tPylh9LXw3DN5/JZLqTujlHBxER09UQCOHI0wviz7/Ido3Nm8u51HN+4Z0rhFRAYrl8H2ncAtwP8AF7r71e7+szzHNSq5O99cub3fgfVV297gPXf/jn3tR4nGncrSRPKoLA1zuDva7+t0R2IsXbVdvRIRKQi59EhmuPtid/+Ju3fmPaJRrHlXG62dvVn3KUm5c73jaBR35+KzawG45Ow6Nr7WltNrtXb0smFXbvuKiORTLonkLDP7jZk9B2BmM8zsH/Mc16j0UHNL2gH2VJe97U08uGQeDy6ZR90ZJXx7zSv0RON88IpLiLnz7O72nF6rOxJjRXPLUIQtIhJILjckfgf4O+AeAHffbGb/D/infAY2GjXvbKO/q00rt77Oyq2vn9R258+30nFhlG9syf3+UPdED0hEZKTl0iMpd/d1fdr6v5A/xkRicVrauob1NVsOdhGJpZ8dJiIyXHJJJAfM7K2AA5jZh4B9eY1qFDrQ0UOoaHgr94aKjAMdPcP6miIifeVyLeVTwL3A281sD/Aq8Fd5jWoYpZZyD4eMusrSQZVyj0Sd4b7v38yIRDVzS0RGVi4rJO4AFppZBVDk7kfyH1b+ZCvl7g7ReHxQpdzDIcs6PjK9vpKvfGAGsbizq7WTv/vpZu54/7lc2DCB5/e207FzM8BJbXf+fGu/P0s4pKo1IjKyxtQKiQ8/sztrKfdjBlPKva6ylGg883jFsdIoAF/70IzjpVGuvucp/mnRBZxZM4Hz31x0UtuMSVVszjKLKxp36ipLB3AGRESGXrZrOGckv5qATwINya9PALPzH9rQ2dfezY4Dnfzjo8/R0tZNV2+s39lVx0q5t7QlSrlfe+/arDWuwsVFNFaXZ9zeX2mUs95Uw6zJ1QMql9JYU64VFUVkxGX8FHL3O5OrJE4CZrv77e5+OzAHmDxcAQZ1rJR7V08s76Xcm6ZWZx0nyVYapTQcYvy4UM7lUsygaUr2RCMiMhxy+XO2Hki9Xbs32VbwUku5O8EGpXMp5X5VUyNl4czLt2QrjdITiXLkaDTncill4WKubmoM8BOJiAyNXBLJD4B1ZvYlM/sS8DTw/XwGNRSylXKfVF3G+s8vZPmSefzgo3MBWHLpNB76xHzuvmYmoaLM3YpspdybplRTW1GS9rj+SqO8fuAgz7zWlnO5lLrKUuaoRyIiBSCXMvJ3ATcCbcmvG939K/kOLIj+SrlDYgzi2nvXcv3966itKGH+tFquWvYUL/zxMO8+P3uHK1MpdzPj1iumU15yaq8kU2mUFR+fT8yd11sP8fzewye1ZSqXUhYu5taF52gddxEpCDnV5HD3Z4Bn8hzLkMmllPv8abWs+Ph8Hn/+j7yyv4O1O1qBRIJZNLOB/9ryx4zHppZyXzx70knbFs9q4MH1LWzY1XbS62cqjXLM7Ree2pZOqMiYMamKRbMasu4nIjJcTrspP7mUcn/jcA8Lvr6av/jOWi4+u44Zk6pSBrmjjC8L9/s6mUq5mxl3XzuT0nB+Tm1puIhvXTtLvRERKRh5SyRm1mhmT5jZVjN73sxuTrPPlWa22cw2mVlz6hK+ZvYRM3sp+fWRXF83l1LuvbE43ZEYsbjz222vs6u1K2WQO8Th7khOr5WplPvEqjIeuHFu1oH3wSgLF/PAjXM5q2rckD6viEgQ+eyRRIHb3f08YB7wKTM7r88+vwHe4e4zgY8C3wUwsxrgi8BFwFzgi2aW08hyLqXcK1LGMOZMrWFXaxcXTUsd5D6Uy0tlLeXeNLWGH940l4rS4qyD97kIFRkVpcX88Ka5NE2tCfRcIiJDLW+JxN33JcdWSJZV2UbihsbUfTr8xLWhCjg+R/c9wEp3P+jubcBK4M9yed1cSrm/8y01/PzTl/CzT76L19uPsqnlEOtebeWhT8znvDeP59dbM4+PnBx/9lLuTVNrWHXbZcyZUj3o3kl5STFzplSz6rbLlEREpCDlvgBGAGY2FZhFYupw322Lga8AZwLvSzY3AKl/6u+mTxJKxyGnUu6rX9zP6hf3n9S2bM0Olq3Z0e+xfR0r5Z7pDvOJVWUsXzKPRzfuYemq7bR2nFyeJR2zxGWs2soSbls4nUX9lGcRERlJlu91v82sElgD3OXuD2fZ71LgC+6+0Mw+A4xz939KbrsD6Hb3r6c5bgmwBKC+vn7OHUvvJZ7mZ6ovg9czVzgZtCIzptefQbg4tw/6rt4YbZ29dPbG6I3FOXbUmeOc148aJcVFVJQUU11RknYa8Ujq6OigsrJypMPIqNDjg8KPUfEFV+gx9o1vwYIFG9y9Kchz5rVHYmZh4GfAj7MlEQB3/52ZTTOzOmAPcHnK5knA6gzH3UuizD0zZs72/3ihJO2MrdsHuAJhrspLivnVey5icm3mOluZpJawf3HT0yy68rKCrp21evVqLr/88pEOI6NCjw8KP0bFF1yhx5iP+PI5a8uA+4BtmSoFm9nZyf0ws9lAKdAKPA6828yqk4Ps70629fOa9Ds+MtSClHIPFxcxsaqMybXlhIutoJOIiEgm+eyRXAxcB2wxs03Jts+RLPjo7suADwLXm1kE6AauSQ6+HzSzLwPrk8f9X3fPXC0xKVRcRFeWUu79rQly7GbAgawJolLuIjLW5S2RuPuTQNY/1d39q8BXM2y7H7h/IK9pQGN1OTsOdKbd3t+aIDMmVRGL+4DWBFEpdxEZ6067T8Bspdz7WxNk9uTqAa0JolLuIiKnYSLpr5R7tjVBxpcNbE0QlXIXETkNE0m2Uu6QfU2Qw90DWxNEpdxFRE7DRJKtlHt/a4JsfK0t5zVBVMpdRCThtEskkCjlfmFD1Sk1rvpbE+TZ3e05rQmiUu4iIicMS4mU4XaslPvCpWuI9py4ObG/NUGytaVSKXcRkRNOyx4JqJS7iMhwOW0TCZxcyt2y39LSL5VyFxFJ77ROJHCilHt5abFKuYuI5MFpn0ggcZlrWl0F/7z4AhpryigvKc540+IxZokE0lhTxl2LLmD5knlMrCobnoBFREaR03KwPZPFsyexaFYDG3a1saK5heZdbbQc7CJUZJgZ7k407jTWlNM0pZqrmxqZM6Vag+oiIlmMqUQCiRldTVNrjl+iSi3lHg4ZdZWlqp0lIjIAYy6R9HWslLuIiAyO/vQWEZFAlEhERCQQJRIREQlEiURERAJRIhERkUCUSEREJBAlEhERCUSJREREAlEiERGRQJRIREQkECUSEREJRIlEREQCUSIREZFAlEhERCQQJRIREQlEiURERAJRIhERkUCUSEREJBAlEhERCUSJREREAlEiERGRQJRIREQkECUSEREJRIlEREQCUSIREZFAlEhERCQQJRIREQlEiURERAJRIhERkUCUSEREJJDQSAcwFkRicQ509BCJOuGQUVdZSrhYOVxETg9KJHng7jTvauOh5haad7bR0tZFqKgIM3CHaDxOY3U5TVOruaqpkaYp1SMdsojIoOUtkZhZI/ADoB5w4F53/1affT4M/D1gwBHgk+7+bHLbrcDHksduAW5096P5incouDuPbNzDN1dup7Wzl+5IDPfEtkgsdtK+Ow508mprJ7/YvI/aihJuvzCGu2NmIxC5iMjg5fP6ShS43d3PA+YBnzKz8/rs8ypwmbtfCHwZuBfAzBqAvwWa3P0CoBi4No+xBravvZtr713LPz76HC1t3XT1nkgimbhDV2+MlrZu9hxKHL+vvXt4AhYRGSJ5SyTuvs/dn0k+PgJsAxr67PMHd29LfrsWmJSyOQSUmVkIKAf25ivWoNbvPMjCpWvYsKuNrt5Y/wekEXdnw642Fi5dQ/POg0McoYhI/gzLiK+ZTQVmAU9n2e0m4L8B3H0P8HXgNWAf0O7uv85zmIOyfudBrr9vHZ09MaLxfrog/YjGnc6eGNfdt07JRERGDfP+rr8EfQGzSmANcJe7P5xhnwXAfwKXuHurmVUDPwOuAQ4BDwE/dfcfpTl2CbAEoL6+fs7y5cvTxtHR0UFlZWXwHyhFJOZsf/0I8QznsKbqDP7kopm4O4eOdPKbtRu5ZPYF1NdO4I2Dh/ifDc8d37e+DF5PuapVZMb0+jMIFxfOmEk+zuFQKvT4oPBjVHzBFXqMfeNbsGDBBndvCvKceZ21ZWZhEgnhx1mSyAzgu8Cfu3trsnkh8Kq770/u8zDwLuCUROLu95IcW2lqavLLL788bSyrV68m07Z0+puy6+5cc89annmtO2NPJFR0lM8/+RQAX/vQDH5zqI66I2EWP7CWf1p0ASvbatm8ux2A2y+M8o0toZRjjTlHxrF8ybyCGYAf6DkcboUeHxR+jIovuEKPMR/x5XPWlgH3AdvcfWmGfSYDDwPXufv2lE2vAfPMrBzoBv4UaM5XrDDwKbstrZ08t7c96+Ws1G290TjvemstT750AIAnXz7A7MnVxxNJumM3727n0Y17WDx7Utp9REQKQT57JBcD1wFbzGxTsu1zwGQAd18GfAGoBf4z+Vd31N2b3P1pM/sp8AyJ2V8bSfY6htpgpuz+/Nm99MacWA5jIgvPPZO/e8/b2dnayRtHeujoiQJw5GiE6fXZu7/dkRhLV21n0ayGgumViIj0lbdE4u5Pkrg/JNs+HyNxr0i6bV8EvpiH0I7b197NLcs3sWVPe86zrdyhOxLP+TVWbXuDVdve4Ev/53yicaeyNHHKK0vDHO6O9nt8a0cvG3a10TS1JufXFBEZTmO2TsdQTNntT0nKmErH0SjuzsVn1wJwydl1bHytLdOhx3VHYqxobslLfCIiQ2FMJpKhnLKbzWVvexMPLpnHg0vmUXdGCd9e8wo90TgrPj6fmDvPZhgfSeUOzbv6TzgiIiNlzNXa2tfezQ3fW0d3JH0vZFJ1GY/89cW8sr+D3mic6+9fx5JLp3HFefXsaevmMw89m3PyWbn1dVZuff2ktjt/vnXAMbcc7CISi6vQo4gUpDH1yeTu3PyTTfT0M8bx5MsHuPbetVx//zpqK0qYP62Wq5Y9xQt/PMy7z68fpmhPCBUZBzp6hv11RURyMaYSySMb9/Q7ZRdg/rRaVnx8Pjdd8hYunFTF2h2J21uOTdkdbmZGJJrfG0dFRAZrTF3a+ubK7f0OrL9xuIcFX19NbyzOd65voqK0mNaOXgCOHI0yviyc9fiZjRO44/3nEXdn8+5DfPkX29JeGuvblo174qZIEZFCNGZ6JF29MVo7e/vdrzcWpzsSIxZ3frvtdXa1dqVM2Q1xuDuS9fg9bd385XfWctWyp6itKOWit9SccmlsoJfLonGnrrI09x9WRGQYjZlE0pa82bA/FSXFxx/PmVrDrtYuLpqWOmX3UNbj93f00BNNjMFE43HOqT/jlEtjA71c1lhTroF2ESlYY+bSVmdvDPf+Lw+98y013H7F2+iNxVn/6kE2tRxi3autPPSJ+ew91M39v381p9d7+1lnUFNRyuHuCMcKYx67NDZ+XDjlDvdEW6ZVSMzQCooiUtDGRCKJxOL0xuIk1sfKbvWL+1n94v6T2pat2cGyNTtyfr2qsjB3Xnk+n/7xRi5oqGJi1TjgxKWxI0ejp7RlUhYu5uqmxpxfW0RkuI2J6yUHOnqy12oZQsVFxt3XzOSff7mN/R09bN596JRLY+naMqmrLGWOeiQiUsDGRCIZzqmz77twIjMmVfEP7z2X5UvmMaW2/PilsfPePJ5fb/0jrZ29p7SlUxYu5taF56hgo4gUtDFxaau/qbODnbKb7n6Ux57dy2PPnrwq8DOvHTrl0lh/l8tCRcaMSVUsmtWQcR8RkUIwJnokdZWlZOuT5GPKblCl4SK+de0s9UZEpOCNiUQSLi46qRJvX/mYshtEWbiYB26cy1nJAXkRkUI2Ji5tQeL+ELM42ZaoH+iU3aEWKjJKw0U8cONcrT8iIqPGmOiRAFRXlFAWzjz999iU3b//6WaOHI2ecjd7urahVGTGnCnVrLrtMiURERlVxkwiKS8ppraiJO22oZ6yW1xklJcU09/whlkirsaaMhomlLF8yTwmVpUN6ucTERkpY+bSFsCtV0znHx997pTCjalTdgH+9VcvnHI3eyTmOd3hXhYu5q5F5zO5toIVzS0072qj5WAXoSLDzHB3onGnsaacpinVXN3UyJwp1axZs0YD6yIyKo2pRLJ4VgMPrm9hw662k6buDvWU3cWzJ2Fmxy9RRWJxDnT0EIkmqvjWVZaqdpaInDbG1KeZmXH3tTMpDefnx840ZTdcXMTEqjIm15YzsapMSURETitj7hNtYlUZD9w4N+vA+2Boyq6IjFVjLpEANE2t4Yc3zaWitJhQUbBxiVCRUVFazA9v0pRdERmbxmQigUQyWXXbZcyZUj3o3kl5SbGm7IrImDemBtv7mliVmHL76MY9LF21ndaOxOJX2W5aNEtcxqqtLOG2hdNZNKtBs61EZEwb04kEEgPwi2dPYtGsBjbsahvQlF0lEBERJZLjjk3X1ZRdEZGBUSLJ4NiUXRERyU5/YouISCBKJCIiEogSieQkEosTiTmvtXaxr72bSCw+0iGJSIHQGImk5e4072rjoeYWmne20dLWxS3nR7n5W7/DPbEAWGN1OU1Tq7mqqZEmzWITGbOUSOQk7s4jG/fwzZXbae08+b6auPtJlZN3HOjk1dZOfrF5H7UVJdx6xXQW674akTFHiUSO29fezS3LN7FlT/sppfYzcYeu3hhdvd18/pHneHB9C3dfO1Mz3kTGEI2RCADrdx5k4dI1bNjVlnMS6as7EmPDrjYWLl1D886DQxyhiBQqJRJh/c6DXH/fOjp7Yiet0zIY0bjT2RPjuvvWKZmIjBG6tDXG7Wvv5obvraM7kr4XMr2+kq98YAaxuFMZ7+AbW7Zwx/vP5cKGCTy/t507f7417XHdkRgf+d46Vt12mS5ziZzm1CMZw9ydm3+yiZ5I5qm8O/Z38sFv/4Gr73kKgJmNEygvCXH1PU8RLi5ixqSqjMf2ROLcsnwTnq0KpoiMekokY9gjG/fw3N72rJezUrfF4jHe9dZannzpAABPvnyA2ZOrsx67eXc7j27cM3RBi0jBUSIZo9ydb67cntPA+sJzz+TxWy6lfNw4wsVFdPREAThyNML4suxXR7sjMZau2q5eichpTIlkjGre1UZrZ29O+67a9gbvuft3dHR1E407laWJ5FFZGuZwd7Tf41s7etmwqy1QvCJSuJRIxqiHmlsyDrCnKkkpnd8bieLuXHx2LQCXnF3Hxtf6TxDdkRgrmlsGH6yIFDQlkjGqeWdb1pUgj7nsbW/iwSXzeHDJPMrHlfLtNa/QE42z4uPzibnz7O72fp/DPdEDEpHTk6b/jkGRWJyWtq6c9l259XVWbn0dgNsvjOIeyjjlN5uWg11EYnEtDiZyGtJv9Rh0oKOHUNHw/teHiowDHT3D+poiMjzy9mliZo1m9oSZbTWz583s5jT7fNjMNpvZFjP7g5m9I2XbBDP7qZm9YGbbzGx+vmIdayJRZ7jrKpoZkahmbomcjvJ5aSsK3O7uz5jZGcAGM1vp7qnXRV4FLnP3NjP7c+Be4KLktm8Bv3L3D5lZCVCex1hPe6lr0Ld19Qz7dFx3JxxSVWCR01HeEom77wP2JR8fMbNtQAOwNWWfP6QcshaYBGBmVcClwA3J/XqB3OaqCpB+PZFQURFmEI87R6PZF6Y684xS7r/hnZxzZiXnffFxYnHnktkXcNG7qk8qjZJLuRRI3JxYV1k6pD+jiBQGG46/TM1sKvA74AJ3P5xhn88Ab3f3j5nZTBK9k63AO4ANwM3u3pnmuCXAEoD6+vo5y5cvTxtDR0cHlZWVwX+YPBnK+A51RXj98FGicSc+yP/f4qIiQqFi3nvpXB79zR+omzCeuedO5Zd/eJbL3zmDra+8hrtz4fSp/PbpE21vHDyU9vlKQ8VMr8/v+S/0/2Mo/BgVX3CFHmPf+BYsWLDB3ZuCPGfeZ22ZWSXwM+CWLElkAXATcElKXLOBv3H3p83sW8BngTv6Huvu95JIOjQ1Nfnll1+eNo7Vq1eTaVshGIr4TqwncpSu3qEY/nJmzDG++VyIv5hbR9Xu/XxjS4jn/SBnja8jGnf+5+mD/DKl7ftbOk55FjO4as4kllz+jjSvMXQK/f8YCj9GxRdcoceYj/jyOnXHzMIkksiP3f3hDPvMAL4LXOnurcnm3cBud386+f1PSSQWyWAo1hPJZvy4EL2Rk0ujjB8XyqlcSlm4mKubGoc8JhEpDPmctWXAfcA2d1+aYZ/JwMPAde6+/Vi7u/8RaDGztyWb/pSUsRU52VCuJ5LJkaNRSsInl0Y5cjSaU7mUuspS5kzJXNxRREa3fF7auhi4DthiZpuSbZ8DJgO4+zLgC0At8J/Jdb6jKdfq/gb4cXLG1g7gxjzGOmr1t57IpOoyHvnri3llfwe90TjX37+OJZdO44rz6tnT1s1nHno2p+TzzGttXHN+I/AGl5xdx083tBCNOx++aDK/3LLveFtfZeFibl14jtZxFzmN5XPW1pNA1k8Pd/8Y8LEM2zYBgQaATne5rCcCiXLvtz64CYDaihLmT6vlqmVP8YnLpvHu8+v5ry1/POWYUJHx/Rvncu7E8fzgo3P52uMvEo3FWPHx+Wzdd/h4aZRj5VJS21KfY8akKhbNahiaH1hECpJKpIxiuawnAjB/Wi0rPj6fx5//I6/s72DtjsRQ1JMvH2DRzIa0iSQad/7qvqdPavufDQf4xpaT3zLZpvyWhov41rWz1BsROc0pkYxSua4n8sbhHhZ8fTW9sTjfub6JitJiWjsSt+QcORplfFk4L/GVhYt54Ma5nFU1Li/PLyKFQ4lklMp1PZHeWBySuea3217nSE+Us8YnPtwrS0Mc7o4MaVyhIqM0XMQDN86laWrNkD63iBQmFW0cpXJdT6SipPj44zlTa9jV2sVF01LXEzk0ZDGVlxQzZ0o1q267TElEZAxRj2SUynU9kXe+pYbbr3gbvbE46189yKaWQ6x7tZWHPjGfvYe6uf/3rwaKwyxxGau2soTbFk5n0awGjYmIjDFKJKPQQNYTWf3ifla/uP+ktmVrdrBszY4Bv26RGeUlxbg70bjTWFNO05Rqrm5qZM6UaiUQkTFKiWQUOraeSCQ29HewZzIuZEx7UwW/es98wiGjrrJUi1SJCKBEMiqNxHoiRUVFFBcZk2tVzV9ETqZEMgqFQ5Z1fCRdCfh05d5zLQEPienGlv3+UhEZo3RtYhSqqywlGs98N3t7d4S//O5aNrYcAuD8N4+nvCTE1fc8Rbi4iBmTqtK2ZRONO6FiJRIROZUSySgULi6isTrzJaaeaPykAoqzJlfz5EsHgMTd7LMnV6dty6axplz9ERFJS4lklGqaWp3zOEm6cu+5loCHxBTfJlXvFZEMlEhGqauaGikLF/e/I6Qt955rCXjQeiIikp0SySjVNKWa2oqSnPZ95rU2Lj479W72trRtmWg9ERHJRolklDIzbr1iOuUlp/ZKQkXGj2666HgJ+HBx0fFy7zF3nt3dzvN7D5/Slo7WExGR/mj67yi2eFYDD65vYcOutpNKyacrAb8pOYMrVX9TfrWeiIjkQj2SUczMuPvamZSG8/PfqPVERCQXSiSj3MSqMh64cW7OA++50noiIpIrJZLTQNPUGn5401wqSosJFQXrPYSKjIrSYn54k9YTEZHcKJGcJpqm1rDqtsuYM6V60L0TrSciIoOhwfbTyMSqMpYvmcejG/ewdNV2Wjt66Y7Estbl0noiIhKUEslpxsxYPHsSi2Y1sGFXGyuaW2je1UbLwS5CRYaZaT0RERlS5rksszdKmNl+YFeGzXXAgWEMZ6DyHZ9RHAoZZo47sWgUGOh//lg/h0Oh0GNUfMEVeox945vi7m8K8oSnVSLJxsya3b1ppOPIpNDjg8KPsdDjg8KPUfEFV+gx5iM+DbaLiEggSiQiIhLIWEok9450AP0o9Pig8GMs9Pig8GNUfMEVeoxDHt+YGSMREZH8GEs9EhERyYNRl0jMrNHMnjCzrWb2vJndnGaft5vZU2bWY2afSWl/m5ltSvk6bGa3JLd9ycz2pGx7b55j/LCZbTazLWb2BzN7R8q2PzOzF83sZTP7bEr7W8zs6WT7g2aW24IkQxhftmML7BzuTLZvMrPmlPYaM1tpZi8l/x3UQisBz2He34c5xndlMr5NZtZsZpekbPtI8hy9ZGYfSWmfk/x5XjazfzMb/M1HQWI0s5nJ3/Hnk9uvSTnm+2b2aso5nDnc8SW3xVJieCylfUh+j4PGaGYL+rwPj5rZouS2gZ1Ddx9VX8BEYHby8RnAduC8PvucCbwTuAv4TIbnKQb+SGIONcCXMu2bpxjfBVQnH/858HRKXK8A04AS4NljxwIrgGuTj5cBnxyB+DIeWyjnMPn9TqAuzfP+K/DZ5OPPAl8difjy/T7MMb5KTlzengG8kHxcA+xI/ludfHzs51gHzAMM+G/gz0coxunAOcnHbwb2AROS338f+NBInsPk9x0ZnndIfo+HIsaUfWqAg0D5YM7hqOuRuPs+d38m+fgIsA1o6LPPG+6+Hohkeao/BV5x90w3MOY7xj+4+7FlCdcCk5KP5wIvu/sOd+8FlgNXJv/y+xPgp8n9HgAWDXd8uRw7FAKew2yuJHHuYITOYR95eR/mGF+HJz81gApO3KD6HmClux9Mxr8S+DMzmwiMd/e1yeN+wCDPX9AY3X27u7+UfLwXeAMIdFPdUMaXyVD+Hg9xjB8C/tvduwYTx6hLJKnMbCowC3i6n13TuRb4SZ+2Tye7gPcP9pJHXznGeBOJv+4g8SZoSdm2O9lWCxxy92if9uGOr79jC+EcQuKX5ddmtsHMlqS017v7vuTjPwL1IxTfMXl/H2aLz8wWm9kLwC+BjyabM70HG5KP+7YHNogYU7fPJdF7fyWl+a7kOfymmZWOUHzjkpeS1h67ZESefo8DxHhMuvdh7udwsF2qkf4i0V3bAHwgyz5fIs1lAhJvugMkPlSOtdWTuMxQROKS2P3DFOMCEn9F1Ca//xDw3ZTt1wH/TqKswcsp7Y3Ac8MdX7ZjC+UcJtsakv+eSeLy4KXJ7w/1ObZtBM9h3t+HucSX3O9SYFXy8WeAf0zZdkeyrenYPsn2/wX8Yjj+j/vGmNI2EXgRmNenzYBSEn/xf2Ek4kt5D04jcan1rfn4PR6ic7gfCA/2HAYKfqS+gDDwOHBbP/t9ifSJ5Erg11mOmxr0PzeXGElcr3wFmJ7SNh94POX7f0h+WfJDJ5Ruv+GKbwDHjtg5zPY+SH7oTEw+ngi8OFLx5ft9mOvvScr+O5IfdH8B3JPSfk+ybSInjwGctN9wxph8PB54hizX8oHLCZDsgsTXp/37JP5IHNLf46GIEbgZuDfIORx1l7aS1xjvA7a5+9JBPs1f0Kcbl7z+e8xi4LlBPndOMZrZZOBh4Dp3356yaT1wTnJmRwmJLudjnvgffYLEmxHgI8D/N9zxZTu2UM6hmVWY2RnHHgPvTonlMRLnDkboHKbI2/swx/jOTu6Hmc0m8ddnK4kPpXebWXXy0tq7SXzY7QMOm9m85HHXM8jzFzTG5O/GI8AP3P2nfY6ZmPL8ixiBc5g8d6XJ9jrgYmDrUP4eB40xZZeM78Ocz2GQTDgSX8AlJK5/bwY2Jb/eC3wC+ERyn7NIXHs8DBxKPh6f3FaRPIlVfZ73h8CW5PM+RvKv1jzG+F2gLWV7c8rx7yUx++IV4PMp7dNIzJp5GXgIKB3u+DIdW0jnMHmenk1+Pd/nHNYCvwFeAlYBNSP0f5zX92GO8f198vxsAp4CLkk5/qPJ99nLwI0p7U0kPlReIXHJ1fL8f5w2RuCvSEym2ZTyNTO57bfJc/gc8COgcgTie1cyhmeT/9401L/HQ/T/PBXYAxT1ed4BnUPd2S4iIoGMuktbIiJSWJRIREQkECUSEREJRIlEREQCUSIREZFAlEhEksxsgpn99SCP/S8zmzCA/b9kKZWpRUYzJRKREyYAaROJmYWyHeju73X3Q3mISaTgKZGInPAvwFuT6y98zcwuN7P/scRaElsBzOzRZCHI51OLQVpi/ZM6M5tqZtvM7DvJfX5tZmXZXtQSa2usTRbIe+RYoUYz+1tLrDOx2cyWJ9susxNrRGw8dge/yEjSDYkiScnqqb9w9wuS319OolrqBe7+arKtxt0PJpPDeuAyd281s50k7vquJHHHcpO7bzKzFSRK3Pyoz2t9icR6FV83s83A37j7GjP7vySqMNxiZnuBt7h7j5lNcPdDZvZz4F/c/fdmVgkc9ROVZEVGhHokItmtO5ZEkv7WzJ4lsb5II3BOmmNedfdNyccbSJShSMvMqkgsyLQm2fQAiQqtkCh78WMz+yvgWLL4PbDUzP42eZySiIw4JRKR7DqPPUj2UBYC8939HcBGYFyaY3pSHseArOMrWbwP+A9gNrDezELu/i/Ax4Ay4Pdm9vZBPrfIkFEiETnhCInlSjOpIrF+SVfyA3xe0Bd093agzcz+V7LpOmCNmRUBje7+BImie1VApZm91d23uPtXSVxaUyKRETfYv5RETjvJsY7fm9lzJFYz/GWfXX4FfMLMtpFY12TtEL30R4BlZlZOYq2IG0ksbvWj5KUvA/4tOUbyZTNbAMRJVHRNt+qiyLDSYLuIiASiS1siIhKIEomIiASiRCIiIoEokYiISCBKJCIiEogSiYiIBKJEIiIigSiRiIhIIP8/xztiC2N8qvAAAAAASUVORK5CYII=\n",
"text/plain": [
"<Figure size 432x432 with 1 Axes>"
]
},
"metadata": {
"needs_background": "light"
},
"output_type": "display_data"
}
],
"source": [
"plt.figure(figsize=(6,6))\n",
"plt.scatter(dfs['train loss'], dfs['dev loss'],s=800)\n",
"plt.xlabel('train loss')\n",
"plt.ylabel('dev loss')\n",
"for i in range(len(dfs)):\n",
" plt.text(dfs['train loss'][i],dfs['dev loss'][i], \n",
" str(df['Cdim'][i]) + '\\n' + str(df['hidden_layer_size'][i]),\n",
" ha='center', va='center', color='white',fontsize=8)\n",
"plt.grid('minor')\n",
"# note that the smalles train loss does not have the smallest dev loss \n",
"# due to overfitting at Cdim=50, hidden_layer_size=300"
]
},
{
"cell_type": "markdown",
"id": "d1b93302",
"metadata": {},
"source": [
"## use the best hyperparameter to calculate loss on the test dataset"
]
},
{
"cell_type": "code",
"execution_count": 155,
"id": "3579b344",
"metadata": {},
"outputs": [],
"source": [
"Cdim = dfs['Cdim'][0]\n",
"n_output1 = dfs['hidden_layer_size'][0]"
]
},
{
"cell_type": "code",
"execution_count": 156,
"id": "d1b70f03",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2min 40s, sys: 7.38 s, total: 2min 47s\n",
"Wall time: 2min 45s\n"
]
}
],
"source": [
"%%time\n",
"\n",
"Vdim=27\n",
"block_size = 3 # hyperparamter\n",
"n_input1 = Cdim * block_size\n",
"batch_size = 32\n",
"n_input2 = n_output1\n",
"n_output2 = Vdim\n",
"\n",
"g = torch.Generator().manual_seed(2147483647)\n",
"C = torch.randn((Vdim, Cdim), generator=g)\n",
"W1 = torch.randn((n_input1,n_output1), generator=g)\n",
"b1 = torch.randn(n_output1, generator=g)\n",
"W2 = torch.randn((n_input2, n_output2), generator=g)\n",
"b2 = torch.randn(n_output2, generator=g)\n",
"parameters = [C, W1, b1, W2, b2]\n",
"\n",
"for p in parameters:\n",
" p.requires_grad=True\n",
" \n",
"for i in range(500000):\n",
" # mini-batch\n",
" ix = torch.randint(0, Xtr.shape[0], (32,))\n",
"\n",
" # forward pass\n",
" # only grab 32 rows\n",
" emb = C[Xtr[ix]] # (32,3,2)\n",
" h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
" logits = h @ W2 + b2\n",
" loss = F.cross_entropy(logits, Ytr[ix])\n",
" # backward pass\n",
"\n",
" # initalize all gradients\n",
" for p in parameters:\n",
" p.grad = None\n",
"\n",
" loss.backward()\n",
"\n",
" if i < 10000:\n",
" lr = 0.1 \n",
" elif i < 100000:\n",
" lr = 0.05\n",
" elif i < 200000:\n",
" lr = 0.01\n",
" else:\n",
" lr = 0.005\n",
" # update the parameters\n",
" for p in parameters:\n",
" p.data += -lr * p.grad"
]
},
{
"cell_type": "code",
"execution_count": 157,
"id": "e81edd4a",
"metadata": {},
"outputs": [],
"source": [
"# compute loss on dev dataset\n",
"emb = C[Xtr] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss_train = F.cross_entropy(logits, Ytr).item()\n",
"\n",
"# can have even smaller loss with smaller learning rate\n",
"emb = C[Xdev] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss_dev = F.cross_entropy(logits, Ydev).item()\n",
"\n",
"# report loss with test dataset\n",
"emb = C[Xte] # (32,3,2)\n",
"h = torch.tanh(emb.view(-1,n_input1) @ W1 + b1)\n",
"logits = h @ W2 + b2\n",
"loss_test = F.cross_entropy(logits, Yte).item()"
]
},
{
"cell_type": "code",
"execution_count": 158,
"id": "d9c6ed37",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"train loss: 2.2831\n",
"dev loss: 2.2875\n",
"test loss: 2.2865\n"
]
}
],
"source": [
"print ('train loss: %.4f'%loss_train)\n",
"print ('dev loss: %.4f'%loss_dev)\n",
"print ('test loss: %.4f'%loss_test)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "61337334",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment