Skip to content

Instantly share code, notes, and snippets.

@cahya-wirawan
Created August 17, 2020 18:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save cahya-wirawan/0e3eedbcd78c28602dbc554c447aed2a to your computer and use it in GitHub Desktop.
Save cahya-wirawan/0e3eedbcd78c28602dbc554c447aed2a to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoModelForSeq2SeqLM, AutoTokenizer"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"#model_name = \"facebook/mbart-large-en-ro\"\n",
"model_name = \"cahya/mbart-large-en-de\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"device = \"cuda\"\n",
"task = \"translation\""
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"## Helper functions from transformers/examples/seq2seq/utils.py\n",
"def use_task_specific_params(model, task):\n",
" \"\"\"Update config with summarization specific params.\"\"\"\n",
" task_specific_params = model.config.task_specific_params\n",
"\n",
" if task_specific_params is not None:\n",
" pars = task_specific_params.get(task, {})\n",
" logger.info(f\"using task specific params for {task}: {pars}\")\n",
" model.config.update(pars)\n",
"\n",
"def trim_batch(\n",
" input_ids, pad_token_id, attention_mask=None,\n",
"):\n",
" \"\"\"Remove columns that are populated exclusively by pad_token_id\"\"\"\n",
" keep_column_mask = input_ids.ne(pad_token_id).any(dim=0)\n",
" if attention_mask is None:\n",
" return input_ids[:, keep_column_mask]\n",
" else:\n",
" return (input_ids[:, keep_column_mask], attention_mask[:, keep_column_mask])\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"model = AutoModelForSeq2SeqLM.from_pretrained(model_name).to(device)\n",
"tokenizer = AutoTokenizer.from_pretrained(model_name)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"decoder_start_token_id = tokenizer.bos_token_id"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"use_task_specific_params(model, task)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"texts = [\"I am very hungry\", \"The weather is very hot today\"]"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.\n"
]
}
],
"source": [
"#batch = tokenizer(texts, return_tensors=\"pt\", truncation=True, padding=\"max_length\").to(device)\n",
"batch = tokenizer(texts, return_tensors=\"pt\", truncation=True, padding='longest').to(device)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'input_ids': tensor([[ 87, 444, 4552, 1926, 47285, 2, 250004, 1],\n",
" [ 581, 92949, 83, 4552, 8010, 18925, 2, 250004]],\n",
" device='cuda:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 0],\n",
" [1, 1, 1, 1, 1, 1, 1, 1]], device='cuda:0')}"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"batch"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"input_ids, attention_mask = trim_batch(**batch, pad_token_id=tokenizer.pad_token_id)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"summaries = model.generate(\n",
" input_ids=input_ids,\n",
" attention_mask=attention_mask,\n",
" decoder_start_token_id=decoder_start_token_id\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"tensor([[0, 0, 6, 5],\n",
" [0, 0, 6, 5]], device='cuda:0')"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"summaries"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"dec = tokenizer.batch_decode(summaries, skip_special_tokens=True, clean_up_tokenization_spaces=False)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"['.', '.']"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dec"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment