Skip to content

Instantly share code, notes, and snippets.

@fergusq
Last active February 23, 2023 14:51
Show Gist options
  • Save fergusq/d60b2fca8a1cf5e1eacbc8ee750564e6 to your computer and use it in GitHub Desktop.
Save fergusq/d60b2fca8a1cf5e1eacbc8ee750564e6 to your computer and use it in GitHub Desktop.
# (c) 2023 Iikka Hauhio
# Use freely but attribute me
import argparse
import readline
import traceback
from pynvml import *
import rich, rich.text, rich.panel, rich.live, rich.console
import torch
import transformers
# This function was copied from the Huggingface website
def print_gpu_utilization():
nvmlInit()
handle = nvmlDeviceGetHandleByIndex(0)
info = nvmlDeviceGetMemoryInfo(handle)
print(f"GPU memory occupied: {info.used//1024**2} MB.")
class Generator:
def __init__(self, model_name: str, int8=False):
self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
self.model = transformers.AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype="auto",
device_map="auto",
load_in_8bit=int8
)
self.temperature = 0.7
self.max_length = 250
self.repetition_penalty = 0.7
self.end = "</s>"
self.result = None
self.input_ids = None
self.token_ids = None
if not int8:
self.model.to("cuda")
def predict(
self,
prompt: str,
):
tokenized = self.tokenizer(prompt, return_tensors="pt")
self.input_ids = tokenized["input_ids"].to(self.model.device) # type: ignore
self.result = []
self.token_ids = self.input_ids
yield from self.complete()
def complete(self):
num = 0
with torch.inference_mode():
while num < self.max_length:
num += 1
output = self.model(self.token_ids)
logits = output.logits[0, -1, :]
if self.repetition_penalty > 0.:
for i, token_id in enumerate(reversed(self.token_ids)):
logits[token_id] *= 1.0-self.repetition_penalty*2**(-i*.1)
logits = logits.softmax(dim=-1)
logits.pow_(1/self.temperature)
next_token = logits.multinomial(1)[0]
t = self.tokenizer.decode(next_token)
yield t, logits[next_token].item()
if t == self.end:
break
self.result.append(next_token.item())
# print(self.result)
self.token_ids = torch.concat(
(self.input_ids, torch.tensor([self.result]).to(self.model.device)),
dim=-1
)
def get_colors(prob: float):
r = 255*2*prob if prob < 0.5 else 255
g = 255#*abs(2*prob-1)
b = 255*(2-2*prob) if prob >= 0.5 else 255
return r, g, b
def main():
console = rich.console.Console()
parser = argparse.ArgumentParser()
parser.add_argument("model", type=str)
parser.add_argument("--int8", action="store_true")
args = parser.parse_args()
generator = Generator(args.model, args.int8)
prev_text = ""
while True:
try:
prompt = input("> ").replace("\\n", "\n")
except EOFError:
break
except KeyboardInterrupt:
continue
try:
cont = False
if prompt.startswith("/"):
if " " in prompt:
i = prompt.index(" ")
command = prompt[:i+1]
arg = prompt[i:]
else:
command = prompt
arg = ""
if command == "/gpuinfo":
print_gpu_utilization()
continue
elif command == "/temp":
generator.temperature = float(arg)
continue
elif command == "/n":
generator.max_length = int(arg)
continue
elif command == "/scale":
for prob in [0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]:
r, g, b = get_colors(prob)
console.print(f"[rgb({int(r)},{int(g)},{int(b)})]{prob:.1f}")
continue
elif command == "/cont":
cont = True
prompt = prev_text
else:
print("commands: /gpuinfo /temp /n /scale")
continue
text = rich.text.Text(prompt)
text.stylize("i", 0, len(prompt))
panel = rich.panel.Panel.fit(text)
prev_text = prompt
with rich.live.Live(panel, refresh_per_second=8):
if cont:
tokens = generator.complete()
else:
tokens = generator.predict(prompt)
for token, prob in generator.predict(prompt):
try:
r, g, b = get_colors(prob)
if token != generator.end:
prev_text += token
text.append(token)
text.stylize(f"rgb({int(r)},{int(g)},{int(b)})", -len(token))
except KeyboardInterrupt:
break
except:
traceback.print_exc()
print()
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment