Created
March 27, 2023 13:49
-
-
Save sayakpaul/f24f287b9f7443ef44c65705d3003b34 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Adapted from https://github.com/borisdayma/dalle-mini/blob/main/tools/train/train.py | |
import random | |
from dataclasses import dataclass, field | |
from functools import partial | |
import jax | |
import jax.numpy as jnp | |
import numpy as np | |
from datasets import Dataset, load_dataset | |
@dataclass | |
class Dataset: | |
dataset_repo_or_path: str | |
resolution: int | |
batch_size: int | |
streaming: bool = True | |
image_column: str = "image" | |
conditioning_image_column: str = "conditioning_image" | |
caption_column: str = "caption" | |
max_train_samples: int = None | |
preprocessing_num_workers: int = None | |
overwrite_cache: bool = False | |
seed_dataset: int = None | |
train_dataset: Dataset = field(init=False) | |
rng_dataset: jnp.ndarray = field(init=False) | |
multi_hosts: bool = field(init=False) | |
def __post_init__(self): | |
if self.seed_dataset is None: | |
# create a random seed | |
self.seed_dataset = random.randint(0, 2**32 - 1) | |
# set numpy rng | |
self.np_rng = np.random.default_rng(self.seed_dataset) | |
self.multi_hosts = jax.process_count() > 1 | |
# load dataset | |
dataset = load_dataset( | |
self.dataset_repo_or_path, | |
streaming=self.streaming | |
) | |
if "train" not in dataset: | |
raise ValueError("Training requires a training dataset") | |
self.train_dataset = dataset["train"] | |
if self.max_train_samples is not None: | |
self.train_dataset = ( | |
self.train_dataset.take(self.max_train_samples) | |
if self.streaming | |
else self.train_dataset.select(range(self.max_train_samples)) | |
) | |
# set up augmentation chain | |
def preprocess(self, tokenizer): | |
if self.streaming: | |
# we need to shuffle early in streaming mode | |
if hasattr(self, "train_dataset"): | |
self.train_dataset = self.train_dataset.shuffle( | |
buffer_size=5000, seed=self.seed_dataset | |
) | |
else: | |
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset) | |
# preprocess | |
partial_preprocess_function = partial( | |
preprocess_function, | |
tokenizer=tokenizer, | |
image_column=self.image_column, | |
conditioning_column=self.image_column, | |
caption_column=self.caption_column, | |
resolution=self.resolution | |
) | |
ds = "train_dataset" | |
if hasattr(self, ds): | |
setattr( | |
self, | |
ds, | |
( | |
getattr(self, ds).map( | |
partial_preprocess_function, | |
batched=True, | |
batch_size=self.batch_size | |
) | |
if self.streaming | |
else getattr(self, ds).map( | |
partial_preprocess_function, | |
batched=True, | |
batch_size=self.batch_size, | |
num_proc=self.preprocessing_num_workers, | |
load_from_cache_file=not self.overwrite_cache, | |
desc="Preprocessing datasets", | |
) | |
), | |
) | |
def dataloader(self, split, batch_size, epoch=None): | |
def _dataloader_datasets_non_streaming( | |
dataset: Dataset, | |
rng: jax.random.PRNGKey = None, | |
): | |
""" | |
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices. | |
Shuffle batches if rng is set. | |
""" | |
steps_per_epoch = len(dataset) // batch_size | |
if rng is not None: | |
batch_idx = jax.random.permutation(rng, len(dataset)) | |
else: | |
batch_idx = jnp.arange(len(dataset)) | |
batch_idx = batch_idx[ | |
: steps_per_epoch * batch_size | |
] # Skip incomplete batch. | |
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size)) | |
for idx in batch_idx: | |
batch = dataset[idx] | |
batch = {k: jnp.array(v) for k, v in batch.items()} | |
yield batch | |
def _dataloader_datasets_streaming( | |
dataset: Dataset, | |
epoch: int, | |
): | |
keys = ["pixel_values", "conditioning_pixel_values", "input_ids"] | |
batch = {k: [] for k in keys} | |
first_loop = True # stop after one loop in some cases | |
while (self.multi_hosts and split == "train") or first_loop: | |
# in multi-host, we run forever (no epoch) as hosts need to stop | |
# at the same time and training data may not be split equally | |
# For validation data we put the entire batch on each host and then | |
# keep only the one specific to each host (could be improved but not necessary) | |
if epoch is not None: | |
assert split == "train" | |
# reshuffle training data at each epoch | |
dataset.set_epoch(epoch) | |
epoch += 1 | |
for item in dataset: | |
for k in keys: | |
batch[k].append(item[k]) | |
if len(batch[keys[0]]) == batch_size: | |
batch = {k: jnp.array(v) for k, v in batch.items()} | |
yield batch | |
batch = {k: [] for k in keys} | |
first_loop = False | |
if split == "train": # Only this split is supported for now. | |
ds = self.train_dataset | |
if self.streaming: | |
return _dataloader_datasets_streaming(ds, epoch) | |
else: | |
if split == "train": | |
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset) | |
return _dataloader_datasets_non_streaming(ds, input_rng) | |
@property | |
def length(self): | |
len_train_dataset, len_eval_dataset = None, None | |
if self.streaming: | |
# we don't know the length, let's just assume max_samples if defined | |
if self.max_train_samples is not None: | |
len_train_dataset = self.max_train_samples | |
if self.max_eval_samples is not None: | |
len_eval_dataset = self.max_eval_samples | |
else: | |
len_train_dataset = ( | |
len(self.train_dataset) if hasattr(self, "train_dataset") else None | |
) | |
len_eval_dataset = ( | |
len(self.eval_dataset) if hasattr(self, "eval_dataset") else None | |
) | |
return len_train_dataset, len_eval_dataset | |
def tokenize_captions(examples, tokenizer, caption_column): | |
captions = [] | |
for caption in examples[caption_column]: | |
captions.append(caption) | |
inputs = tokenizer( | |
captions, | |
max_length=tokenizer.model_max_length, | |
padding="max_length", | |
truncation=True, | |
return_tensors="np", | |
) | |
return inputs.input_ids | |
def preprocess_image(image, resolution, rescale=False): | |
image = np.array(image.convert("RGB").resize((resolution, resolution))) | |
image = image / 255. | |
if rescale: # [0, 1] | |
image = (2. * image) - 1. | |
return image | |
def preprocess_function(examples, tokenizer, | |
image_column, | |
conditioning_column, | |
caption_column, resolution): | |
images = [preprocess_image(image, resolution, rescale=True) for image in examples[image_column]] | |
conditioning_images = [preprocess_image(image, resolution, rescale=False) for image in examples[conditioning_column]] | |
examples["pixel_values"] = images | |
examples["conditioning_pixel_values"] = conditioning_images | |
examples["input_ids"] = tokenize_captions(examples, tokenizer, caption_column) | |
return examples |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment