Last active
February 7, 2025 15:56
-
-
Save sayakpaul/b5e94f5202eaf34cbaf9dac1c45f89ad to your computer and use it in GitHub Desktop.
Generate labels with DeepSeek and `transformers`.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Implementation of the label generation part in https://danielvanstrien.xyz/posts/2025/deepseek/distil-deepseek-modernbert.html | |
using `transformers` and DeepSeek. | |
""" | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
import re | |
import contextlib | |
import math | |
from tqdm.auto import tqdm | |
import json | |
import polars as pl | |
from datasets import Dataset, Value, ClassLabel | |
from huggingface_hub import snapshot_download | |
JSON_PATTERN = re.compile(r"```json\n(.*?)```", re.DOTALL) | |
DIRECT_JSON_PATTERN = re.compile(r"\{[^}]*\}", re.DOTALL) | |
BATCH_SIZE = 64 | |
NUM_SAMPLES = 3000 | |
@torch.no_grad() | |
def load_model(): | |
repo_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B" | |
model = AutoModelForCausalLM.from_pretrained( | |
repo_id, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" | |
).to("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
return model, tokenizer | |
def format_text_as_prompt(data: dict[str, str]): | |
return f"""Look at the title and abstract for the following arXiv paper. Assess whether the paper is likely to introduce a newly created dataset. | |
Title: {data['title']} | |
Abstract: {data['abstract']} | |
Your role is to decide whether the paper introduces a newly created dataset. First you should think about whether the paper is likely to introduce a newly created dataset. You should then return your reasoning and the label you've chosen. | |
You should choose out of the "new_dataset" or "no_new_dataset" labels. | |
Return your reasoning and the label you've chosen as a JSON object like this: | |
```json | |
{{ | |
"label": "new_dataset" | "no_new_dataset", | |
"explanation": "The reasoning the model used to come to its conclusion" | |
}} | |
``` | |
""" | |
def load_dataset(): | |
files = snapshot_download( | |
repo_id="librarian-bots/arxiv-metadata-snapshot", | |
allow_patterns=["*.parquet"], | |
repo_type="dataset", | |
) | |
df = pl.scan_parquet(files) | |
df = df.collect() | |
return df | |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16) | |
@torch.no_grad() | |
def predict_label_without_structured_output(data: list[dict[str, str]], model: torch.nn.Module, tokenizer) -> str: | |
prompts = [format_text_as_prompt(d) for d in data] | |
texts = [ | |
tokenizer.apply_chat_template( | |
[{"role": "user", "content": prompt}], tokenize=False, add_generation_prompt=True | |
) | |
for prompt in prompts | |
] | |
model_inputs = tokenizer( | |
texts, | |
return_tensors="pt", | |
padding=True, # important so they line up in a batch | |
truncation=True, # so they don’t exceed model’s max length | |
).to(model.device) | |
generated_ids = model.generate(**model_inputs, max_new_tokens=2048) | |
results_ids = [] | |
for i, output_ids in enumerate(generated_ids): | |
input_len = len(model_inputs.input_ids[i]) | |
results_ids.append(output_ids[input_len:]) | |
outputs = tokenizer.batch_decode(results_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
return outputs | |
def try_extract_json_from_text(text: str) -> tuple[str, dict | None]: | |
if match := JSON_PATTERN.search(text): | |
json_results = match.group(1) | |
with contextlib.suppress(json.JSONDecodeError): | |
return text, json.loads(json_results) | |
if match := DIRECT_JSON_PATTERN.search(text): | |
json_text = match.group(0) | |
with contextlib.suppress(json.JSONDecodeError): | |
return text, json.loads(json_text) | |
return text, None | |
def create_and_push_ds(df): | |
ds = Dataset.from_polars( | |
df.select(["id", "title", "abstract", "labels", "explanations"]), | |
) | |
large_string_columns = [k for k, v in ds.features.items() if isinstance(v, Value) and v.dtype == "large_string"] | |
for column in large_string_columns: | |
ds = ds.cast_column(column, Value("string")) | |
ds = ds.cast_column("labels", ClassLabel(names=["new_dataset", "no_new_dataset"])) | |
ds.push_to_hub("sayakpaul/arxiv-new-datasets") | |
def chunked(iterable, batch_size): | |
for i in range(0, len(iterable), batch_size): | |
yield iterable[i : i + batch_size] | |
def main(): | |
df = load_dataset() | |
model, tokenizer = load_model() | |
sample_df = df.sample(NUM_SAMPLES, seed=42) | |
examples = sample_df.select(pl.col(["abstract", "title"])).to_dicts() | |
total_batches = math.ceil(len(examples) / BATCH_SIZE) | |
# run _sample | |
raw_predictions = [] | |
for i, batch_examples in enumerate(tqdm(chunked(examples, BATCH_SIZE), total=total_batches)): | |
preds = predict_label_without_structured_output(batch_examples, model, tokenizer) | |
raw_predictions.extend(preds) | |
parsed_results = [try_extract_json_from_text(result) for result in raw_predictions] | |
labels_and_explanations = [ | |
(result[1].get("label"), result[1].get("explanation")) | |
if result[1] is not None and isinstance(result[1], dict) | |
else (None, None) | |
for result in parsed_results | |
] | |
# Unzip the list of tuples into separate lists | |
labels, explanations = zip(*labels_and_explanations) | |
lables = list(labels) | |
explanations = list(explanations) | |
sample_df = sample_df.with_columns( | |
pl.Series(lables).alias("labels"), | |
pl.Series(explanations).alias("explanations"), | |
) | |
create_and_push_ds(sample_df) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment