Skip to content

Instantly share code, notes, and snippets.

@nbertagnolli
Last active March 25, 2020 19:58
Show Gist options
  • Save nbertagnolli/9cbd5aab357df25bed332f29ecd29b94 to your computer and use it in GitHub Desktop.
Save nbertagnolli/9cbd5aab357df25bed332f29ecd29b94 to your computer and use it in GitHub Desktop.
Use gpt2_simple to fine tune on a new dataset
import gpt_2_simple as gpt2
# The name of the pretrained GPT2 model we want to use it can be 117M, 124M, or 355M
# 124M is about as big as I can fit on my 1080Ti.
model_name = "124M"
# Download the model if it is not present
if not os.path.isdir(os.path.join("models", model_name)):
print(f"Downloading {model_name} model...")
gpt2.download_gpt2(model_name=model_name)
# Start a Tensorflow session to pass to gpt2_simple
sess = gpt2.start_tf_sess()
# Define the number of steps we want our model to take we want this to be such that
# we only pass over the data set 1-2 times to avoid overfitting.
num_steps = 100
# This is the path to the text file we want to use for training.
text_path = "proverbs.txt"
# Pass in the session and the
gpt2.finetune(sess,
text_path,
model_name=model_name,
steps=num_steps
)
gpt2.generate(sess)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment