Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Last active February 7, 2025 15:56
Show Gist options
  • Save sayakpaul/b5e94f5202eaf34cbaf9dac1c45f89ad to your computer and use it in GitHub Desktop.
Save sayakpaul/b5e94f5202eaf34cbaf9dac1c45f89ad to your computer and use it in GitHub Desktop.
Generate labels with DeepSeek and `transformers`.
"""
Implementation of the label generation part in https://danielvanstrien.xyz/posts/2025/deepseek/distil-deepseek-modernbert.html
using `transformers` and DeepSeek.
"""
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import re
import contextlib
import math
from tqdm.auto import tqdm
import json
import polars as pl
from datasets import Dataset, Value, ClassLabel
from huggingface_hub import snapshot_download
JSON_PATTERN = re.compile(r"```json\n(.*?)```", re.DOTALL)
DIRECT_JSON_PATTERN = re.compile(r"\{[^}]*\}", re.DOTALL)
BATCH_SIZE = 64
NUM_SAMPLES = 3000
@torch.no_grad()
def load_model():
repo_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
model = AutoModelForCausalLM.from_pretrained(
repo_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2"
).to("cuda")
tokenizer = AutoTokenizer.from_pretrained(repo_id)
return model, tokenizer
def format_text_as_prompt(data: dict[str, str]):
return f"""Look at the title and abstract for the following arXiv paper. Assess whether the paper is likely to introduce a newly created dataset.
Title: {data['title']}
Abstract: {data['abstract']}
Your role is to decide whether the paper introduces a newly created dataset. First you should think about whether the paper is likely to introduce a newly created dataset. You should then return your reasoning and the label you've chosen.
You should choose out of the "new_dataset" or "no_new_dataset" labels.
Return your reasoning and the label you've chosen as a JSON object like this:
```json
{{
"label": "new_dataset" | "no_new_dataset",
"explanation": "The reasoning the model used to come to its conclusion"
}}
```
"""
def load_dataset():
files = snapshot_download(
repo_id="librarian-bots/arxiv-metadata-snapshot",
allow_patterns=["*.parquet"],
repo_type="dataset",
)
df = pl.scan_parquet(files)
df = df.collect()
return df
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
@torch.no_grad()
def predict_label_without_structured_output(data: list[dict[str, str]], model: torch.nn.Module, tokenizer) -> str:
prompts = [format_text_as_prompt(d) for d in data]
texts = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True
)
for prompt in prompts
]
model_inputs = tokenizer(
texts,
return_tensors="pt",
padding=True, # important so they line up in a batch
truncation=True, # so they don’t exceed model’s max length
).to(model.device)
generated_ids = model.generate(**model_inputs, max_new_tokens=2048)
results_ids = []
for i, output_ids in enumerate(generated_ids):
input_len = len(model_inputs.input_ids[i])
results_ids.append(output_ids[input_len:])
outputs = tokenizer.batch_decode(results_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return outputs
def try_extract_json_from_text(text: str) -> tuple[str, dict | None]:
if match := JSON_PATTERN.search(text):
json_results = match.group(1)
with contextlib.suppress(json.JSONDecodeError):
return text, json.loads(json_results)
if match := DIRECT_JSON_PATTERN.search(text):
json_text = match.group(0)
with contextlib.suppress(json.JSONDecodeError):
return text, json.loads(json_text)
return text, None
def create_and_push_ds(df):
ds = Dataset.from_polars(
df.select(["id", "title", "abstract", "labels", "explanations"]),
)
large_string_columns = [k for k, v in ds.features.items() if isinstance(v, Value) and v.dtype == "large_string"]
for column in large_string_columns:
ds = ds.cast_column(column, Value("string"))
ds = ds.cast_column("labels", ClassLabel(names=["new_dataset", "no_new_dataset"]))
ds.push_to_hub("sayakpaul/arxiv-new-datasets")
def chunked(iterable, batch_size):
for i in range(0, len(iterable), batch_size):
yield iterable[i : i + batch_size]
def main():
df = load_dataset()
model, tokenizer = load_model()
sample_df = df.sample(NUM_SAMPLES, seed=42)
examples = sample_df.select(pl.col(["abstract", "title"])).to_dicts()
total_batches = math.ceil(len(examples) / BATCH_SIZE)
# run _sample
raw_predictions = []
for i, batch_examples in enumerate(tqdm(chunked(examples, BATCH_SIZE), total=total_batches)):
preds = predict_label_without_structured_output(batch_examples, model, tokenizer)
raw_predictions.extend(preds)
parsed_results = [try_extract_json_from_text(result) for result in raw_predictions]
labels_and_explanations = [
(result[1].get("label"), result[1].get("explanation"))
if result[1] is not None and isinstance(result[1], dict)
else (None, None)
for result in parsed_results
]
# Unzip the list of tuples into separate lists
labels, explanations = zip(*labels_and_explanations)
lables = list(labels)
explanations = list(explanations)
sample_df = sample_df.with_columns(
pl.Series(lables).alias("labels"),
pl.Series(explanations).alias("explanations"),
)
create_and_push_ds(sample_df)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment