Skip to content

Instantly share code, notes, and snippets.

@devig
Created April 14, 2024 08:46
Show Gist options
  • Save devig/e529f1920dec0b61a5dc5ef47ecc3ae8 to your computer and use it in GitHub Desktop.
Save devig/e529f1920dec0b61a5dc5ef47ecc3ae8 to your computer and use it in GitHub Desktop.
import torch
from transformers import AutoTokenizer, pipeline, AutoModelForCausalLM
import time
def custom_chat_template(chat):
# This template concatenates role and content, each on new line
return "\n".join(f"{turn['role']}: {turn['content']}" for turn in chat)
def generate_response(system_prompt, user_prompt):
try:
print("Initializing the model and tokenizer...")
model_name = "macadeliccc/WestLake-7B-v2-laser-truthy-dpo"
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Apply the custom chat template to the tokenizer
tokenizer.chat_template = custom_chat_template
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16)
# Explicitly set the device to CPU
device = torch.device("cpu")
print(f"Using device: {device}")
print("Setting up the text generation pipeline...")
text_gen_pipeline = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device=-1 # -1 indicates to use CPU
)
print("Preparing the chat template...")
chat = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
prompt = tokenizer.chat_template(chat) # Using the custom chat template
print(f"Generated prompt: {prompt[:50]}...") # Print first 50 characters of the prompt
print("Generating response...")
start_time = time.time()
outputs = text_gen_pipeline(prompt, max_new_tokens=256, do_sample=True, temperature=0.1, top_k=50, top_p=0.95)
print(f"Response generated in {time.time() - start_time:.2f} seconds.")
return outputs[0]['generated_text']
except Exception as e:
print(f"An error occurred: {str(e)}")
return None
# Example usage
system_prompt = "Below is an instruction that describes a task. Write a response that appropriately completes the request."
user_prompt = "Give me top 3 jokes"
response = generate_response(system_prompt, user_prompt)
print(response if response else "No response generated.")
@devig
Copy link
Author

devig commented Apr 14, 2024

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment