Skip to content

Instantly share code, notes, and snippets.

@enpassanty
Last active February 10, 2022 17:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save enpassanty/4bac3d1ed5d8995ac3c48050b0c2aca1 to your computer and use it in GitHub Desktop.
Save enpassanty/4bac3d1ed5d8995ac3c48050b0c2aca1 to your computer and use it in GitHub Desktop.
HF pipeline error
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "Rrd6BdmQqJuX",
"outputId": "d9a0b503-a19d-4464-ac78-d175846551d3"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[K |████████████████████████████████| 3.5 MB 8.9 MB/s \n",
"\u001b[K |████████████████████████████████| 6.8 MB 58.3 MB/s \n",
"\u001b[K |████████████████████████████████| 67 kB 3.1 MB/s \n",
"\u001b[K |████████████████████████████████| 895 kB 65.7 MB/s \n",
"\u001b[K |████████████████████████████████| 596 kB 23.7 MB/s \n",
"\u001b[?25hMounted at /content/gdrive\n"
]
}
],
"source": [
"! pip install transformers tokenizers --quiet\n",
"from google.colab import drive\n",
"drive.mount('/content/gdrive')"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "y-YuscSIfcp2"
},
"outputs": [],
"source": [
"vocab_size = 50000\n",
"tokenizer_folder = \"./gdrive/MyDrive/nlp-chart/chart_bpe_tokenizer/\"\n",
"model_folder = './gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/'"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "QS3BGLO3xGCM",
"outputId": "b25e7b85-e77c-40cd-9ec1-58ab5af261ea"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 4min 45s, sys: 7.47 s, total: 4min 52s\n",
"Wall time: 1min 19s\n"
]
}
],
"source": [
"%%time \n",
"from tokenizers import ByteLevelBPETokenizer\n",
"\n",
"# Initialize a tokenizer\n",
"tokenizer = ByteLevelBPETokenizer(lowercase=True)\n",
"# Customize training\n",
"tokenizer.train(files='./gdrive/MyDrive/nlp-chart/train charts.txt',\n",
" vocab_size=vocab_size, \n",
" min_frequency=5,\n",
" show_progress=True,\n",
" special_tokens=[\n",
" \"<s>\",\n",
" \"<pad>\",\n",
" \"</s>\",\n",
" \"<unk>\",\n",
" \"<mask>\",\n",
"])\n",
"#Save the Tokenizer to disk\n",
"tokenizer.save_model(tokenizer_folder)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "mZhVX5clhWQw",
"outputId": "1ded3c90-e868-4a6c-9675-ce71c9c64dbe"
},
"outputs": [
{
"data": {
"text/plain": [
"('./gdrive/MyDrive/nlp-chart/chart_bpe_tokenizer/tokenizer_config.json',\n",
" './gdrive/MyDrive/nlp-chart/chart_bpe_tokenizer/special_tokens_map.json',\n",
" './gdrive/MyDrive/nlp-chart/chart_bpe_tokenizer/vocab.json',\n",
" './gdrive/MyDrive/nlp-chart/chart_bpe_tokenizer/merges.txt',\n",
" './gdrive/MyDrive/nlp-chart/chart_bpe_tokenizer/added_tokens.json',\n",
" './gdrive/MyDrive/nlp-chart/chart_bpe_tokenizer/tokenizer.json')"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from transformers import RobertaTokenizerFast\n",
"\n",
"tokenizer = RobertaTokenizerFast.from_pretrained(tokenizer_folder, return_special_tokens_mask=True, max_length=512) \n",
"\n",
"tokenizer.save_pretrained(tokenizer_folder) "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "7sksvk5ezTD0"
},
"outputs": [],
"source": [
"from transformers import RobertaConfig\n",
"\n",
"config = RobertaConfig(\n",
" vocab_size=vocab_size,\n",
" max_position_embeddings=514,\n",
" num_attention_heads=12,\n",
" num_hidden_layers=6,\n",
" type_vocab_size=1,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "_1TEpJpTzmtl"
},
"outputs": [],
"source": [
"from transformers import RobertaForMaskedLM\n",
"\n",
"model = RobertaForMaskedLM(config=config)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "alM5RIwDzqW8",
"outputId": "4018af4e-4861-4b77-dbd1-3140ce82ed45"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/transformers/data/datasets/language_modeling.py:125: FutureWarning: This dataset will be removed from the library soon, preprocessing should be handled with the 🤗 Datasets library. You can have a look at this example script for pointers: https://github.com/huggingface/transformers/blob/master/examples/pytorch/language-modeling/run_mlm.py\n",
" FutureWarning,\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 6min 49s, sys: 15.6 s, total: 7min 4s\n",
"Wall time: 1min 58s\n"
]
}
],
"source": [
"%%time\n",
"from transformers import LineByLineTextDataset\n",
"\n",
"dataset = LineByLineTextDataset(\n",
" tokenizer=tokenizer,\n",
" file_path='./gdrive/MyDrive/nlp-chart/train charts.txt',\n",
" block_size=256,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "VGaNV6gOz5jD"
},
"outputs": [],
"source": [
"from transformers import DataCollatorForLanguageModeling\n",
"\n",
"data_collator = DataCollatorForLanguageModeling(\n",
" tokenizer=tokenizer, mlm=True, mlm_probability=0.15\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "HY1s41Inz8RS",
"outputId": "8474b173-63c7-48db-ae1b-d3d3d25a755b"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using amp half precision backend\n"
]
}
],
"source": [
"# Wall time: 1h 40min 25s - fp16=False\n",
"from transformers import Trainer, TrainingArguments\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir= model_folder,\n",
" overwrite_output_dir=True,\n",
" num_train_epochs=50,\n",
" per_device_train_batch_size=32,\n",
" eval_steps=1000,\n",
" save_steps=2000,\n",
" save_total_limit=1,\n",
" prediction_loss_only=True,\n",
" fp16=True\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" data_collator=data_collator,\n",
" train_dataset=dataset,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"background_save": true,
"base_uri": "https://localhost:8080/",
"height": 1000
},
"id": "kEbMCmb6z_R9",
"outputId": "d2d5bfed-0815-45c3-af6b-37d7673e1ac5"
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/usr/local/lib/python3.7/dist-packages/transformers/optimization.py:309: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use thePyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
" FutureWarning,\n",
"***** Running training *****\n",
" Num examples = 28502\n",
" Num Epochs = 50\n",
" Instantaneous batch size per device = 32\n",
" Total train batch size (w. parallel, distributed & accumulation) = 32\n",
" Gradient Accumulation steps = 1\n",
" Total optimization steps = 44550\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" \n",
" <progress value='43501' max='44550' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" [43501/44550 8:30:03 < 12:18, 1.42 it/s, Epoch 48.82/50]\n",
" </div>\n",
" <table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: left;\">\n",
" <th>Step</th>\n",
" <th>Training Loss</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <td>500</td>\n",
" <td>5.653200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1000</td>\n",
" <td>4.251700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>1500</td>\n",
" <td>3.606700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2000</td>\n",
" <td>3.214100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>2500</td>\n",
" <td>2.938400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3000</td>\n",
" <td>2.725400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>3500</td>\n",
" <td>2.539300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4000</td>\n",
" <td>2.321800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>4500</td>\n",
" <td>2.115200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5000</td>\n",
" <td>1.967400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>5500</td>\n",
" <td>1.842000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6000</td>\n",
" <td>1.740900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>6500</td>\n",
" <td>1.653800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7000</td>\n",
" <td>1.588300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>7500</td>\n",
" <td>1.528300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8000</td>\n",
" <td>1.476700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>8500</td>\n",
" <td>1.425400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9000</td>\n",
" <td>1.399000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>9500</td>\n",
" <td>1.350500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10000</td>\n",
" <td>1.321800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>10500</td>\n",
" <td>1.287000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11000</td>\n",
" <td>1.259600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>11500</td>\n",
" <td>1.226800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12000</td>\n",
" <td>1.192300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>12500</td>\n",
" <td>1.186700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13000</td>\n",
" <td>1.159600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>13500</td>\n",
" <td>1.135700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14000</td>\n",
" <td>1.115100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>14500</td>\n",
" <td>1.103700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15000</td>\n",
" <td>1.085300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>15500</td>\n",
" <td>1.063700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16000</td>\n",
" <td>1.052800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>16500</td>\n",
" <td>1.032800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17000</td>\n",
" <td>1.024100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>17500</td>\n",
" <td>1.006600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18000</td>\n",
" <td>1.001100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>18500</td>\n",
" <td>0.981500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19000</td>\n",
" <td>0.974000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>19500</td>\n",
" <td>0.970600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20000</td>\n",
" <td>0.946800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>20500</td>\n",
" <td>0.945900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21000</td>\n",
" <td>0.930600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>21500</td>\n",
" <td>0.924300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22000</td>\n",
" <td>0.911800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>22500</td>\n",
" <td>0.901100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23000</td>\n",
" <td>0.890700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>23500</td>\n",
" <td>0.876900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24000</td>\n",
" <td>0.872500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>24500</td>\n",
" <td>0.859500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25000</td>\n",
" <td>0.857200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>25500</td>\n",
" <td>0.847200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26000</td>\n",
" <td>0.844200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>26500</td>\n",
" <td>0.826900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27000</td>\n",
" <td>0.826400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>27500</td>\n",
" <td>0.813700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28000</td>\n",
" <td>0.811400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>28500</td>\n",
" <td>0.808900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29000</td>\n",
" <td>0.795900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>29500</td>\n",
" <td>0.799500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30000</td>\n",
" <td>0.792800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>30500</td>\n",
" <td>0.790600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31000</td>\n",
" <td>0.784600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>31500</td>\n",
" <td>0.776200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>32000</td>\n",
" <td>0.777300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>32500</td>\n",
" <td>0.769700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>33000</td>\n",
" <td>0.770300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>33500</td>\n",
" <td>0.761700</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34000</td>\n",
" <td>0.764800</td>\n",
" </tr>\n",
" <tr>\n",
" <td>34500</td>\n",
" <td>0.761200</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35000</td>\n",
" <td>0.756500</td>\n",
" </tr>\n",
" <tr>\n",
" <td>35500</td>\n",
" <td>0.750000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>36000</td>\n",
" <td>0.745600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>36500</td>\n",
" <td>0.743400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>37000</td>\n",
" <td>0.745600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>37500</td>\n",
" <td>0.746600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38000</td>\n",
" <td>0.741000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>38500</td>\n",
" <td>0.738000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>39000</td>\n",
" <td>0.733900</td>\n",
" </tr>\n",
" <tr>\n",
" <td>39500</td>\n",
" <td>0.735100</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40000</td>\n",
" <td>0.736300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>40500</td>\n",
" <td>0.735000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>41000</td>\n",
" <td>0.727600</td>\n",
" </tr>\n",
" <tr>\n",
" <td>41500</td>\n",
" <td>0.728000</td>\n",
" </tr>\n",
" <tr>\n",
" <td>42000</td>\n",
" <td>0.725400</td>\n",
" </tr>\n",
" <tr>\n",
" <td>42500</td>\n",
" <td>0.728300</td>\n",
" </tr>\n",
" <tr>\n",
" <td>43000</td>\n",
" <td>0.724000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table><p>"
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-2000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-2000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-2000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-8000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-4000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-4000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-4000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-2000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-6000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-6000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-6000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-4000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-8000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-8000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-8000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-6000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-10000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-10000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-10000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-8000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-12000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-12000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-12000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-10000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-14000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-14000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-14000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-12000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-16000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-16000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-16000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-14000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-18000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-18000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-18000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-16000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-20000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-20000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-20000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-18000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-22000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-22000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-22000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-20000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-24000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-24000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-24000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-22000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-26000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-26000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-26000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-24000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-28000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-28000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-28000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-26000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-30000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-30000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-30000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-28000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-32000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-32000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-32000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-30000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-34000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-34000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-34000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-32000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-36000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-36000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-36000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-34000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-38000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-38000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-38000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-36000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-40000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-40000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-40000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-38000] due to args.save_total_limit\n",
"Saving model checkpoint to ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-42000\n",
"Configuration saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-42000/config.json\n",
"Model weights saved in ./gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-42000/pytorch_model.bin\n",
"Deleting older checkpoint [gdrive/MyDrive/nlp-chart/roberta_mlm_2_6_2022/checkpoint-40000] due to args.save_total_limit\n"
]
}
],
"source": [
"%%time\n",
"trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "T3ixRemT791D"
},
"outputs": [],
"source": [
"from transformers import pipeline\n",
"\n",
"fill_mask = pipeline(\n",
" \"fill-mask\",\n",
" model= model_folder+'checkpoint-8000',\n",
" tokenizer= tokenizer_folder\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "URzWNLWX8YlR"
},
"outputs": [],
"source": [
"predictions = fill_mask(\"congestive <mask> failure\")\n",
"\n",
"for prediction in predictions:\n",
" print(prediction['sequence'].strip('<s>').strip('</s>'), end='\\t--- ')\n",
" print(f\"{round(100*prediction['score'],2)}% confidence\")"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"machine_shape": "hm",
"name": "roberta mlm test 2 -7-22.ipynb",
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment