Skip to content

Instantly share code, notes, and snippets.

@vgel
Last active February 15, 2025 06:56
Show Gist options
  • Save vgel/8a2497dc45b1ded33287fa7bb6cc1adc to your computer and use it in GitHub Desktop.
Save vgel/8a2497dc45b1ded33287fa7bb6cc1adc to your computer and use it in GitHub Desktop.
script to run deepseek-r1 with a min-thinking-tokens parameter, replacing </think> with a random continuation string to extend the model's chain of thought
import argparse
import random
import sys
from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache
import torch
parser = argparse.ArgumentParser()
parser.add_argument("question", type=str)
parser.add_argument(
"-m", "--model-name", default="deepseek-ai/DeepSeek-R1-Distill-Qwen-32B"
)
parser.add_argument("-d", "--device", default="auto")
parser.add_argument(
"-r", "--replacements", nargs="+", default=["\nWait, but", "\nHmm", "\nSo"]
)
parser.add_argument("-t", "--min-thinking-tokens", type=int, default=128)
parser.add_argument("-p", "--prefill", default="")
args = parser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
model = AutoModelForCausalLM.from_pretrained(
args.model_name, torch_dtype=torch.bfloat16, device_map=args.device
)
_, _start_think_token, end_think_token = tokenizer.encode("<think></think>")
@torch.inference_mode
def reasoning_effort(question: str, min_thinking_tokens: int):
tokens = tokenizer.apply_chat_template(
[
{"role": "user", "content": question},
{"role": "assistant", "content": "<think>\n" + args.prefill},
],
continue_final_message=True,
return_tensors="pt",
)
tokens = tokens.to(model.device)
kv = DynamicCache()
n_thinking_tokens = 0
yield tokenizer.decode(list(tokens[0]))
while True:
out = model(input_ids=tokens, past_key_values=kv, use_cache=True)
next_token = torch.multinomial(
torch.softmax(out.logits[0, -1, :], dim=-1), 1
).item()
kv = out.past_key_values
if (
next_token in (end_think_token, model.config.eos_token_id)
and n_thinking_tokens < min_thinking_tokens
):
replacement = random.choice(args.replacements)
yield replacement
replacement_tokens = tokenizer.encode(replacement)
n_thinking_tokens += len(replacement_tokens)
tokens = torch.tensor([replacement_tokens]).to(tokens.device)
elif next_token == model.config.eos_token_id:
break
else:
yield tokenizer.decode([next_token])
n_thinking_tokens += 1
tokens = torch.tensor([[next_token]]).to(tokens.device)
for chunk in reasoning_effort(args.question, args.min_thinking_tokens):
print(chunk, end="", flush=True)
@vgel
Copy link
Author

vgel commented Jan 22, 2025

you can run it like this:

python r1.py "What is 1+1?"

to change the model, use -m / --model

python r1.py -m 'deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B'  "What is 1+1?"

to change the thinking duration, use -t / --min-thinking-tokens. default 128

python r1.py -t 32 "What is 1+1?"

you can also use -r / --replacements to alter the list of continuation strings, and -d / --device to customize which device to load the model on.

@vgel
Copy link
Author

vgel commented Jan 22, 2025

changed start_think_token, end_think_token = tokenizer.encode("<think></think>") to _, _start_think_token, end_think_token = tokenizer.encode("<think></think>") because deepseek just pushed a tokenizer config update (https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B/commit/ca24ee48c0532a014ddea4c32437a6f4be981ab2), if you're getting an error about not enough values to unpack, make sure your tokenizer config is up to date, or edit that line to remove the leading _,

@secemp9
Copy link

secemp9 commented Jan 22, 2025

changed start_think_token, end_think_token = tokenizer.encode("<think></think>") to _, _start_think_token, end_think_token = tokenizer.encode("<think></think>") because deepseek just pushed a tokenizer config update (https://huggingface.co/deepseek-ai/DeepSeek-R1-Distill-Qwen-32B/commit/ca24ee48c0532a014ddea4c32437a6f4be981ab2), if you're getting an error about not enough values to unpack, make sure your tokenizer config is up to date, or edit that line to remove the leading _,

ah, just had that exact error and thought I was going crazy...good thing I checked back here

@DenisSergeevitch
Copy link

That’s almost what I thought yesterday – thank you for sharing!

I thought the approach could be a bit different:
If the tag appears, we inject a new prompt like:

The first level of problem-solving has been achieved. Now find a new logical solution through intensive contemplation, but do not copy the previous <think> thoughts, make new ones.

<think>

and perform this process ~five times.

I’m not a skilled developer with r1 level hardware – could you please try this approach too?

@sebington
Copy link

Hi, I tried this on my CPU-only machine and it was very slow, even with the Distill 1.5B model. So I asked Claude to generate a 'CPU-friendly' version: https://gist.github.com/sebington/ece931a90048109a38b1df1fa4dc4a03

@bindingsoul
Copy link

Is there a way to format the output, I want main output without getting the think texts, after the deepseek r1 has already done with thinking and producing output.

@vgel
Copy link
Author

vgel commented Jan 23, 2025

@bindingsoul

Is there a way to format the output, I want main output without getting the think texts, after the deepseek r1 has already done with thinking and producing output.

Yes the loop at the bottom is over token strings. So just don't print until you see </think>:

has_stopped_thinking = False
for chunk in reasoning_effort(args.question, args.min_thinking_tokens):
    if not has_stopped_thinking:
        if "</think>" in chunk:
            has_stopped_thinking = True
    else:
        print(chunk, end="", flush=True)

(wrote this on my phone, untested)

@vTuanpham
Copy link

We need someone to straight up eval on the ARC-AGI by setting the min-thinking-tokens to 50k per task

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment