Skip to content

Instantly share code, notes, and snippets.

@catid
Last active March 29, 2024 23:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save catid/47eda7bd667b4a744697d93e1509089f to your computer and use it in GitHub Desktop.
Save catid/47eda7bd667b4a744697d93e1509089f to your computer and use it in GitHub Desktop.
DBRX on 3x 3090 GPUs
# conda create -n dbrx python=3.10 -y && conda activate dbrx
# pip install torch transformers tiktoken flash_attn bitsandbytes
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
tokenizer = AutoTokenizer.from_pretrained("SinclairSchneider/dbrx-instruct-quantization-fixed", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("SinclairSchneider/dbrx-instruct-quantization-fixed", device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True, load_in_4bit=True)
input_text = "What does it take to build a great LLM?"
messages = [{"role": "user", "content": input_text}]
input_ids = tokenizer.apply_chat_template(messages, return_dict=True, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda")
outputs = model.generate(**input_ids, max_new_tokens=200)
print(tokenizer.decode(outputs[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment