Skip to content

Instantly share code, notes, and snippets.

@cahya-wirawan
Created September 8, 2020 11:14
Show Gist options
  • Save cahya-wirawan/b36e91cae21a6a7f9a10e1c85f59d9ae to your computer and use it in GitHub Desktop.
Save cahya-wirawan/b36e91cae21a6a7f9a10e1c85f59d9ae to your computer and use it in GitHub Desktop.
BERT - GPT2 - CNN
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# BERT - GPT2"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import nlp\n",
"import logging\n",
"import transformers \n",
"from transformers import BertTokenizer, GPT2Tokenizer, EncoderDecoderModel, Trainer, TrainingArguments\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"3.0.2\n"
]
}
],
"source": [
"print(transformers.__version__)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"logging.basicConfig(level=logging.INFO)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-config.json from cache at /root/.cache/torch/transformers/b945b69218e98b3e2c95acf911789741307dec43c698d35fad11c1ae28bda352.9da767be51e1327499df13488672789394e2ca38b877837e52618a67d7002391\n",
"INFO:transformers.configuration_utils:Model config BertConfig {\n",
" \"architectures\": [\n",
" \"BertForMaskedLM\"\n",
" ],\n",
" \"attention_probs_dropout_prob\": 0.1,\n",
" \"gradient_checkpointing\": false,\n",
" \"hidden_act\": \"gelu\",\n",
" \"hidden_dropout_prob\": 0.1,\n",
" \"hidden_size\": 768,\n",
" \"initializer_range\": 0.02,\n",
" \"intermediate_size\": 3072,\n",
" \"layer_norm_eps\": 1e-12,\n",
" \"max_position_embeddings\": 512,\n",
" \"model_type\": \"bert\",\n",
" \"num_attention_heads\": 12,\n",
" \"num_hidden_layers\": 12,\n",
" \"pad_token_id\": 0,\n",
" \"type_vocab_size\": 2,\n",
" \"vocab_size\": 28996\n",
"}\n",
"\n",
"INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/bert-base-cased-pytorch_model.bin from cache at /root/.cache/torch/transformers/d8f11f061e407be64c4d5d7867ee61d1465263e24085cfa26abf183fdc830569.3fadbea36527ae472139fe84cddaa65454d7429f12d543d80bfc3ad70de55ac2\n",
"INFO:transformers.modeling_utils:All model checkpoint weights were used when initializing BertModel.\n",
"\n",
"INFO:transformers.modeling_utils:All the weights of BertModel were initialized from the model checkpoint at bert-base-cased.\n",
"If your task is similar to the task the model of the ckeckpoint was trained on, you can already use BertModel for predictions without further training.\n",
"INFO:transformers.configuration_utils:loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-config.json from cache at /root/.cache/torch/transformers/4be02c5697d91738003fb1685c9872f284166aa32e061576bbe6aaeb95649fcf.db13c9bc9c7bdd738ec89e069621d88e05dc670366092d809a9cbcac6798e24e\n",
"INFO:transformers.configuration_utils:Model config GPT2Config {\n",
" \"activation_function\": \"gelu_new\",\n",
" \"architectures\": [\n",
" \"GPT2LMHeadModel\"\n",
" ],\n",
" \"attn_pdrop\": 0.1,\n",
" \"bos_token_id\": 50256,\n",
" \"embd_pdrop\": 0.1,\n",
" \"eos_token_id\": 50256,\n",
" \"initializer_range\": 0.02,\n",
" \"layer_norm_epsilon\": 1e-05,\n",
" \"model_type\": \"gpt2\",\n",
" \"n_ctx\": 1024,\n",
" \"n_embd\": 768,\n",
" \"n_head\": 12,\n",
" \"n_layer\": 12,\n",
" \"n_positions\": 1024,\n",
" \"resid_pdrop\": 0.1,\n",
" \"summary_activation\": null,\n",
" \"summary_first_dropout\": 0.1,\n",
" \"summary_proj_to_labels\": true,\n",
" \"summary_type\": \"cls_index\",\n",
" \"summary_use_proj\": true,\n",
" \"task_specific_params\": {\n",
" \"text-generation\": {\n",
" \"do_sample\": true,\n",
" \"max_length\": 50\n",
" }\n",
" },\n",
" \"vocab_size\": 50257\n",
"}\n",
"\n",
"INFO:transformers.modeling_encoder_decoder:Initializing gpt2 as a decoder model. Cross attention layers are added to gpt2 and randomly initialized if gpt2's architecture allows for cross attention layers.\n",
"INFO:transformers.modeling_utils:loading weights file https://cdn.huggingface.co/gpt2-pytorch_model.bin from cache at /root/.cache/torch/transformers/d71fd633e58263bd5e91dd3bde9f658bafd81e11ece622be6a3c2e4d42d8fd89.778cf36f5c4e5d94c8cd9cefcf2a580c8643570eb327f0d4a1f007fab2acbdf1\n",
"INFO:transformers.modeling_utils:All model checkpoint weights were used when initializing GPT2LMHeadModel.\n",
"\n",
"WARNING:transformers.modeling_utils:Some weights of GPT2LMHeadModel were not initialized from the model checkpoint at gpt2 and are newly initialized: ['h.0.attn.masked_bias', 'h.1.attn.masked_bias', 'h.2.attn.masked_bias', 'h.3.attn.masked_bias', 'h.4.attn.masked_bias', 'h.5.attn.masked_bias', 'h.6.attn.masked_bias', 'h.7.attn.masked_bias', 'h.8.attn.masked_bias', 'h.9.attn.masked_bias', 'h.10.attn.masked_bias', 'h.11.attn.masked_bias', 'lm_head.weight']\n",
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
"INFO:transformers.configuration_encoder_decoder:Set `config.is_decoder=True` for decoder_config\n",
"INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-cased-vocab.txt from cache at /root/.cache/torch/transformers/5e8a2b4893d13790ed4150ca1906be5f7a03d6c4ddf62296c383f6db42814db2.e13dbb970cb325137104fb2e5f36fe865f27746c6b526f6352861b1980eb80b1\n"
]
}
],
"source": [
"model = EncoderDecoderModel.from_encoder_decoder_pretrained(\"bert-base-cased\", \"gpt2\")\n",
"# cache is currently not supported by EncoderDecoder framework\n",
"model.decoder.config.use_cache = False\n",
"bert_tokenizer = BertTokenizer.from_pretrained(\"bert-base-cased\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"# CLS token will work as BOS token\n",
"bert_tokenizer.bos_token = bert_tokenizer.cls_token\n",
"\n",
"# SEP token will work as EOS token\n",
"bert_tokenizer.eos_token = bert_tokenizer.sep_token\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"# make sure GPT2 appends EOS in begin and end\n",
"def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):\n",
" outputs = [self.bos_token_id] + token_ids_0 + [self.eos_token_id]\n",
" return outputs"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-vocab.json from cache at /root/.cache/torch/transformers/f2808208f9bec2320371a9f5f891c184ae0b674ef866b79c58177067d15732dd.1512018be4ba4e8726e41b9145129dc30651ea4fec86aa61f4b9f40bf94eac71\n",
"INFO:transformers.tokenization_utils_base:loading file https://s3.amazonaws.com/models.huggingface.co/bert/gpt2-merges.txt from cache at /root/.cache/torch/transformers/d629f792e430b3c76a1291bb2766b0a047e36fae0588f9dbc1ae51decdff691b.70bec105b4158ed9a1747fea67a43f5dee97855c64d62b6ec3742f4cfdb5feda\n"
]
}
],
"source": [
"GPT2Tokenizer.build_inputs_with_special_tokens = build_inputs_with_special_tokens\n",
"gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\")\n",
"## gpt2_tokenizer = GPT2Tokenizer.from_pretrained(\"/output/gpt2-id-100/small\")\n",
"# set pad_token_id to unk_token_id -> be careful here as unk_token_id == eos_token_id == bos_token_id\n",
"gpt2_tokenizer.pad_token = gpt2_tokenizer.unk_token"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"# set decoding params\n",
"model.config.decoder_start_token_id = gpt2_tokenizer.bos_token_id\n",
"model.config.eos_token_id = gpt2_tokenizer.eos_token_id\n",
"model.config.max_length = 142\n",
"model.config.min_length = 56\n",
"model.config.no_repeat_ngram_size = 3\n",
"model.early_stopping = True\n",
"model.length_penalty = 2.0\n",
"model.num_beams = 4"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:nlp.load:Checking /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py for additional imports.\n",
"INFO:filelock:Lock 140606189019032 acquired on /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py.lock\n",
"INFO:nlp.load:Found main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail\n",
"INFO:nlp.load:Found specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n",
"INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cnn_dailymail.py\n",
"INFO:nlp.load:Found dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/dataset_infos.json to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/dataset_infos.json\n",
"INFO:nlp.load:Found metadata file for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cnn_dailymail.json\n",
"INFO:filelock:Lock 140606189019032 released on /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py.lock\n",
"INFO:nlp.info:Loading Dataset Infos from /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n",
"INFO:nlp.builder:Overwrite dataset info from restored data version.\n",
"INFO:nlp.info:Loading Dataset info from /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n",
"INFO:nlp.builder:Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8)\n",
"INFO:nlp.builder:Constructing Dataset for split train, from /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n",
"INFO:nlp.utils.info_utils:All the checksums matched successfully for post processing resources\n",
"INFO:nlp.load:Checking /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py for additional imports.\n",
"INFO:filelock:Lock 140605783911784 acquired on /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py.lock\n",
"INFO:nlp.load:Found main folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail\n",
"INFO:nlp.load:Found specific version folder for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n",
"INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cnn_dailymail.py\n",
"INFO:nlp.load:Found dataset infos file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/dataset_infos.json to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/dataset_infos.json\n",
"INFO:nlp.load:Found metadata file for dataset https://s3.amazonaws.com/datasets.huggingface.co/nlp/datasets/cnn_dailymail/cnn_dailymail.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cnn_dailymail.json\n",
"INFO:filelock:Lock 140605783911784 released on /root/.cache/huggingface/datasets/720d2e20d8dc6d98f21195a39cc934bb41dd0a40b57ea3d323661a7c5d70522c.4fe1f8a4d3f3c15617ba15dd2d93f559a09627c62d0b04e22f89a5131b7bffb9.py.lock\n",
"INFO:nlp.info:Loading Dataset Infos from /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/datasets/cnn_dailymail/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n",
"INFO:nlp.builder:Overwrite dataset info from restored data version.\n",
"INFO:nlp.info:Loading Dataset info from /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n",
"INFO:nlp.builder:Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8)\n",
"INFO:nlp.builder:Constructing Dataset for split validation[:5%], from /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8\n",
"INFO:nlp.utils.info_utils:All the checksums matched successfully for post processing resources\n"
]
}
],
"source": [
"# load train and validation data\n",
"train_dataset = nlp.load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"train\")\n",
"val_dataset = nlp.load_dataset(\"cnn_dailymail\", \"3.0.0\", split=\"validation[:5%]\")\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:nlp.load:Checking /root/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py for additional imports.\n",
"INFO:filelock:Lock 140606158875056 acquired on /root/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py.lock\n",
"INFO:nlp.load:Found main folder for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/metrics/rouge\n",
"INFO:nlp.load:Found specific version folder for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1\n",
"INFO:nlp.load:Found script file from https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py to /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1/rouge.py\n",
"INFO:nlp.load:Couldn't find dataset infos file at https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/dataset_infos.json\n",
"INFO:nlp.load:Found metadata file for metric https://s3.amazonaws.com/datasets.huggingface.co/nlp/metrics/rouge/rouge.py at /sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/nlp/metrics/rouge/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1/rouge.json\n",
"INFO:filelock:Lock 140606158875056 released on /root/.cache/huggingface/datasets/5ecb6e4b474317b41ae1fe5d702d1af8d86d452f0b1d70f77a12f6f014ded6ac.35bc2c477aa456d2f589656477ccb0b463c21cdfb83a9de86d63de8560a96d1b.py.lock\n",
"INFO:filelock:Lock 140605692902928 acquired on /root/.cache/huggingface/metrics/rouge/default/1.0.0/06783dbed5f6b6a5413f84d2a5f0d9dc9cb871f1aeb3787f2c90a8e3fe60b1c1/1-rouge-0.arrow.lock\n"
]
}
],
"source": [
"# load rouge for validation\n",
"rouge = nlp.load_metric(\"rouge\", experiment_id=1)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"encoder_length = 512\n",
"decoder_length = 128\n",
"batch_size = 16\n"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# map data correctly\n",
"def map_to_encoder_decoder_inputs(batch): # Tokenizer will automatically set [BOS] <text> [EOS] \n",
" # use bert tokenizer here for encoder\n",
" inputs = bert_tokenizer(batch[\"article\"], padding=\"max_length\", truncation=True, max_length=encoder_length)\n",
" # force summarization <= 128\n",
" outputs = gpt2_tokenizer(batch[\"highlights\"], padding=\"max_length\", truncation=True, max_length=decoder_length)\n",
"\n",
" batch[\"input_ids\"] = inputs.input_ids\n",
" batch[\"attention_mask\"] = inputs.attention_mask\n",
" batch[\"decoder_input_ids\"] = outputs.input_ids\n",
" batch[\"labels\"] = outputs.input_ids.copy()\n",
" batch[\"decoder_attention_mask\"] = outputs.attention_mask\n",
"\n",
" # complicated list comprehension here because pad_token_id alone is not good enough to know whether label should be excluded or not\n",
" batch[\"labels\"] = [\n",
" [-100 if mask == 0 else token for mask, token in mask_and_tokens] for mask_and_tokens in [zip(masks, labels) for masks, labels in zip(batch[\"decoder_attention_mask\"], batch[\"labels\"])]\n",
" ]\n",
"\n",
" assert all([len(x) == encoder_length for x in inputs.input_ids])\n",
" assert all([len(x) == decoder_length for x in outputs.input_ids])\n",
"\n",
" return batch\n",
"\n",
"\n",
"def compute_metrics(pred):\n",
" labels_ids = pred.label_ids\n",
" pred_ids = pred.predictions\n",
"\n",
" # all unnecessary tokens are removed\n",
" pred_str = gpt2_tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n",
" labels_ids[labels_ids == -100] = gpt2_tokenizer.eos_token_id\n",
" label_str = gpt2_tokenizer.batch_decode(labels_ids, skip_special_tokens=True)\n",
"\n",
" rouge_output = rouge.compute(predictions=pred_str, references=label_str, rouge_types=[\"rouge2\"])[\"rouge2\"].mid\n",
"\n",
" return {\n",
" \"rouge2_precision\": round(rouge_output.precision, 4),\n",
" \"rouge2_recall\": round(rouge_output.recall, 4),\n",
" \"rouge2_fmeasure\": round(rouge_output.fmeasure, 4),\n",
" }\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:nlp.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cache-13d913db1e04ac617cf6323b5be63ae6.arrow\n",
"INFO:nlp.arrow_dataset:Set __getitem__(key) output type to torch for ['input_ids', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask', 'labels'] columns (when key is int or slice) and don't output other (un-formated) columns.\n",
"INFO:nlp.arrow_dataset:Loading cached processed dataset at /root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/d8c27f2d603e2864036d92b0ec379f081896f6c28605ffd2e194c42cd04d48d8/cache-c84479f093fba2dbcbc88f32a9900f77.arrow\n",
"INFO:nlp.arrow_dataset:Set __getitem__(key) output type to torch for ['input_ids', 'attention_mask', 'decoder_input_ids', 'decoder_attention_mask', 'labels'] columns (when key is int or slice) and don't output other (un-formated) columns.\n"
]
}
],
"source": [
"# make train dataset ready\n",
"train_dataset = train_dataset.map(\n",
" map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=[\"article\", \"highlights\"],\n",
")\n",
"train_dataset.set_format(\n",
" type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"decoder_input_ids\", \"decoder_attention_mask\", \"labels\"],\n",
")\n",
"\n",
"# same for validation dataset\n",
"val_dataset = val_dataset.map(\n",
" map_to_encoder_decoder_inputs, batched=True, batch_size=batch_size, remove_columns=[\"article\", \"highlights\"],\n",
")\n",
"val_dataset.set_format(\n",
" type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"decoder_input_ids\", \"decoder_attention_mask\", \"labels\"],\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"# set training arguments - these params are not really tuned, feel free to change\n",
"training_args = TrainingArguments(\n",
" output_dir=\"./\",\n",
" per_device_train_batch_size=batch_size,\n",
" per_device_eval_batch_size=batch_size,\n",
" #predict_from_generate=True,\n",
" #evaluate_during_training=True,\n",
" do_train=True,\n",
" do_eval=True,\n",
" logging_steps=1000,\n",
" save_steps=1000,\n",
" eval_steps=1000,\n",
" overwrite_output_dir=True,\n",
" warmup_steps=2000,\n",
" save_total_limit=10,\n",
" fp16=False,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:transformers.training_args:PyTorch: setting up devices\n",
"INFO:transformers.trainer:Automatic Weights & Biases logging enabled, to disable set os.environ[\"WANDB_DISABLED\"] = \"true\"\n"
]
},
{
"data": {
"text/html": [
"\n",
" Logging results to <a href=\"https://wandb.com\" target=\"_blank\">Weights & Biases</a> <a href=\"https://docs.wandb.com/integrations/jupyter.html\" target=\"_blank\">(Documentation)</a>.<br/>\n",
" Project page: <a href=\"https://app.wandb.ai/cahya/huggingface\" target=\"_blank\">https://app.wandb.ai/cahya/huggingface</a><br/>\n",
" Run page: <a href=\"https://app.wandb.ai/cahya/huggingface/runs/2q9d4if2\" target=\"_blank\">https://app.wandb.ai/cahya/huggingface/runs/2q9d4if2</a><br/>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:wandb.run_manager:system metrics and metadata threads started\n",
"INFO:wandb.run_manager:checking resume status, waiting at most 10 seconds\n",
"INFO:wandb.run_manager:resuming run from id: UnVuOnYxOjJxOWQ0aWYyOmh1Z2dpbmdmYWNlOmNhaHlh\n",
"INFO:wandb.run_manager:upserting run before process can begin, waiting at most 10 seconds\n",
"INFO:wandb.run_manager:saving pip packages\n",
"INFO:wandb.run_manager:initializing streaming files api\n",
"INFO:wandb.run_manager:unblocking file change observer, beginning sync with W&B servers\n",
"INFO:wandb.run_manager:shutting down system stats and metadata service\n",
"INFO:wandb.run_manager:file/dir modified: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/config.yaml\n",
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/requirements.txt\n",
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-summary.json\n",
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-history.jsonl\n",
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-metadata.json\n",
"INFO:wandb.run_manager:file/dir created: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-events.jsonl\n",
"INFO:wandb.run_manager:stopping streaming files and file change observer\n",
"INFO:wandb.run_manager:file/dir modified: /root/Work/language-modeling-private/Transformers/Tasks/wandb/run-20200908_110903-2q9d4if2/wandb-metadata.json\n"
]
}
],
"source": [
"# instantiate trainer\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" compute_metrics=compute_metrics,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=val_dataset,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"INFO:transformers.trainer:***** Running training *****\n",
"INFO:transformers.trainer: Num examples = 287113\n",
"INFO:transformers.trainer: Num Epochs = 3\n",
"INFO:transformers.trainer: Instantaneous batch size per device = 16\n",
"INFO:transformers.trainer: Total train batch size (w. parallel, distributed & accumulation) = 128\n",
"INFO:transformers.trainer: Gradient Accumulation steps = 1\n",
"INFO:transformers.trainer: Total optimization steps = 6732\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "b5b54269b5604a078896beb13cb96452",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Epoch', max=3.0, style=ProgressStyle(description_width='i…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "757b3693158a4c278138b542e2ffd69a",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(FloatProgress(value=0.0, description='Iteration', max=2244.0, style=ProgressStyle(description_w…"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/opt/conda/conda-bld/pytorch_1591914895884/work/torch/csrc/utils/tensor_numpy.cpp:141: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"\n"
]
},
{
"ename": "TypeError",
"evalue": "Caught TypeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py\", line 60, in _worker\n output = module(*input, **kwargs)\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 550, in __call__\n result = self.forward(*input, **kwargs)\n File \"/root/Work/transformers/src/transformers/modeling_encoder_decoder.py\", line 290, in forward\n **kwargs_decoder,\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 550, in __call__\n result = self.forward(*input, **kwargs)\nTypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'\n",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-16-c108335b43e0>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# start training\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mtrainer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/Work/transformers/src/transformers/trainer.py\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(self, model_path)\u001b[0m\n\u001b[1;32m 497\u001b[0m \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 498\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 499\u001b[0;31m \u001b[0mtr_loss\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_training_step\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 500\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 501\u001b[0m if (step + 1) % self.args.gradient_accumulation_steps == 0 or (\n",
"\u001b[0;32m~/Work/transformers/src/transformers/trainer.py\u001b[0m in \u001b[0;36m_training_step\u001b[0;34m(self, model, inputs, optimizer)\u001b[0m\n\u001b[1;32m 620\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"mems\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_past\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 621\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 622\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 623\u001b[0m \u001b[0mloss\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;31m# model outputs are always tuple in transformers (see doc)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 624\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 548\u001b[0m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 549\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 550\u001b[0;31m \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 551\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mhook\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 552\u001b[0m \u001b[0mhook_result\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mhook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mresult\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 153\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 154\u001b[0m \u001b[0mreplicas\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplicate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 155\u001b[0;31m \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 156\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moutput_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 157\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/data_parallel.py\u001b[0m in \u001b[0;36mparallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 163\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 164\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 165\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mparallel_apply\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdevice_ids\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mreplicas\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 166\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 167\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0mgather\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moutput_device\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py\u001b[0m in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 83\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mresults\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 84\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mExceptionWrapper\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 85\u001b[0;31m \u001b[0moutput\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreraise\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 86\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 87\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0moutputs\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;32m/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/_utils.py\u001b[0m in \u001b[0;36mreraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 393\u001b[0m \u001b[0;31m# (https://bugs.python.org/issue2651), so we work around it.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 394\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mKeyErrorMessage\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 395\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexc_type\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m: Caught TypeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/parallel/parallel_apply.py\", line 60, in _worker\n output = module(*input, **kwargs)\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 550, in __call__\n result = self.forward(*input, **kwargs)\n File \"/root/Work/transformers/src/transformers/modeling_encoder_decoder.py\", line 290, in forward\n **kwargs_decoder,\n File \"/sysadmin/wirawan/miniconda3/envs/transformers-cuda9/lib/python3.7/site-packages/torch/nn/modules/module.py\", line 550, in __call__\n result = self.forward(*input, **kwargs)\nTypeError: forward() got an unexpected keyword argument 'encoder_hidden_states'\n"
]
}
],
"source": [
"# start training\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment