Skip to content

Instantly share code, notes, and snippets.

@strangeloopcanon
Forked from vgel/r1.py
Last active February 10, 2025 01:13
Show Gist options
  • Save strangeloopcanon/7dcc8f5651f47ae20ff37968085c4dd2 to your computer and use it in GitHub Desktop.
Save strangeloopcanon/7dcc8f5651f47ae20ff37968085c4dd2 to your computer and use it in GitHub Desktop.
script to run deepseek-r1 with a min-thinking-tokens parameter, replacing </think> with a random continuation string to extend the model's chain of thought
# Updated from voooogel's r1.py sampler. Looks to see if the tokens are of high confidence, and if not forces a resample.
import argparse
import random
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
import torch
parser = argparse.ArgumentParser()
parser.add_argument("question", type=str)
parser.add_argument(
"-m", "--model-name", default="Qwen/Qwen2.5-0.5B-Instruct"
)
parser.add_argument("-d", "--device", default="auto")
parser.add_argument(
"-r", "--replacements", nargs="+", default=["\nWait, but", "\nHowever...", "\nSo, that means,"]
)
parser.add_argument("-t", "--min-thinking-tokens", type=int, default=128)
parser.add_argument("-p", "--prefill", default="")
parser.add_argument("--max-new-tokens", type=int, default=256)
# New arguments for confidence-based forcing:
parser.add_argument("--confidence-window", type=int, default=3,
help="Number of recent tokens to average for confidence check")
parser.add_argument("--confidence-threshold", type=float, default=0.6,
help="Average probability below which we force a replacement")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.float16, device_map=args.device
)
special_tokens_dict = {"additional_special_tokens": ["<think>", "</think>"]}
tokenizer.add_special_tokens(special_tokens_dict)
model.resize_token_embeddings(len(tokenizer))
# Grab the last token from each special token sequence
_start_think_token = tokenizer.encode("<think>")[-1]
end_think_token = tokenizer.encode("</think>")[-1]
# start_think_token, end_think_token = tokenizer.encode("<think></think>")
@torch.inference_mode
def reasoning_effort(question: str, min_thinking_tokens: int):
# prompt = f"{question}\<think> {args.prefill}\n"
# tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.device)
tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": "You are a brilliant assistant. " + question},
{"role": "assistant", "content": "<think> \n" + args.prefill},
],
continue_final_message=True,
temperature = 0.6,
return_tensors="pt",
)
tokens = tokens.to(model.device)
initial_length = tokens.shape[-1]
kv = DynamicCache()
n_thinking_tokens = 0
forced_replacements = 0
max_forced_replacements = 2
# We'll store the probability of each generated token here:
token_confidences = []
yield tokenizer.decode(tokens[0])
while True:
# Stop if we've added too many tokens
if tokens.shape[-1] - initial_length >= args.max_new_tokens:
break
out = model(input_ids=tokens, past_key_values=kv, use_cache=True)
logits = out.logits[0, -1, :]
probs = torch.softmax(logits, dim=-1)
next_token = torch.multinomial(probs, 1).item()
# Capture the probability (i.e., confidence) for the chosen token
token_prob = probs[next_token].item()
token_confidences.append(token_prob)
kv = out.past_key_values
if len(token_confidences) >= args.confidence_window:
recent_avg_conf = sum(token_confidences[-args.confidence_window:]) / args.confidence_window
if recent_avg_conf < args.confidence_threshold and forced_replacements < max_forced_replacements:
forced_replacements += 1
replacement = random.choice(args.replacements)
yield replacement
replacement_tokens = tokenizer.encode(replacement, add_special_tokens=False)
print(
f"\nLow-confidence trigger: Average over last {args.confidence_window} tokens = {recent_avg_conf:.3f}. "
f"Forced replacement {repr(replacement)} -> {replacement_tokens}",
file=sys.stderr
)
n_thinking_tokens += len(replacement_tokens)
# Replace the current tokens with the replacement tokens.
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
token_confidences = [] # Reset the confidence list after forced injection.
continue # Go to next iteration.
# If the model emits an end token, handle it.
if next_token in (end_think_token, model.config.eos_token_id):
if n_thinking_tokens < min_thinking_tokens and forced_replacements < max_forced_replacements:
forced_replacements += 1
replacement = random.choice(args.replacements)
yield replacement
replacement_tokens = tokenizer.encode(replacement, add_special_tokens=False)
print(f"\nReplacement trigger: {repr(replacement)} -> {replacement_tokens}", file=sys.stderr)
n_thinking_tokens += len(replacement_tokens)
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
token_confidences = []
else:
break
else:
yield tokenizer.decode([next_token])
n_thinking_tokens += 1
tokens = torch.tensor([[next_token]]).to(tokens.device)
# Optionally, you can print a summary of the final token confidences.
if token_confidences:
print(f"\nFinal {len(token_confidences)} tokens confidences: {token_confidences}", file=sys.stderr)
for chunk in reasoning_effort(args.question, args.min_thinking_tokens):
print(chunk, end="", flush=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment