Skip to content

Instantly share code, notes, and snippets.

@thistleknot
Last active November 22, 2023 03:50
Show Gist options
  • Save thistleknot/109278d4283615ce7b1244aebcd85c24 to your computer and use it in GitHub Desktop.
Save thistleknot/109278d4283615ce7b1244aebcd85c24 to your computer and use it in GitHub Desktop.
tiny mistral
from transformers import AutoTokenizer, AutoModelForCausalLM
from transformers import BitsAndBytesConfig
import torch
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
# Check if GPU is available and set the device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("Felladrin/TinyMistral-248M-Evol-Instruct")
model = AutoModelForCausalLM.from_pretrained("Felladrin/TinyMistral-248M-Evol-Instruct", quantization_config=nf4_config)
# Move the model to the device (GPU if available)
#model.to(device)
input_text = \
f"""
### Instruction:
Tell me a story that involves a lot of imagination.
### Response:"""
input_ids = tokenizer.encode(input_text, return_tensors='pt')
# Greedy Search Decoding
greedy_output = model.generate(input_ids, max_length=1024)
print("Greedy Search:", tokenizer.decode(greedy_output[0], skip_special_tokens=True))
# Beam Search Decoding
beam_output = model.generate(
input_ids,
max_length=1024,
num_beams=5,
early_stopping=True
)
print("Beam Search:", tokenizer.decode(beam_output[0], skip_special_tokens=True))
# Sampling with Temperature
sample_output = model.generate(
input_ids,
do_sample=True,
max_length=1024,
top_k=0,
temperature=0.7
)
print("Sampling with Temperature:", tokenizer.decode(sample_output[0], skip_special_tokens=True))
# Top-k Sampling
top_k_output = model.generate(
input_ids,
do_sample=True,
max_length=1024,
top_k=50
)
print("Top-k Sampling:", tokenizer.decode(top_k_output[0], skip_special_tokens=True))
# Nucleus (Top-p) Sampling
top_p_output = model.generate(
input_ids,
do_sample=True,
max_length=1024,
top_p=0.92,
top_k=0
)
print("Nucleus (Top-p) Sampling:", tokenizer.decode(top_p_output[0], skip_special_tokens=True))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment