Created
February 20, 2024 21:30
-
-
Save vwxyzjn/422b6db82a15db133d78d117be24416e to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
git clone https://github.com/argilla-io/distilabel.git | |
pip install -e ".[hf-inference-endpoints]" | |
""" | |
import asyncio | |
import os | |
import pandas as pd | |
from llm_swarm import LLMSwarm, LLMSwarmConfig | |
from huggingface_hub import AsyncInferenceClient | |
from transformers import AutoTokenizer, HfArgumentParser | |
from tqdm.asyncio import tqdm_asyncio | |
from datasets import load_dataset | |
HF_TOKEN = os.getenv("HF_TOKEN") | |
DATASET_TO_GENERATE = "argilla/OpenHermes-2.5-with-system-2" | |
NEW_DATASET_NAME = "argilla/OpenHermes-2.5-dpo-with-system-ckpt-2" | |
dataset = load_dataset(DATASET_TO_GENERATE, split="train") | |
dataset = dataset.select(range(1024)) | |
parser = HfArgumentParser([LLMSwarmConfig]) | |
isc = parser.parse_args_into_dataclasses()[0] | |
max_parallel_requests = 120 | |
with LLMSwarm(isc) as llm_swarm: | |
semaphore = asyncio.Semaphore(max_parallel_requests) | |
client = AsyncInferenceClient(model=llm_swarm.endpoint) | |
tokenizer = AutoTokenizer.from_pretrained(isc.model) | |
tokenizer.add_special_tokens({"sep_token": "", "cls_token": "", "mask_token": "", "pad_token": "[PAD]"}) | |
async def process_text(task): | |
prompt = tokenizer.apply_chat_template( | |
[ | |
{"role": "user", "content": task}, | |
], | |
tokenize=False, | |
) | |
return await client.text_generation( | |
prompt=prompt, | |
max_new_tokens=1096, | |
) | |
async def main(): | |
tasks = dataset["input"] | |
results = await tqdm_asyncio.gather(*(process_text(task) for task in tasks)) | |
df = pd.DataFrame({"Task": tasks, "Completion": results}) | |
print(df) | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment