Skip to content

Instantly share code, notes, and snippets.

@float-trip
Created April 27, 2024 00:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save float-trip/a01ba8d38798b630c97bc1d9ae09e5f7 to your computer and use it in GitHub Desktop.
Save float-trip/a01ba8d38798b630c97bc1d9ae09e5f7 to your computer and use it in GitHub Desktop.
import code
import random
import re
import torch
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
LogitsProcessor,
LogitsProcessorList,
set_seed,
)
class StopAfterPlusIsGenerated(LogitsProcessor):
def __init__(self, plus_token_id, eos_token_id):
super().__init__()
self.plus_token_id = plus_token_id
self.eos_token_id = eos_token_id
def __call__(self, input_ids, scores):
forced_eos = torch.full((scores.size(1),), -float("inf")).to(
device=scores.device, dtype=scores.dtype
)
forced_eos[self.eos_token_id] = 0
scores[input_ids[:, -1] == self.plus_token_id] = forced_eos
return scores
model = AutoModelForCausalLM.from_pretrained(
"float-trip/drama-llama-3",
torch_dtype=torch.bfloat16,
attn_implementation="flash_attention_2",
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained("float-trip/drama-llama-3")
tokenizer.pad_token = tokenizer.eos_token
logits_processor = LogitsProcessorList(
[StopAfterPlusIsGenerated(482, tokenizer.eos_token_id)]
)
def gen(prompt, stop_after_plus=True):
seed = random.randint(0, 100000)
set_seed(seed)
encoded = tokenizer(
prompt,
return_tensors="pt",
padding=True,
truncation=True,
max_length=4096,
).to("cuda")
gen_tokens = model.generate(
input_ids=encoded.input_ids,
attention_mask=encoded.attention_mask,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.90,
use_cache=True,
max_new_tokens=512,
logits_processor=logits_processor if stop_after_plus else None,
)
lines = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True)[0][
len(prompt) :
].split("\n")[:-1]
return "\n".join(lines).strip()
class Chat:
def __init__(self, user):
self.msgs = [":marseywave:"]
self.user = user
def chat(self, msg):
self.msgs.append(msg)
prompt = self.build_prompt()
g = gen(prompt)
reply = self.extract_comment(g)
print(reply)
self.msgs.append(reply)
def build_prompt(self):
if len(self.msgs) % 2 == 1:
print("Prompt must be built after a user message")
return
prompt = "[Post] [Date] 05/2024 [Hole] N/A [Author] ChattyMatty [Title] what's going on guys [URL] N/A [Votes] +15 / -1\n\nanyone wanna talk\n\n[Comments]\n\n"
for i, msg in enumerate(self.msgs + [""]):
indent = " " * i
msg = "\n".join([indent + line for line in msg.strip().split("\n")])
if i % 2 == 0:
user = self.user
else:
user = "ChattyMatty"
prompt += f"{indent}{user} +2 / -0\n{msg}\n\n"
return prompt.strip()
def extract_comment(self, text):
pattern = r"\s*\S+\s*\+\d+"
parts = re.split(pattern, text, maxsplit=1)
body = parts[0] if parts else ""
return "\n".join([l.strip() for l in body.split("\n")])
def guidance(g_prompt, scale=5):
inputs = tokenizer(["[Post]"], return_tensors="pt").to("cuda")
neg_inputs = tokenizer([g_prompt], return_tensors="pt").to("cuda")
out = model.generate(
inputs["input_ids"],
guidance_scale=scale,
negative_prompt_ids=neg_inputs["input_ids"],
)
result = tokenizer.batch_decode(out, skip_special_tokens=True)[0]
return result
# Usage:
# x = Chat("X")
# x.chat("tell me a story")
# print(gen("[Post] [Date] 05/2024 [Hole] N/A [Author] SIMPSONIANTHIELITEDOOMER [Title] Here are my five favorite things about Marsey! [URL] N/A [Votes] +90 / -2", stop_after_plus=False))
code.interact(local=dict(globals(), **locals()))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment