-
-
Save strangeloopcanon/7dcc8f5651f47ae20ff37968085c4dd2 to your computer and use it in GitHub Desktop.
script to run deepseek-r1 with a min-thinking-tokens parameter, replacing </think> with a random continuation string to extend the model's chain of thought
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Updated from voooogel's r1.py sampler. Looks to see if the tokens are of high confidence, and if not forces a resample. | |
import argparse | |
import random | |
import sys | |
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache | |
import torch | |
parser = argparse.ArgumentParser() | |
parser.add_argument("question", type=str) | |
parser.add_argument( | |
"-m", "--model-name", default="Qwen/Qwen2.5-0.5B-Instruct" | |
) | |
parser.add_argument("-d", "--device", default="auto") | |
parser.add_argument( | |
"-r", "--replacements", nargs="+", default=["\nWait, but", "\nHowever...", "\nSo, that means,"] | |
) | |
parser.add_argument("-t", "--min-thinking-tokens", type=int, default=128) | |
parser.add_argument("-p", "--prefill", default="") | |
parser.add_argument("--max-new-tokens", type=int, default=256) | |
# New arguments for confidence-based forcing: | |
parser.add_argument("--confidence-window", type=int, default=3, | |
help="Number of recent tokens to average for confidence check") | |
parser.add_argument("--confidence-threshold", type=float, default=0.6, | |
help="Average probability below which we force a replacement") | |
args = parser.parse_args() | |
tokenizer = AutoTokenizer.from_pretrained(args.model_name) | |
model = AutoModelForCausalLM.from_pretrained( | |
args.model_name, torch_dtype=torch.float16, device_map=args.device | |
) | |
special_tokens_dict = {"additional_special_tokens": ["<think>", "</think>"]} | |
tokenizer.add_special_tokens(special_tokens_dict) | |
model.resize_token_embeddings(len(tokenizer)) | |
# Grab the last token from each special token sequence | |
_start_think_token = tokenizer.encode("<think>")[-1] | |
end_think_token = tokenizer.encode("</think>")[-1] | |
# start_think_token, end_think_token = tokenizer.encode("<think></think>") | |
@torch.inference_mode | |
def reasoning_effort(question: str, min_thinking_tokens: int): | |
# prompt = f"{question}\<think> {args.prefill}\n" | |
# tokens = tokenizer.encode(prompt, return_tensors="pt").to(model.device) | |
tokens = tokenizer.apply_chat_template( | |
[ | |
{"role": "user", "content": "You are a brilliant assistant. " + question}, | |
{"role": "assistant", "content": "<think> \n" + args.prefill}, | |
], | |
continue_final_message=True, | |
temperature = 0.6, | |
return_tensors="pt", | |
) | |
tokens = tokens.to(model.device) | |
initial_length = tokens.shape[-1] | |
kv = DynamicCache() | |
n_thinking_tokens = 0 | |
forced_replacements = 0 | |
max_forced_replacements = 2 | |
# We'll store the probability of each generated token here: | |
token_confidences = [] | |
yield tokenizer.decode(tokens[0]) | |
while True: | |
# Stop if we've added too many tokens | |
if tokens.shape[-1] - initial_length >= args.max_new_tokens: | |
break | |
out = model(input_ids=tokens, past_key_values=kv, use_cache=True) | |
logits = out.logits[0, -1, :] | |
probs = torch.softmax(logits, dim=-1) | |
next_token = torch.multinomial(probs, 1).item() | |
# Capture the probability (i.e., confidence) for the chosen token | |
token_prob = probs[next_token].item() | |
token_confidences.append(token_prob) | |
kv = out.past_key_values | |
if len(token_confidences) >= args.confidence_window: | |
recent_avg_conf = sum(token_confidences[-args.confidence_window:]) / args.confidence_window | |
if recent_avg_conf < args.confidence_threshold and forced_replacements < max_forced_replacements: | |
forced_replacements += 1 | |
replacement = random.choice(args.replacements) | |
yield replacement | |
replacement_tokens = tokenizer.encode(replacement, add_special_tokens=False) | |
print( | |
f"\nLow-confidence trigger: Average over last {args.confidence_window} tokens = {recent_avg_conf:.3f}. " | |
f"Forced replacement {repr(replacement)} -> {replacement_tokens}", | |
file=sys.stderr | |
) | |
n_thinking_tokens += len(replacement_tokens) | |
# Replace the current tokens with the replacement tokens. | |
tokens = torch.tensor([replacement_tokens]).to(tokens.device) | |
token_confidences = [] # Reset the confidence list after forced injection. | |
continue # Go to next iteration. | |
# If the model emits an end token, handle it. | |
if next_token in (end_think_token, model.config.eos_token_id): | |
if n_thinking_tokens < min_thinking_tokens and forced_replacements < max_forced_replacements: | |
forced_replacements += 1 | |
replacement = random.choice(args.replacements) | |
yield replacement | |
replacement_tokens = tokenizer.encode(replacement, add_special_tokens=False) | |
print(f"\nReplacement trigger: {repr(replacement)} -> {replacement_tokens}", file=sys.stderr) | |
n_thinking_tokens += len(replacement_tokens) | |
tokens = torch.tensor([replacement_tokens]).to(tokens.device) | |
token_confidences = [] | |
else: | |
break | |
else: | |
yield tokenizer.decode([next_token]) | |
n_thinking_tokens += 1 | |
tokens = torch.tensor([[next_token]]).to(tokens.device) | |
# Optionally, you can print a summary of the final token confidences. | |
if token_confidences: | |
print(f"\nFinal {len(token_confidences)} tokens confidences: {token_confidences}", file=sys.stderr) | |
for chunk in reasoning_effort(args.question, args.min_thinking_tokens): | |
print(chunk, end="", flush=True) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment