Skip to content

Instantly share code, notes, and snippets.

@lazarinastoy
Forked from chetanambi/streamlit_demo.py
Last active September 1, 2021 12:16
Show Gist options
  • Save lazarinastoy/d1fcb94e48b3607ec734182865196391 to your computer and use it in GitHub Desktop.
Save lazarinastoy/d1fcb94e48b3607ec734182865196391 to your computer and use it in GitHub Desktop.
import torch
import streamlit as st
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers import T5Tokenizer, T5ForConditionalGeneration
st.title('Text Summarization Demo')
st.markdown('Using BART and T5 transformer model')
model = st.selectbox('Select the model', ('BART', 'T5'))
if model == 'BART':
_num_beams = 4
_no_repeat_ngram_size = 3
_length_penalty = 1
_min_length = 12
_max_length = 128
_early_stopping = True
else:
_num_beams = 4
_no_repeat_ngram_size = 3
_length_penalty = 2
_min_length = 30
_max_length = 200
_early_stopping = True
col1, col2, col3 = st.beta_columns(3)
_num_beams = col1.number_input("num_beams", value=_num_beams)
_no_repeat_ngram_size = col2.number_input("no_repeat_ngram_size", value=_no_repeat_ngram_size)
_length_penalty = col3.number_input("length_penalty", value=_length_penalty)
col1, col2, col3 = st.beta_columns(3)
_min_length = col1.number_input("min_length", value=_min_length)
_max_length = col2.number_input("max_length", value=_max_length)
_early_stopping = col3.number_input("early_stopping", value=_early_stopping)
text = st.text_area('Text Input')
def run_model(input_text):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if model == "BART":
bart_model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
bart_tokenizer = BartTokenizer.from_pretrained("facebook/bart-base")
input_text = str(input_text)
input_text = ' '.join(input_text.split())
input_tokenized = bart_tokenizer.encode(input_text, return_tensors='pt').to(device)
summary_ids = bart_model.generate(input_tokenized,
num_beams=_num_beams,
no_repeat_ngram_size=_no_repeat_ngram_size,
length_penalty=_length_penalty,
min_length=_min_length,
max_length=_max_length,
early_stopping=_early_stopping)
output = [bart_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in
summary_ids]
st.write('Summary')
st.success(output[0])
else:
t5_model = T5ForConditionalGeneration.from_pretrained("t5-base")
t5_tokenizer = T5Tokenizer.from_pretrained("t5-base")
input_text = str(input_text).replace('\n', '')
input_text = ' '.join(input_text.split())
input_tokenized = t5_tokenizer.encode(input_text, return_tensors="pt").to(device)
summary_task = torch.tensor([[21603, 10]]).to(device)
input_tokenized = torch.cat([summary_task, input_tokenized], dim=-1).to(device)
summary_ids = t5_model.generate(input_tokenized,
num_beams=_num_beams,
no_repeat_ngram_size=_no_repeat_ngram_size,
length_penalty=_length_penalty,
min_length=_min_length,
max_length=_max_length,
early_stopping=_early_stopping)
output = [t5_tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=False) for g in
summary_ids]
st.write('Summary')
st.success(output[0])
if st.button('Submit'):
run_model(text)
st.write('Author: Chetan Ambi, Code Source: [github] (https://gist.github.com/chetanambi/d54d83443df5f131c6bd0ca5dffa5742#file-streamlit_demo-py)')
@lazarinastoy
Copy link
Author

Cheetan Ambi has shared the code for a really cool Streamlit app, which gives the user a choice between which model to use - BART or T5 transformer. 
You can find his full tutorial with Python code in this article.
This approach is amazing for when you want to summarize individual snippets of text.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment