Skip to content

Instantly share code, notes, and snippets.

@alvarobartt
Created December 21, 2023 09:30
Show Gist options
  • Save alvarobartt/d08888dd2660b6763421dd6b1142127c to your computer and use it in GitHub Desktop.
Save alvarobartt/d08888dd2660b6763421dd6b1142127c to your computer and use it in GitHub Desktop.
Unpacks a 🤗`trl.ConstantLengthDataset` to estimate how many steps is one epoch to avoid wrong epoch estimations
# Usage:
# python steps-unpacked-constant-length-dataset.py \
# --dataset-path "argilla/ultrafeedback-binarized-preferences-cleaned" \
# --hf-tokenizer "alignment-handbook/zephyr-7b-sft-full" \
# --gradient-accumulation-steps 2 \
# --per-eval-batch-size 32 \
# --num-devices 8 \
# --max-seq-length 2048 \
# --num-of-sequences 1024 \
# --chars-per-token 3.6
from datasets import load_dataset
from transformers import AutoTokenizer
from trl.trainer.utils import ConstantLengthDataset
import argparse
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--dataset-path", type=str)
parser.add_argument("--hf-tokenizer", type=str)
parsed.add_argument("--gradient-accumulation-steps", type=int)
parsed.add_argument("--per-eval-batch-size", type=int)
parsed.add_argument("--num-devices", type=int)
parsed.add_argument("--max-seq-length", type=int)
parsed.add_argument("--num-of-sequences", type=int)
parsed.add_argument("--chars-per-token", type=float)
return parser.parse_args()
def main() -> None:
args = get_args()
print(f"Args: {args}")
dataset = load_dataset(args.dataset_path, split="train")
tokenizer = AutoTokenizer.from_pretrained(args.hf_tokenizer)
def formatting_func(...) -> ...:
...
dataset = ConstantLengthDataset(
tokenizer,
dataset["train"],
dataset_text_field="text",
formatting_func=formatting_func,
seq_length=args.max_seq_length,
infinite=False,
num_of_sequences=args.num_of_sequences,
chars_per_token=args.chars_per_token,
eos_token_id=tokenizer.eos_token_id,
)
counter = 0
for _ in dataset:
counter += 1
steps = counter / (len(dataset) / args.gradient_accumulation_steps / args.per_device_batch_size / args.num_devices)
print(f"1 training epoch equals to {steps} steps")
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment