Created
November 23, 2023 23:27
-
-
Save p-i-/355668983aaeee3f282977cdfb93017c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 🔸 reference (BERT)\n", | |
"\n", | |
"https://huggingface.co/bert-base-uncased\n", | |
"\n", | |
"https://huggingface.co/bert-base-multilingual-cased <-- use this for fairness, similar size to our model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/Users/pi/code/m2/fff/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", | |
" from .autonotebook import tqdm as notebook_tqdm\n", | |
"config.json: 100%|██████████| 625/625 [00:00<00:00, 1.84MB/s]\n", | |
"model.safetensors: 100%|██████████| 714M/714M [01:28<00:00, 8.07MB/s] \n", | |
"Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias']\n", | |
"- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n", | |
"tokenizer_config.json: 100%|██████████| 29.0/29.0 [00:00<00:00, 115kB/s]\n", | |
"vocab.txt: 100%|██████████| 996k/996k [00:00<00:00, 2.76MB/s]\n", | |
"tokenizer.json: 100%|██████████| 1.96M/1.96M [00:00<00:00, 10.9MB/s]\n" | |
] | |
} | |
], | |
"source": [ | |
"from transformers import pipeline\n", | |
"# unmasker = pipeline('fill-mask', model='bert-base-uncased') # 110M params\n", | |
"unmasker = pipeline('fill-mask', model='bert-base-multilingual-cased') # 179M params, 714MB\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"62.6 ms ± 1.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit unmasker(\"The capital of France is [MASK]!\")" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"https://huggingface.co/pbelcak/UltraFastBERT-1x11-long" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"metadata": {}, | |
"source": [ | |
"# 🔸 Our model\n", | |
"\n", | |
"https://huggingface.co/pbelcak/UltraFastBERT-1x11-long\n", | |
"\n", | |
"(189M params)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import cramming" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from transformers import AutoTokenizer, AutoModelForMaskedLM\n", | |
"import torch\n", | |
"\n", | |
"tokenizer = AutoTokenizer.from_pretrained(\"pbelcak/UltraFastBERT-1x11-long\")\n", | |
"model = AutoModelForMaskedLM.from_pretrained(\"pbelcak/UltraFastBERT-1x11-long\")\n" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"text = \"The capital of France is <mask>!\"\n", | |
"# text = \"The cat sat on <mask> mat.\"\n", | |
"encoded_input = tokenizer(text, return_tensors='pt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"77.6 ms ± 2.08 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)\n" | |
] | |
} | |
], | |
"source": [ | |
"%timeit output = model(**encoded_input)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"Text: The capital of France is <mask>!\n", | |
"Most likely token prediction: ['paris france french here germany']\n" | |
] | |
} | |
], | |
"source": [ | |
"output = model(**encoded_input)\n", | |
"\n", | |
"mask_token_index = (encoded_input.input_ids[0] == tokenizer.mask_token_id).nonzero(as_tuple=True)[0]\n", | |
"logits = output['outputs']\n", | |
"top5_indices = torch.topk(logits[mask_token_index], k=5).indices\n", | |
"token_winners = [tokenizer.decode(u) for u in top5_indices]\n", | |
"print('Text:', text)\n", | |
"print(f'Most likely token prediction: {token_winners}')" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": ".venv", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.11.6" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment