Skip to content

Instantly share code, notes, and snippets.

@ThilinaRajapakse
Last active September 19, 2021 06:46
Show Gist options
  • Save ThilinaRajapakse/0584b8fc04199e3ffefc4a3099cf4d36 to your computer and use it in GitHub Desktop.
Save ThilinaRajapakse/0584b8fc04199e3ffefc4a3099cf4d36 to your computer and use it in GitHub Desktop.
from simpletransformers.language_modeling import LanguageModelingModel
import logging
logging.basicConfig(level=logging.INFO)
transformers_logger = logging.getLogger("transformers")
transformers_logger.setLevel(logging.WARNING)
train_args = {
"reprocess_input_data": True,
"overwrite_output_dir": True,
"train_batch_size": 64,
"num_train_epochs": 3,
"mlm": False,
}
model = LanguageModelingModel('gpt2', 'gpt2', args=train_args)
model.train_model("data/train.txt", eval_file="data/test.txt")
model.eval_model("data/test.txt")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment