Skip to content

Instantly share code, notes, and snippets.

@MarcSkovMadsen
Last active October 30, 2021 07:46
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save MarcSkovMadsen/eae998fbcb299fae9e92ab0089e7eff8 to your computer and use it in GitHub Desktop.
Save MarcSkovMadsen/eae998fbcb299fae9e92ab0089e7eff8 to your computer and use it in GitHub Desktop.
Hugging Face GPT2 Transformer Example
import logging
import tensorflow as tf
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
from transformers import tf_top_k_top_p_filtering
import panel as pn
pn.extension()
import panel.widgets as pnw
from math import pi
from bokeh.plotting import figure
from bokeh.models import ColumnDataSource
from transformers import TFGPT2LMHeadModel, GPT2Tokenizer
pn.extension(sizing_mode="stretch_width")
ACCENT_BASE_COLOR = "#f37736"
THEME = pn.state.session_args.get("theme", [b"default"])[0].decode()
if THEME=="dark":
GENERATED_TEXT_BACKGROUND = "#181818"
else:
GENERATED_TEXT_BACKGROUND = "#f0f0f0"
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
logger.info("Setting Tensorflow random seed to 1234")
tf.random.set_seed(1234)
# tokenizer and model for word generation
# should only be loaded once as loading them takes ~5s.
# Thus we use caching to share between sessions
logger.info("Loading gpt2 tokenized ...")
if "gpt2-tokenizer" in pn.state.cache:
tokenizer = pn.state.cache["gpt2-tokenizer"]
else:
tokenizer = pn.state.cache["gpt2-tokenizer"] = GPT2Tokenizer.from_pretrained("gpt2")
logger.info("Loading gpt2 model ...")
if "gpt2-model" in pn.state.cache:
model = pn.state.cache["gpt2-model"]
else:
model = pn.state.cache["gpt2-model"] = TFGPT2LMHeadModel.from_pretrained(
"gpt2", pad_token_id=tokenizer.eos_token_id
)
def get_pred(
sequence="Please input some text",
model=model,
tokenizer=tokenizer,
temperature=0.7,
top_k=50,
top_p=0.95,
):
"""Returns the predicted words and logits to derive the probabilities for each prediction"""
tf.random.set_seed(1234)
input_ids = tokenizer.encode(sequence, return_tensors="tf")
# get logits of last hidden state
next_token_logits = model(input_ids)[0][:, -1, :]
# apply a temperature coefficient and filter
next_token_logits = next_token_logits / temperature
# filter
filtered_next_token_logits = tf_top_k_top_p_filtering(next_token_logits, top_k, top_p)
# sample
next_token = tf.random.categorical(filtered_next_token_logits, dtype=tf.int32, num_samples=1)
resulting_string = tokenizer.decode(next_token.numpy().tolist()[0])
return resulting_string, filtered_next_token_logits
def get_plot_data(filtered_next_token_logits):
"""Returns the data ready for plotting in Bokeh"""
probabilities = tf.nn.softmax(filtered_next_token_logits)
k = tf.math.count_nonzero(probabilities).numpy()
k = min(100, k)
probs_filter = tf.math.top_k(probabilities[0], k)
probability_list = probs_filter.values.numpy()
word_list = list()
for i in probs_filter.indices.numpy():
word_list.append(tokenizer.decode([i]))
return probability_list, word_list
def clean_plot_data(word_list, probability_list):
"""Prepares the data for plotting
- Aggregates words that appear multiple times
"""
result = {}
for w, p in zip(word_list, probability_list):
if w not in result:
result[w] = p
else:
result[w] += p
sorted_keys = sorted(result, key=result.get, reverse=True)
result = {k: result[k] for k in sorted_keys}
return list(result.keys()), list(result.values())
def get_plot(word_list, probability_list):
"""Returns a Bokeh plot"""
word_list, probability_list = clean_plot_data(word_list, probability_list)
source = ColumnDataSource(data=dict(word_list=word_list, probability_list=probability_list))
plot = figure(
x_range=source.data["word_list"],
height=250,
title="Probabilities",
toolbar_location=None,
tools="",
)
plot.vbar(
x="word_list", top="probability_list", width=0.8, source=source, color=ACCENT_BASE_COLOR
)
plot.xaxis.major_label_orientation = pi / 2
return plot
logger.info("Creating Widgets and Panes")
temperature_pn = pnw.FloatSlider(name="Temperature", value=1.0, start=0.0, end=1.0, step=0.01)
top_k_pn = pnw.IntSlider(name="Top K", value=0, start=0, end=100)
top_p_pn = pnw.FloatSlider(name="Top p", value=1.0, start=0.0, end=1.0, step=0.01)
settings = pn.Column(temperature_pn, top_k_pn, top_p_pn)
text_input = pn.widgets.TextInput(value="Enter a string here...")
generated_text = pn.pane.HTML(
object=text_input.value, background=GENERATED_TEXT_BACKGROUND, min_height=200, sizing_mode="stretch_both"
)
text_input.link(generated_text, value="object")
predict_button = pn.widgets.Button(name="▶ Predict", button_type="primary")
text_part = pn.Column(text_input, predict_button, generated_text)
bokeh_plot = pn.pane.Bokeh(sizing_mode="stretch_both")
def predict(event=None):
"""Runs the prediction, updates widgets and panes"""
# bokeh_plot.loading = True
pred, filtered_next_token_logits = get_pred(
generated_text.object,
model,
tokenizer,
temperature_pn.value,
top_k_pn.value,
top_p_pn.value,
)
generated_text.object += pred
probabilities, word_list = get_plot_data(filtered_next_token_logits)
probability_list = probabilities.tolist()
bokeh_plot.object = get_plot(word_list, probability_list)
# bokeh_plot.loading = False
predict()
predict_button.on_click(predict)
auto_predict_callback = pn.state.add_periodic_callback(predict, period=1000, start=False)
def text_change_cb(event):
generated_text.object = event.new
text_input.param.watch(text_change_cb, "value")
panel_logo_pane = pn.pane.PNG(
"https://panel.holoviz.org/_static/logo_stacked.png",
link_url="https://panel.holoviz.org",
embed=False,
height=115,
margin=25,
sizing_mode="fixed",
)
hugging_face_pane = pn.pane.PNG(
"https://raw.githubusercontent.com/huggingface/transformers/master/docs/source/imgs/transformers_logo_name.png",
link_url="https://huggingface.co/",
embed=False,
height=115,
margin=25,
sizing_mode="fixed",
)
image_component = pn.layout.FlexBox(
panel_logo_pane, hugging_face_pane,
justify_content="center",
margin=25,
sizing_mode="stretch_both",
)
app = pn.template.FastListTemplate(
site="Awesome Panel",
title="Hugging Face Transformers",
sidebar=[
"# ⚙️ Parameters",
settings,
"# 🏃 Auto Predict",
pn.Param(auto_predict_callback.param, parameters=["period", "running"], show_name=False),
"""
# 🎓 Info
**GPT-2** is a large *transformer-based* language model with 1.5 billion parameters, trained on a
dataset of 8 million web pages.
GPT-2 is trained with a simple objective: **predict the next word**, given all
of the previous words within some text.""",
],
main=[image_component, text_part, bokeh_plot],
accent_base_color=ACCENT_BASE_COLOR,
header_background=ACCENT_BASE_COLOR,
)
logger.info("Serving the App")
app.servable()
@MarcSkovMadsen
Copy link
Author

MarcSkovMadsen commented Oct 30, 2021

Install

pip install panel==0.12.4 bokeh=2.4.1 tensorflow=2.6.0 transformers=4.12.2

Serve

panel serve gpt2_transformers.py

and open at http://localhost:5006/gpt2_transformers

Serve with live reload

panel serve gpt2_transformers.py --autoreload

Resources

@MarcSkovMadsen
Copy link
Author

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment