Skip to content

Instantly share code, notes, and snippets.

@vwxyzjn
Created February 20, 2024 21:30
Show Gist options
  • Save vwxyzjn/422b6db82a15db133d78d117be24416e to your computer and use it in GitHub Desktop.
Save vwxyzjn/422b6db82a15db133d78d117be24416e to your computer and use it in GitHub Desktop.
"""
git clone https://github.com/argilla-io/distilabel.git
pip install -e ".[hf-inference-endpoints]"
"""
import asyncio
import os
import pandas as pd
from llm_swarm import LLMSwarm, LLMSwarmConfig
from huggingface_hub import AsyncInferenceClient
from transformers import AutoTokenizer, HfArgumentParser
from tqdm.asyncio import tqdm_asyncio
from datasets import load_dataset
HF_TOKEN = os.getenv("HF_TOKEN")
DATASET_TO_GENERATE = "argilla/OpenHermes-2.5-with-system-2"
NEW_DATASET_NAME = "argilla/OpenHermes-2.5-dpo-with-system-ckpt-2"
dataset = load_dataset(DATASET_TO_GENERATE, split="train")
dataset = dataset.select(range(1024))
parser = HfArgumentParser([LLMSwarmConfig])
isc = parser.parse_args_into_dataclasses()[0]
max_parallel_requests = 120
with LLMSwarm(isc) as llm_swarm:
semaphore = asyncio.Semaphore(max_parallel_requests)
client = AsyncInferenceClient(model=llm_swarm.endpoint)
tokenizer = AutoTokenizer.from_pretrained(isc.model)
tokenizer.add_special_tokens({"sep_token": "", "cls_token": "", "mask_token": "", "pad_token": "[PAD]"})
async def process_text(task):
prompt = tokenizer.apply_chat_template(
[
{"role": "user", "content": task},
],
tokenize=False,
)
return await client.text_generation(
prompt=prompt,
max_new_tokens=1096,
)
async def main():
tasks = dataset["input"]
results = await tqdm_asyncio.gather(*(process_text(task) for task in tasks))
df = pd.DataFrame({"Task": tasks, "Completion": results})
print(df)
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment