Created
October 16, 2019 14:18
-
-
Save dreiss/ee4ff1ed2e137326d13e96bb4f953061 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import os\n", | |
"import collections\n", | |
"import operator\n", | |
"import torch\n", | |
"import numpy as np" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"MODEL_ROOT = os.path.join(os.environ[\"HOME\"], \"Downloads\")\n", | |
"MODEL_ID = \"f140225004\"" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"def grab_mod(sd, path):\n", | |
" return {\n", | |
" k.replace(path, \"\", 1): v\n", | |
" for k, v in sd.items()\n", | |
" if k.startswith(path)\n", | |
" }" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"class ByteLSTM(torch.nn.Module):\n", | |
" def __init__(\n", | |
" self,\n", | |
" embedding_width,\n", | |
" lstm_width,\n", | |
" lstm_depth,\n", | |
" mlp_widths,\n", | |
" ):\n", | |
" super(ByteLSTM, self).__init__()\n", | |
" self.embedding = torch.nn.Embedding(256, embedding_width)\n", | |
" self.lstm = torch.nn.LSTM(\n", | |
" input_size=embedding_width,\n", | |
" hidden_size=lstm_width,\n", | |
" num_layers=lstm_depth,\n", | |
" batch_first=True,\n", | |
" bidirectional=True,\n", | |
" )\n", | |
" mlp_layers = []\n", | |
" for in_width, out_width in zip([2 * lstm_width] + mlp_widths, mlp_widths):\n", | |
" if mlp_layers:\n", | |
" mlp_layers.append(torch.nn.ReLU())\n", | |
" mlp_layers.append(torch.nn.Linear(in_width, out_width))\n", | |
" self.mlp = torch.nn.Sequential(*mlp_layers)\n", | |
" \n", | |
" self.lstm_width = lstm_width\n", | |
" self.lstm_depth = lstm_depth\n", | |
" vocab = []\n", | |
"\n", | |
" def forward(self, byte_input):\n", | |
" token_emb = self.embedding(byte_input.long())\n", | |
" empty = torch.zeros(self.lstm_depth * 2, token_emb.size(0), self.lstm_width)\n", | |
" rep, ns = self.lstm(token_emb, (empty, empty))\n", | |
" pooled = torch.sum(rep, 1) / token_emb.shape[1]\n", | |
" raw_scores = self.mlp(pooled)\n", | |
" normalized = torch.nn.functional.softmax(raw_scores, dim=1)\n", | |
" return normalized\n", | |
"\n", | |
" @torch.jit.export\n", | |
" def get_classes(self):\n", | |
" return self.vocab\n", | |
"\n", | |
" def run_on_text(self, text, limit=3, mod=None):\n", | |
" mod = mod if mod is not None else self\n", | |
" text_bytes = text.lower().encode(\"utf-8\")\n", | |
" text_tensor = torch.as_tensor(np.ndarray(shape=(1, len(text_bytes)), dtype=np.byte, buffer=text_bytes))\n", | |
" scores = mod(text_tensor)\n", | |
" pairs = zip(self.get_classes(), scores[0].detach().numpy())\n", | |
" return collections.OrderedDict(sorted(pairs, key=operator.itemgetter(1), reverse=True)[:limit])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ByteLSTM(\n", | |
" (embedding): Embedding(256, 64)\n", | |
" (lstm): LSTM(64, 128, batch_first=True, bidirectional=True)\n", | |
" (mlp): Sequential(\n", | |
" (0): Linear(in_features=256, out_features=64, bias=True)\n", | |
" (1): ReLU()\n", | |
" (2): Linear(in_features=64, out_features=16, bias=True)\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model = ByteLSTM(64, 128, 1, [64, 16])\n", | |
"model.eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"WARNING:root:This caffe2 python run does not have GPU support. Will run in CPU only mode.\n", | |
"Install apex from https://github.com/NVIDIA/apex/.\n" | |
] | |
} | |
], | |
"source": [ | |
"train_output = torch.load(f\"{MODEL_ROOT}/model-{MODEL_ID}.pt\", map_location=\"cpu\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.embedding.load_state_dict(grab_mod(train_output[\"model_state\"], \"embedding.word_embedding.\"))\n", | |
"model.lstm.load_state_dict(grab_mod(train_output[\"model_state\"], \"representation.lstm.lstm.\"))\n", | |
"model.mlp[0].load_state_dict(grab_mod(train_output[\"model_state\"], \"decoder.mlp.0.\"))\n", | |
"model.mlp[2].load_state_dict(grab_mod(train_output[\"model_state\"], \"decoder.mlp.2.\"))\n", | |
"model.vocab = list(train_output[\"tensorizers\"][\"labels\"].vocab)\n", | |
"None" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"OrderedDict([('nba', 0.8307514),\n", | |
" ('nfl', 0.091046356),\n", | |
" ('fantasyfootball', 0.035300486)])" | |
] | |
}, | |
"execution_count": 8, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.run_on_text(\"lebron\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ScriptModule(\n", | |
" original_name=ByteLSTM\n", | |
" (embedding): ScriptModule(original_name=Embedding)\n", | |
" (lstm): ScriptModule(original_name=LSTM)\n", | |
" (mlp): _ConstSequential(\n", | |
" original_name=_ConstSequential\n", | |
" (0): ScriptModule(original_name=Linear)\n", | |
" (1): ScriptModule(original_name=ReLU)\n", | |
" (2): ScriptModule(original_name=Linear)\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 9, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"smod = torch.jit.script(model)\n", | |
"smod.eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"OrderedDict([('nba', 0.8307514),\n", | |
" ('nfl', 0.091046356),\n", | |
" ('fantasyfootball', 0.035300486)])" | |
] | |
}, | |
"execution_count": 10, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"smod.run_on_text(\"lebron\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"smod.save(\"model-reddit16-{MODEL_ID}.pt1\")" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"ScriptModule(\n", | |
" original_name=ByteLSTM\n", | |
" (embedding): ScriptModule(original_name=Embedding)\n", | |
" (lstm): ScriptModule(original_name=LSTM)\n", | |
" (mlp): ScriptModule(\n", | |
" original_name=_ConstSequential\n", | |
" (0): ScriptModule(original_name=Linear)\n", | |
" (1): ScriptModule(original_name=ReLU)\n", | |
" (2): ScriptModule(original_name=Linear)\n", | |
" )\n", | |
")" | |
] | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loaded = torch.jit.load(\"model-reddit16-{MODEL_ID}.pt1\")\n", | |
"loaded.eval()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"OrderedDict([('nba', 0.8307514),\n", | |
" ('nfl', 0.091046356),\n", | |
" ('fantasyfootball', 0.035300486)])" | |
] | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.run_on_text(\"lebron\", mod=loaded)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"OrderedDict([('StarWarsBattlefront', 0.73975956),\n", | |
" ('gaming', 0.22228362),\n", | |
" ('Overwatch', 0.019848222)])" | |
] | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.run_on_text(\"Vader is the worst in the game\", mod=loaded)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"OrderedDict([('leagueoflegends', 0.9999074),\n", | |
" ('SquaredCircle', 6.828745e-05),\n", | |
" ('nba', 1.3255787e-05)])" | |
] | |
}, | |
"execution_count": 15, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"model.run_on_text(\"If he hadn't been in the jungle, it would have been an easy win\", mod=loaded)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": { | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": [ | |
"['news',\n", | |
" 'nba',\n", | |
" 'CFB',\n", | |
" 'StarWarsBattlefront',\n", | |
" 'Overwatch',\n", | |
" 'fantasyfootball',\n", | |
" 'soccer',\n", | |
" 'todayilearned',\n", | |
" 'The_Donald',\n", | |
" 'leagueoflegends',\n", | |
" 'hockey',\n", | |
" 'nfl',\n", | |
" 'SquaredCircle',\n", | |
" 'politics',\n", | |
" 'worldnews',\n", | |
" 'gaming']" | |
] | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"loaded.get_classes()" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "pytorch3-nightly", | |
"language": "python", | |
"name": "pytorch3-nightly" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment