Skip to content

Instantly share code, notes, and snippets.

@dreiss
Created October 16, 2019 14:18
Show Gist options
  • Save dreiss/ee4ff1ed2e137326d13e96bb4f953061 to your computer and use it in GitHub Desktop.
Save dreiss/ee4ff1ed2e137326d13e96bb4f953061 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"import collections\n",
"import operator\n",
"import torch\n",
"import numpy as np"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"MODEL_ROOT = os.path.join(os.environ[\"HOME\"], \"Downloads\")\n",
"MODEL_ID = \"f140225004\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def grab_mod(sd, path):\n",
" return {\n",
" k.replace(path, \"\", 1): v\n",
" for k, v in sd.items()\n",
" if k.startswith(path)\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"class ByteLSTM(torch.nn.Module):\n",
" def __init__(\n",
" self,\n",
" embedding_width,\n",
" lstm_width,\n",
" lstm_depth,\n",
" mlp_widths,\n",
" ):\n",
" super(ByteLSTM, self).__init__()\n",
" self.embedding = torch.nn.Embedding(256, embedding_width)\n",
" self.lstm = torch.nn.LSTM(\n",
" input_size=embedding_width,\n",
" hidden_size=lstm_width,\n",
" num_layers=lstm_depth,\n",
" batch_first=True,\n",
" bidirectional=True,\n",
" )\n",
" mlp_layers = []\n",
" for in_width, out_width in zip([2 * lstm_width] + mlp_widths, mlp_widths):\n",
" if mlp_layers:\n",
" mlp_layers.append(torch.nn.ReLU())\n",
" mlp_layers.append(torch.nn.Linear(in_width, out_width))\n",
" self.mlp = torch.nn.Sequential(*mlp_layers)\n",
" \n",
" self.lstm_width = lstm_width\n",
" self.lstm_depth = lstm_depth\n",
" vocab = []\n",
"\n",
" def forward(self, byte_input):\n",
" token_emb = self.embedding(byte_input.long())\n",
" empty = torch.zeros(self.lstm_depth * 2, token_emb.size(0), self.lstm_width)\n",
" rep, ns = self.lstm(token_emb, (empty, empty))\n",
" pooled = torch.sum(rep, 1) / token_emb.shape[1]\n",
" raw_scores = self.mlp(pooled)\n",
" normalized = torch.nn.functional.softmax(raw_scores, dim=1)\n",
" return normalized\n",
"\n",
" @torch.jit.export\n",
" def get_classes(self):\n",
" return self.vocab\n",
"\n",
" def run_on_text(self, text, limit=3, mod=None):\n",
" mod = mod if mod is not None else self\n",
" text_bytes = text.lower().encode(\"utf-8\")\n",
" text_tensor = torch.as_tensor(np.ndarray(shape=(1, len(text_bytes)), dtype=np.byte, buffer=text_bytes))\n",
" scores = mod(text_tensor)\n",
" pairs = zip(self.get_classes(), scores[0].detach().numpy())\n",
" return collections.OrderedDict(sorted(pairs, key=operator.itemgetter(1), reverse=True)[:limit])"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ByteLSTM(\n",
" (embedding): Embedding(256, 64)\n",
" (lstm): LSTM(64, 128, batch_first=True, bidirectional=True)\n",
" (mlp): Sequential(\n",
" (0): Linear(in_features=256, out_features=64, bias=True)\n",
" (1): ReLU()\n",
" (2): Linear(in_features=64, out_features=16, bias=True)\n",
" )\n",
")"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = ByteLSTM(64, 128, 1, [64, 16])\n",
"model.eval()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:root:This caffe2 python run does not have GPU support. Will run in CPU only mode.\n",
"Install apex from https://github.com/NVIDIA/apex/.\n"
]
}
],
"source": [
"train_output = torch.load(f\"{MODEL_ROOT}/model-{MODEL_ID}.pt\", map_location=\"cpu\")"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model.embedding.load_state_dict(grab_mod(train_output[\"model_state\"], \"embedding.word_embedding.\"))\n",
"model.lstm.load_state_dict(grab_mod(train_output[\"model_state\"], \"representation.lstm.lstm.\"))\n",
"model.mlp[0].load_state_dict(grab_mod(train_output[\"model_state\"], \"decoder.mlp.0.\"))\n",
"model.mlp[2].load_state_dict(grab_mod(train_output[\"model_state\"], \"decoder.mlp.2.\"))\n",
"model.vocab = list(train_output[\"tensorizers\"][\"labels\"].vocab)\n",
"None"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('nba', 0.8307514),\n",
" ('nfl', 0.091046356),\n",
" ('fantasyfootball', 0.035300486)])"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.run_on_text(\"lebron\")"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ScriptModule(\n",
" original_name=ByteLSTM\n",
" (embedding): ScriptModule(original_name=Embedding)\n",
" (lstm): ScriptModule(original_name=LSTM)\n",
" (mlp): _ConstSequential(\n",
" original_name=_ConstSequential\n",
" (0): ScriptModule(original_name=Linear)\n",
" (1): ScriptModule(original_name=ReLU)\n",
" (2): ScriptModule(original_name=Linear)\n",
" )\n",
")"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"smod = torch.jit.script(model)\n",
"smod.eval()"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('nba', 0.8307514),\n",
" ('nfl', 0.091046356),\n",
" ('fantasyfootball', 0.035300486)])"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"smod.run_on_text(\"lebron\")"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"smod.save(\"model-reddit16-{MODEL_ID}.pt1\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"ScriptModule(\n",
" original_name=ByteLSTM\n",
" (embedding): ScriptModule(original_name=Embedding)\n",
" (lstm): ScriptModule(original_name=LSTM)\n",
" (mlp): ScriptModule(\n",
" original_name=_ConstSequential\n",
" (0): ScriptModule(original_name=Linear)\n",
" (1): ScriptModule(original_name=ReLU)\n",
" (2): ScriptModule(original_name=Linear)\n",
" )\n",
")"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loaded = torch.jit.load(\"model-reddit16-{MODEL_ID}.pt1\")\n",
"loaded.eval()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('nba', 0.8307514),\n",
" ('nfl', 0.091046356),\n",
" ('fantasyfootball', 0.035300486)])"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.run_on_text(\"lebron\", mod=loaded)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('StarWarsBattlefront', 0.73975956),\n",
" ('gaming', 0.22228362),\n",
" ('Overwatch', 0.019848222)])"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.run_on_text(\"Vader is the worst in the game\", mod=loaded)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"OrderedDict([('leagueoflegends', 0.9999074),\n",
" ('SquaredCircle', 6.828745e-05),\n",
" ('nba', 1.3255787e-05)])"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.run_on_text(\"If he hadn't been in the jungle, it would have been an easy win\", mod=loaded)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": [
"['news',\n",
" 'nba',\n",
" 'CFB',\n",
" 'StarWarsBattlefront',\n",
" 'Overwatch',\n",
" 'fantasyfootball',\n",
" 'soccer',\n",
" 'todayilearned',\n",
" 'The_Donald',\n",
" 'leagueoflegends',\n",
" 'hockey',\n",
" 'nfl',\n",
" 'SquaredCircle',\n",
" 'politics',\n",
" 'worldnews',\n",
" 'gaming']"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"loaded.get_classes()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "pytorch3-nightly",
"language": "python",
"name": "pytorch3-nightly"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment