Skip to content

Instantly share code, notes, and snippets.

@fancyerii
Created September 21, 2023 11:44
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save fancyerii/d89d4d887b020eec1d756f606c3bc8bf to your computer and use it in GitHub Desktop.
Save fancyerii/d89d4d887b020eec1d756f606c3bc8bf to your computer and use it in GitHub Desktop.
diff --git a/src/llama_recipes/configs/datasets.py b/src/llama_recipes/configs/datasets.py
index 62230e4..3e4c76f 100644
--- a/src/llama_recipes/configs/datasets.py
+++ b/src/llama_recipes/configs/datasets.py
@@ -15,8 +15,8 @@ class samsum_dataset:
@dataclass
class grammar_dataset:
dataset: str = "grammar_dataset"
- train_split: str = "src/llama_recipes/datasets/grammar_dataset/gtrain_10k.csv"
- test_split: str = "src/llama_recipes/datasets/grammar_dataset/grammar_validation.csv"
+ train_split: str = "src/llama_recipes/datasets2/grammar_dataset/gtrain_10k.csv"
+ test_split: str = "src/llama_recipes/datasets2/grammar_dataset/grammar_validation.csv"
input_length: int = 2048
@@ -25,7 +25,7 @@ class alpaca_dataset:
dataset: str = "alpaca_dataset"
train_split: str = "train"
test_split: str = "val"
- data_path: str = "src/llama_recipes/datasets/alpaca_data.json"
+ data_path: str = "src/llama_recipes/datasets2/alpaca_data.json"
@dataclass
diff --git a/src/llama_recipes/datasets/__init__.py b/src/llama_recipes/datasets/__init__.py
deleted file mode 100644
index 57d2376..0000000
--- a/src/llama_recipes/datasets/__init__.py
+++ /dev/null
@@ -1,6 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from llama_recipes.datasets.grammar_dataset.grammar_dataset import get_dataset as get_grammar_dataset
-from llama_recipes.datasets.alpaca_dataset import InstructionDataset as get_alpaca_dataset
-from llama_recipes.datasets.samsum_dataset import get_preprocessed_samsum as get_samsum_dataset
\ No newline at end of file
diff --git a/src/llama_recipes/datasets/alpaca_dataset.py b/src/llama_recipes/datasets/alpaca_dataset.py
deleted file mode 100644
index 091aef9..0000000
--- a/src/llama_recipes/datasets/alpaca_dataset.py
+++ /dev/null
@@ -1,78 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-# For dataset details visit: https://crfm.stanford.edu/2023/03/13/alpaca.html
-
-import copy
-import json
-
-import torch
-from torch.utils.data import Dataset
-
-
-PROMPT_DICT = {
- "prompt_input": (
- "Below is an instruction that describes a task, paired with an input that provides further context. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
- ),
- "prompt_no_input": (
- "Below is an instruction that describes a task. "
- "Write a response that appropriately completes the request.\n\n"
- "### Instruction:\n{instruction}\n\n### Response:"
- ),
-}
-
-class InstructionDataset(Dataset):
- def __init__(self, dataset_config, tokenizer, partition="train", max_words=30):
- self.ann = json.load(open(dataset_config.data_path))
- if partition == "train":
- self.ann = self.ann
- else:
- self.ann = self.ann[:200]
-
- self.max_words = max_words
- # tokenizer = Tokenizer(model_path=model_path + "./tokenizer.model")
- self.tokenizer = tokenizer
- # self.tokenizer1 = tokenizer
-
- def __len__(self):
- return len(self.ann)
-
- def __getitem__(self, index):
- IGNORE_INDEX = -100 # The default setting in CrossEntropyLoss
-
-
- ann = self.ann[index]
- if ann.get("input", "") == "":
- prompt = PROMPT_DICT["prompt_no_input"].format_map(ann)
- else:
- prompt = PROMPT_DICT["prompt_input"].format_map(ann)
- example = prompt + ann["output"]
- prompt = torch.tensor(
- self.tokenizer.encode(prompt), dtype=torch.int64
- )
- example = self.tokenizer.encode(example)
- example.append(self.tokenizer.eos_token_id)
- example = torch.tensor(
- example, dtype=torch.int64
- )
- padding = self.max_words - example.shape[0]
- if padding > 0:
- example = torch.cat((example, torch.zeros(padding, dtype=torch.int64) - 1))
- elif padding < 0:
- example = example[: self.max_words]
- labels = copy.deepcopy(example)
- labels[: len(prompt)] = -1
- example_mask = example.ge(0)
- label_mask = labels.ge(0)
- example[~example_mask] = 0
- labels[~label_mask] = IGNORE_INDEX
- example_mask = example_mask.float()
- label_mask = label_mask.float()
-
- return {
- "input_ids": example,
- "labels": labels,
- "attention_mask":example_mask,
- }
diff --git a/src/llama_recipes/datasets/grammar_dataset/__init__.py b/src/llama_recipes/datasets/grammar_dataset/__init__.py
deleted file mode 100644
index b193f67..0000000
--- a/src/llama_recipes/datasets/grammar_dataset/__init__.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
diff --git a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py b/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py
deleted file mode 100644
index 47383c4..0000000
--- a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset.py
+++ /dev/null
@@ -1,85 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-# For dataset details visit: https://huggingface.co/datasets/jfleg
-# For download and preparation see: recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb
-
-
-from datasets import load_dataset
-from pathlib import Path
-
-from torch.utils.data import Dataset
-
-from llama_recipes.datasets.utils import ConcatDataset
-
-
-class grammar(Dataset):
- def __init__(
- self,
- tokenizer,
- csv_name=None,
- ):
-
- try:
- self.dataset = load_dataset(
- "csv",
- data_files={"train": [csv_name]}, # "eval": "grammar_validation.csv"},
- delimiter=",",
- )
- except Exception as e:
- print("Loading of grammar dataset failed! Please see recipes/ft_datasets/grammar_dataset/grammar_dataset_process.ipynb for details on how to download the dataset.")
- raise e
-
- # self.dataset = load_dataset("wikihow", "all", data_dir="data/", split=type_path)
- # if num_samples:
- # self.dataset = self.dataset.select(list(range(0, num_samples)))
- self.tokenizer = tokenizer
- self.print_text = False # print_text
-
- def __len__(self):
- return self.dataset["train"].shape[0]
-
- def convert_to_features(self, example_batch):
-
- # Create prompt and tokenize contexts and questions
-
- if self.print_text:
- print("Input Text: ", self.clean_text(example_batch["text"]))
-
- input_ = example_batch["input"]
- target_ = example_batch["target"]
-
- prompt = f"Correct this to standard English: {input_}\n---\nCorrected: {target_}"
- sample = self.tokenizer(prompt)
-
- return sample
-
- def __getitem__(self, index):
- sample = self.convert_to_features(self.dataset["train"][index])
- source_ids = sample["input_ids"]
-
- src_mask = sample["attention_mask"]
-
- return {
- "input_ids": source_ids,
- "attention_mask": src_mask,
- "labels": source_ids.copy(),
- }
-
-
-def get_dataset(
- dataset_config, tokenizer, csv_name=None
-):
- """cover function for handling loading the working dataset"""
- """dataset loading"""
- if csv_name is None:
- currPath = Path.cwd() / "datasets_grammar" / "grammar_train.csv"
- print(f"Loading dataset {currPath}")
- csv_name = str(currPath)
- dataset = grammar(
- tokenizer=tokenizer,
- csv_name=csv_name,
- )
-
- return ConcatDataset(dataset, chunk_size=dataset_config.input_length)
-
diff --git a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb b/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb
deleted file mode 100644
index ccbddca..0000000
--- a/src/llama_recipes/datasets/grammar_dataset/grammar_dataset_process.ipynb
+++ /dev/null
@@ -1,463 +0,0 @@
-{
- "cells": [
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Copyright (c) Meta Platforms, Inc. and affiliates.\n",
- "This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.\n",
- "\n",
- "Use this notebook to pull in datasets and apply pre-processing. Most grammar datasets unfortunately require preprocessing before being usable in training. (example - jfleg has 4 targets per input, so we have to rematch as 1:1 pairings) "
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "metadata": {},
- "outputs": [],
-
- "source": [
- "import csv\n",
- "from datasets import load_metric, load_dataset\n",
- "from pathlib import Path"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "metadata": {},
- "outputs": [],
- "source": [
- "list_replacements = [\n",
- " (\" .\", \".\"), \n",
- " (\" ,\", \",\"),\n",
- " (\" '\", \"'\"),\n",
- " (\" ?\", \"?\"),\n",
- " (\" !\", \"!\"),\n",
- " (\" :\", \"!\"),\n",
- " (\" ;\", \"!\"),\n",
- " (\" n't\", \"n't\"),\n",
- " (\" v\", \"n't\"),\n",
- " (\"2 0 0 6\", \"2006\"),\n",
- " (\"5 5\", \"55\"),\n",
- " (\"4 0 0\", \"400\"),\n",
- " (\"1 7-5 0\", \"1750\"),\n",
- " (\"2 0 %\", \"20%\"),\n",
- " (\"5 0\", \"50\"),\n",
- " (\"1 2\", \"12\"),\n",
- " (\"1 0\", \"10\"),\n",
- " ('\" ballast water', '\"ballast water')\n",
- " ]"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 3,
- "metadata": {},
- "outputs": [],
- "source": [
- "def correct_spacing(item):\n",
- " \"\"\" we iterate through the list of all replacements per each item in dataset\"\"\"\n",
- " for fix in list_replacements:\n",
- " item = item.replace(fix[0], fix[1])\n",
- " return item\n",
- "\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "metadata": {},
- "outputs": [],
- "source": [
- "def generate_csv(csv_path, dataset):\n",
- " \"\"\" apply spacing corrections and save out matched pairs to csv file as dataset\"\"\"\n",
- " with open(csv_path, 'w', newline='') as csvfile:\n",
- " writer = csv.writer(csvfile)\n",
- " writer.writerow([\"input\", \"target\"])\n",
- " for case in dataset:\n",
- " \t # Adding the t5 task indication prefix to input \n",
-
- " input_text = case[\"sentence\"]\n",
-
- " input_text = correct_spacing(input_text)\n",
- "\n",
- " for correction in case[\"corrections\"]:\n",
- " correction = correct_spacing(correction)\n",
- " # a few of the cases contain blank strings. \n",
- " if input_text and correction:\n",
- " writer.writerow([input_text, correction])"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "In Jfleg - validation will be used as 'train', test will be 'validation'"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 5,
-
- "metadata": {},
- "outputs": [
- {
- "name": "stderr",
- "output_type": "stream",
- "text": [
-
- "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n",
- "Found cached dataset jfleg (/data/home/mreso/.cache/huggingface/datasets/jfleg/default/1.0.0/ed4ab2367351fe31949f48849ae6732b164f0d5ea6bb5d4357ff4293ac89511b)\n"
-
- ]
- }
- ],
- "source": [
- "train_dataset = load_dataset(\"jfleg\", split='validation[:]') \n",
- "eval_dataset = load_dataset(\"jfleg\", split='test[:]')\n"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 6,
-
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Dataset({\n",
- " features: ['sentence', 'corrections'],\n",
- " num_rows: 755\n",
- "})\n",
- "Dataset({\n",
- " features: ['sentence', 'corrections'],\n",
- " num_rows: 748\n",
- "})\n"
- ]
- }
- ],
- "source": [
- "print(train_dataset)\n",
- "print(eval_dataset)\n"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 7,
-
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas . \n",
- "['Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become experts in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ', 'Students can focus on only a few subjects they are interested in and they will become an expert in those areas . ']\n"
- ]
- }
- ],
- "source": [
- "print(train_dataset['sentence'][22])\n",
- "print(train_dataset['corrections'][22])"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 8,
-
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "'Students can focus on only a few subjects they are intwerested in and they will become an experts in those areas. '"
- ]
- },
-
- "execution_count": 8,
-
- "metadata": {},
- "output_type": "execute_result"
- }
- ],
- "source": [
- "clean22 = correct_spacing(train_dataset['sentence'][22])\n",
- "clean22"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 9,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "jfleg_dir = Path.cwd()/'jfleg_dataset' # if you only use 'jfleg', hf will try and use that and complain\n",
- "jfleg_dir.mkdir(parents=True,exist_ok=True)\n",
- "c4_dir = Path.cwd()/'c4_dataset'\n",
- "c4_dir.mkdir(parents=True,exist_ok=True)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Process Jfleg data "
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 10,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "j_train_file = jfleg_dir/'jtrain.csv'\n",
- "j_eval_file = jfleg_dir/'jeval.csv'"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 11,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "generate_csv(j_train_file, train_dataset)"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 12,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "generate_csv(j_eval_file, eval_dataset)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Process C4_200M (!) - we'll pull 10K to start"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 13,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "c4_dataset = load_dataset(\"liweili/c4_200m\", streaming = True)"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 14,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "iterator = iter(c4_dataset['train'])"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 15,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "def c4_generate_csv(csv_path, iterator, num_examples):\n",
- " with open(csv_path, 'w', newline='') as csvfile:\n",
- " writer = csv.writer(csvfile)\n",
- " writer.writerow([\"input\", \"target\"])\n",
- " for i in range(0,num_examples):\n",
- " data = next(iterator)\n",
-
- " input_text = data[\"input\"]\n",
-
- " input_text = correct_spacing(input_text)\n",
- " correction = correct_spacing(data[\"output\"])\n",
- " if input_text and correction:\n",
- " writer.writerow([input_text, correction])"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 16,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "c4_dir = Path.cwd()/'c4_dataset'\n",
- "c4_dir.mkdir(parents=True,exist_ok=True)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "You can modify the following to make the csv file with desired number of instances, here we go for 10k to make a quick test"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 17,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "c4_filename = c4_dir/'c4train_10k.csv'"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 18,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "c4_generate_csv(c4_filename, iterator, num_examples=10000)"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "metadata": {},
- "source": [
- "Create a single training file by combining jtrain and c4train"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 19,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "merge_list = [j_train_file, c4_filename, ]"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 20,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "import pandas as pd"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 21,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "combined_csv = pd.concat([pd.read_csv(fn) for fn in merge_list])\n"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 22,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "merged_name = \"gtrain_10k.csv\""
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 23,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "combined_csv.to_csv(merged_name, index=False, encoding = 'utf-8-sig', )"
- ]
- },
- {
- "cell_type": "code",
-
- "execution_count": 24,
-
- "metadata": {},
- "outputs": [],
- "source": [
- "eval_name = \"grammar_validation.csv\""
- ]
-
- },
- {
- "cell_type": "code",
- "execution_count": 25,
- "metadata": {},
- "outputs": [],
- "source": [
- "eval_csv = pd.read_csv(j_eval_file)\n",
- "eval_csv.to_csv(eval_name, index=False, encoding = 'utf-8-sig', )"
- ]
-
- }
- ],
- "metadata": {
- "interpreter": {
- "hash": "5b2c14c5f2a3b21e6c2412c8196f5145870350e81c0b737cae3e5c60eb1e1eac"
- },
- "kernelspec": {
-
- "display_name": "Python 3 (ipykernel)",
-
- "language": "python",
- "name": "python3"
- },
- "language_info": {
- "codemirror_mode": {
- "name": "ipython",
- "version": 3
- },
- "file_extension": ".py",
- "mimetype": "text/x-python",
- "name": "python",
- "nbconvert_exporter": "python",
- "pygments_lexer": "ipython3",
- "version": "3.10.11"
-
- }
- },
- "nbformat": 4,
- "nbformat_minor": 4
-
-}
diff --git a/src/llama_recipes/datasets/samsum_dataset.py b/src/llama_recipes/datasets/samsum_dataset.py
deleted file mode 100644
index fd91782..0000000
--- a/src/llama_recipes/datasets/samsum_dataset.py
+++ /dev/null
@@ -1,33 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-# For dataset details visit: https://huggingface.co/datasets/samsum
-
-import datasets
-
-from llama_recipes.datasets.utils import Concatenator
-
-def get_preprocessed_samsum(dataset_config, tokenizer, split):
- dataset = datasets.load_dataset("samsum", split=split)
-
- prompt = (
- f"Summarize this dialog:\n{{dialog}}\n---\nSummary:\n{{summary}}{{eos_token}}"
- )
-
- def apply_prompt_template(sample):
- return {
- "text": prompt.format(
- dialog=sample["dialogue"],
- summary=sample["summary"],
- eos_token=tokenizer.eos_token,
- )
- }
-
- dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
-
- dataset = dataset.map(
- lambda sample: tokenizer(sample["text"]),
- batched=True,
- remove_columns=list(dataset.features),
- ).map(Concatenator(), batched=True)
- return dataset
diff --git a/src/llama_recipes/datasets/utils.py b/src/llama_recipes/datasets/utils.py
deleted file mode 100644
index 0a11d8c..0000000
--- a/src/llama_recipes/datasets/utils.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
-
-from tqdm import tqdm
-from itertools import chain
-
-from torch.utils.data import Dataset
-
-class Concatenator(object):
- def __init__(self, chunk_size=2048):
- self.chunk_size=chunk_size
- self.residual = {"input_ids": [], "attention_mask": []}
-
- def __call__(self, batch):
- concatenated_samples = {
- k: v + list(chain(*batch[k])) for k, v in self.residual.items()
- }
-
- total_length = len(concatenated_samples[list(concatenated_samples.keys())[0]])
-
- if total_length >= self.chunk_size:
- chunk_num = total_length // self.chunk_size
- result = {
- k: [
- v[i : i + self.chunk_size]
- for i in range(0, chunk_num * self.chunk_size, self.chunk_size)
- ]
- for k, v in concatenated_samples.items()
- }
- self.residual = {
- k: v[(chunk_num * self.chunk_size) :]
- for k, v in concatenated_samples.items()
- }
- else:
- result = concatenated_samples
- self.residual = {k: [] for k in concatenated_samples.keys()}
-
- result["labels"] = result["input_ids"].copy()
-
- return result
-
-class ConcatDataset(Dataset):
- def __init__(self, dataset, chunk_size=4096):
- self.dataset = dataset
- self.chunk_size = chunk_size
-
- self.samples = []
-
- buffer = {
- "input_ids": [],
- "attention_mask": [],
- "labels": [],
- }
-
- for sample in tqdm(self.dataset, desc="Preprocessing dataset", dynamic_ncols=True):
- buffer = {k: v + sample[k] for k,v in buffer.items()}
-
- while len(next(iter(buffer.values()))) > self.chunk_size:
- self.samples.append({k: v[:self.chunk_size] for k,v in buffer.items()})
- buffer = {k: v[self.chunk_size:] for k,v in buffer.items()}
-
- def __getitem__(self, idx):
- return self.samples[idx]
-
- def __len__(self):
- return len(self.samples)
diff --git a/src/llama_recipes/utils/dataset_utils.py b/src/llama_recipes/utils/dataset_utils.py
index 6d5f02c..18955b6 100644
--- a/src/llama_recipes/utils/dataset_utils.py
+++ b/src/llama_recipes/utils/dataset_utils.py
@@ -7,7 +7,7 @@ from pathlib import Path
import torch
-from llama_recipes.datasets import (
+from llama_recipes.datasets2 import (
get_grammar_dataset,
get_alpaca_dataset,
get_samsum_dataset,
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment