Last active
February 25, 2021 18:54
-
-
Save jessvb/a66ce4689c4d92a107384947a71e6ecd to your computer and use it in GitHub Desktop.
Create a story alongside the generative text model, GPT-2. You write one line, GPT-2 writes the next, etc. until you have a full story ππ
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
## This is a demo of text generation with GPT-2 (Generative pretrained transformer) | |
## Copy this file to a working directory to which you have write acess and | |
## run it with python3 using the shell command | |
## python3 gpt2_no_finetuning.py | |
## WARNING: This program can generate uncannily realistic output. But the output | |
## can also be biased, controversial, and obscene. If you distribute the program or | |
## the reults to others, include appropriate disclaimers. | |
import os | |
os.system('pip install tensorflow==1.15') | |
os.system('pip install gpt-2-simple') | |
os.system('mkdir -p checkpoint/') | |
os.system('mkdir -p models/') | |
import gpt_2_simple as gpt2 | |
import os | |
import requests | |
import tensorflow as tf | |
model_name = "124M" | |
# model_name = "anne_50" | |
path = os.path.join("models", model_name) | |
if not os.path.isdir(path): | |
print(f"Downloading {model_name} model...") | |
gpt2.download_gpt2(model_name=model_name) | |
# os.system('cp -r models/124M checkpoint/run1') | |
os.system(f'cp -r {path} checkpoint/run1') | |
sess = gpt2.start_tf_sess() | |
gpt2.load_gpt2(sess) | |
############################### Story Generation ############################### | |
gen_length = 15 | |
print("Type in a sentence to start the story (or just hit enter for an automatic prompt):") | |
prefix = input() | |
if (prefix == "" or prefix == "\n" or prefix == None): | |
prefix = "Once upon a time " | |
print("Starting generation. \n Takes about 30 seconds without a GPU\n ============= \n") | |
story = gpt2.generate(sess, length=gen_length, top_k=10, prefix=prefix, return_as_list=True)[0] | |
print(f"\n ============= \n{story}") | |
while prefix != "exit": | |
print("Type in a sentence to continue the story (type 'exit' to finish writing your story):") | |
prefix = input() | |
if(prefix != "exit"): | |
print("Starting generation. \n Takes about 30 seconds without a GPU\n ============= \n") | |
## You can run this command several times to get different outputs | |
# gpt2.generate(sess, length=500, top_k=10, prefix=prefix) | |
story += gpt2.generate(sess, length=gen_length, top_k=10, prefix=prefix, return_as_list=True)[0] | |
print(f"\n ============= \n{story}") | |
print(f"Thanks for writing a story with GPT-2! Here it is:\n ============= \n{story}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment