Skip to content

Instantly share code, notes, and snippets.

@sekstini
Last active December 29, 2023 21:37
Show Gist options
  • Save sekstini/545075684ff08ad644ff1e295735d55a to your computer and use it in GitHub Desktop.
Save sekstini/545075684ff08ad644ff1e295735d55a to your computer and use it in GitHub Desktop.
Per token logits
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "39e6dc62-c060-4cc4-9313-27020feb16e7",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fluffy/git/transformers/src/transformers/utils/generic.py:441: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n"
]
}
],
"source": [
"import torch\n",
"from transformers import AutoModelForCausalLM, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "84443917-ad82-4357-a3f3-7a5c561f982d",
"metadata": {},
"outputs": [],
"source": [
"tokenizer = AutoTokenizer.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.2\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "3f92998a-98aa-492d-9225-dc50b48aa1ae",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/fluffy/git/transformers/src/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n",
"/home/fluffy/git/transformers/src/transformers/utils/generic.py:309: UserWarning: torch.utils._pytree._register_pytree_node is deprecated. Please use torch.utils._pytree.register_pytree_node instead.\n",
" _torch_pytree._register_pytree_node(\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "5a61999ab8a849f793549c3bd76dcb69",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Loading checkpoint shards: 0%| | 0/3 [00:00<?, ?it/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"model = AutoModelForCausalLM.from_pretrained(\"mistralai/Mistral-7B-Instruct-v0.2\", device_map=0, torch_dtype=\"auto\")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c645e902-2116-492d-83bc-7da9858ea7a3",
"metadata": {},
"outputs": [],
"source": [
"with torch.inference_mode():\n",
" inputs = tokenizer(\"The quick brown fox jumps over the lazy\", return_tensors=\"pt\")\n",
" outputs = model(**inputs.to(model.device), use_cache=False)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "11589c6f-e173-4410-a047-e8f9d6eee02b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(torch.Size([1, 11]), torch.Size([1, 11, 32000]))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs.input_ids.shape, outputs.logits.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "ffc8326e-6337-4669-a53b-6b9cc5b7f03e",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<s>\n",
" 45.57%: ▁Question\n",
" 37.78%: ▁Q\n",
" 8.43%: ▁#\n",
" 5.44%: ▁User\n",
"\n",
"▁The\n",
" 1.74%: ▁first\n",
" 1.63%: ▁\n",
" 1.58%: ▁following\n",
" 0.99%: ▁new\n",
"\n",
"▁quick\n",
" 40.03%: est\n",
" 35.33%: ▁answer\n",
" 8.93%: ▁and\n",
" 1.99%: -\n",
"\n",
"▁brown\n",
" 98.40%: ▁f\n",
" 0.43%: ▁j\n",
" 0.33%: ▁dog\n",
" 0.11%: ▁Fox\n",
"\n",
"▁f\n",
" 99.99%: ox\n",
" 0.00%: ▁Fox\n",
" 0.00%: fox\n",
" 0.00%: oxy\n",
"\n",
"ox\n",
" 91.65%: ▁j\n",
" 5.86%: ▁jumped\n",
" 0.45%: <0x0A>\n",
" 0.45%: ▁is\n",
"\n",
"▁j\n",
" 99.98%: umps\n",
" 0.01%: umped\n",
" 0.00%: umper\n",
" 0.00%: umbled\n",
"\n",
"umps\n",
" 99.40%: ▁over\n",
" 0.12%: <0x0A>\n",
" 0.07%: ▁\n",
" 0.03%: ▁on\n",
"\n",
"▁over\n",
" 93.97%: ▁the\n",
" 5.64%: ▁a\n",
" 0.06%: ▁lazy\n",
" 0.05%: <0x0A>\n",
"\n",
"▁the\n",
" 99.05%: ▁lazy\n",
" 0.22%: ▁L\n",
" 0.20%: ▁la\n",
" 0.05%: ▁l\n",
"\n",
"▁lazy\n",
" 99.11%: ▁dog\n",
" 0.19%: ▁red\n",
" 0.14%: ,\n",
" 0.11%: ▁dogs\n",
"\n"
]
}
],
"source": [
"n = inputs.input_ids.shape[1]\n",
"\n",
"for input_id, input_id_logits in zip(inputs.input_ids[0], outputs.logits[0]):\n",
" probabilities = torch.softmax(input_id_logits, dim=-1)\n",
" top_probabilities, top_indices = torch.topk(probabilities, k=5)\n",
"\n",
" print(tokenizer.convert_ids_to_tokens(input_id.item()))\n",
" for p, index in zip(top_probabilities, top_indices):\n",
" print(f\"{p.item():10.2%}: {tokenizer.convert_ids_to_tokens(index.item())}\")\n",
" print()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment