Skip to content

Instantly share code, notes, and snippets.

@Norod
Created March 15, 2024 11:22
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Norod/6d0ea958b38168e0dd4e39c720894e63 to your computer and use it in GitHub Desktop.
Save Norod/6d0ea958b38168e0dd4e39c720894e63 to your computer and use it in GitHub Desktop.
A simple inference script for CohereForAI/aya-101 with Gradio based UI, RTL support and Streaming text
import torch
import gradio as gr
from threading import Thread
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, TextIteratorStreamer
checkpoint = "CohereForAI/aya-101"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint, device_map='auto', torch_dtype=torch.bfloat16)
text_title = checkpoint.replace("/", " - ") + ' - Gradio Demo'
########################################################################
# Settings
########################################################################
#Set the maximum number of tokens to generate
max_new_tokens = 250
#Set a the value of the repetition penalty
#The higher the value, the less repetitive the generated text will be
#Note that `repetition_penalty` has to be a strictly positive float
repetition_penalty = 1.8
#Set the text direction
#For languages that are written from right to left (RTL), set rtl to True
rtl = False
########################################################################
print(f"Settings: max_new_tokens = {max_new_tokens}, repetition_penalty = {repetition_penalty}, rtl = {rtl}")
if rtl:
text_title += " - RTL"
text_align = 'right'
css = "#output_text{direction: rtl} #input_text{direction: rtl}"
else:
text_align = 'left'
css = ""
def generate(text = ""):
print("Create streamer")
yield "[Please wait for an answer]"
decode_kwargs = dict(skip_special_tokens = True, clean_up_tokenization_spaces = True)
streamer = TextIteratorStreamer(tokenizer, timeout = 5., decode_kwargs = decode_kwargs)
inputs = tokenizer([text], add_special_tokens = False, return_tensors = "pt").to('cuda')
print(tokenizer.decode(inputs['input_ids'][0], skip_special_tokens=True))
generation_kwargs = dict(inputs, streamer = streamer, max_new_tokens=max_new_tokens, repetition_penalty=repetition_penalty)
print("Create thread")
thread = Thread(target = model.generate, kwargs = generation_kwargs)
thread.start()
generated_text = ""
for new_text in streamer:
if tokenizer.eos_token not in new_text:
new_text = new_text.replace(tokenizer.pad_token, "")
yield generated_text + new_text
print(new_text, end ="")
generated_text += new_text
else:
new_text = new_text.replace(tokenizer.eos_token, "\n")
print(new_text, end ="")
generated_text += new_text
return generated_text
return generated_text
demo = gr.Interface(
title = text_title,
fn = generate,
inputs = gr.Textbox(label = "Enter your prompt here", elem_id = "input_text", text_align = text_align, rtl = rtl),
outputs = gr.Textbox(type = "text", label = "Generated text will appear here", elem_id = "output_text", text_align = text_align, rtl = rtl),
css = css,
allow_flagging = 'never'
)
demo.queue()
demo.launch(debug = True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment