Skip to content

Instantly share code, notes, and snippets.

@sayakpaul
Created March 27, 2023 13:49
Show Gist options
  • Save sayakpaul/f24f287b9f7443ef44c65705d3003b34 to your computer and use it in GitHub Desktop.
Save sayakpaul/f24f287b9f7443ef44c65705d3003b34 to your computer and use it in GitHub Desktop.
# Adapted from https://github.com/borisdayma/dalle-mini/blob/main/tools/train/train.py
import random
from dataclasses import dataclass, field
from functools import partial
import jax
import jax.numpy as jnp
import numpy as np
from datasets import Dataset, load_dataset
@dataclass
class Dataset:
dataset_repo_or_path: str
resolution: int
batch_size: int
streaming: bool = True
image_column: str = "image"
conditioning_image_column: str = "conditioning_image"
caption_column: str = "caption"
max_train_samples: int = None
preprocessing_num_workers: int = None
overwrite_cache: bool = False
seed_dataset: int = None
train_dataset: Dataset = field(init=False)
rng_dataset: jnp.ndarray = field(init=False)
multi_hosts: bool = field(init=False)
def __post_init__(self):
if self.seed_dataset is None:
# create a random seed
self.seed_dataset = random.randint(0, 2**32 - 1)
# set numpy rng
self.np_rng = np.random.default_rng(self.seed_dataset)
self.multi_hosts = jax.process_count() > 1
# load dataset
dataset = load_dataset(
self.dataset_repo_or_path,
streaming=self.streaming
)
if "train" not in dataset:
raise ValueError("Training requires a training dataset")
self.train_dataset = dataset["train"]
if self.max_train_samples is not None:
self.train_dataset = (
self.train_dataset.take(self.max_train_samples)
if self.streaming
else self.train_dataset.select(range(self.max_train_samples))
)
# set up augmentation chain
def preprocess(self, tokenizer):
if self.streaming:
# we need to shuffle early in streaming mode
if hasattr(self, "train_dataset"):
self.train_dataset = self.train_dataset.shuffle(
buffer_size=5000, seed=self.seed_dataset
)
else:
self.rng_dataset = jax.random.PRNGKey(self.seed_dataset)
# preprocess
partial_preprocess_function = partial(
preprocess_function,
tokenizer=tokenizer,
image_column=self.image_column,
conditioning_column=self.image_column,
caption_column=self.caption_column,
resolution=self.resolution
)
ds = "train_dataset"
if hasattr(self, ds):
setattr(
self,
ds,
(
getattr(self, ds).map(
partial_preprocess_function,
batched=True,
batch_size=self.batch_size
)
if self.streaming
else getattr(self, ds).map(
partial_preprocess_function,
batched=True,
batch_size=self.batch_size,
num_proc=self.preprocessing_num_workers,
load_from_cache_file=not self.overwrite_cache,
desc="Preprocessing datasets",
)
),
)
def dataloader(self, split, batch_size, epoch=None):
def _dataloader_datasets_non_streaming(
dataset: Dataset,
rng: jax.random.PRNGKey = None,
):
"""
Returns batches of size `batch_size` from truncated `dataset`, sharded over all local devices.
Shuffle batches if rng is set.
"""
steps_per_epoch = len(dataset) // batch_size
if rng is not None:
batch_idx = jax.random.permutation(rng, len(dataset))
else:
batch_idx = jnp.arange(len(dataset))
batch_idx = batch_idx[
: steps_per_epoch * batch_size
] # Skip incomplete batch.
batch_idx = batch_idx.reshape((steps_per_epoch, batch_size))
for idx in batch_idx:
batch = dataset[idx]
batch = {k: jnp.array(v) for k, v in batch.items()}
yield batch
def _dataloader_datasets_streaming(
dataset: Dataset,
epoch: int,
):
keys = ["pixel_values", "conditioning_pixel_values", "input_ids"]
batch = {k: [] for k in keys}
first_loop = True # stop after one loop in some cases
while (self.multi_hosts and split == "train") or first_loop:
# in multi-host, we run forever (no epoch) as hosts need to stop
# at the same time and training data may not be split equally
# For validation data we put the entire batch on each host and then
# keep only the one specific to each host (could be improved but not necessary)
if epoch is not None:
assert split == "train"
# reshuffle training data at each epoch
dataset.set_epoch(epoch)
epoch += 1
for item in dataset:
for k in keys:
batch[k].append(item[k])
if len(batch[keys[0]]) == batch_size:
batch = {k: jnp.array(v) for k, v in batch.items()}
yield batch
batch = {k: [] for k in keys}
first_loop = False
if split == "train": # Only this split is supported for now.
ds = self.train_dataset
if self.streaming:
return _dataloader_datasets_streaming(ds, epoch)
else:
if split == "train":
self.rng_dataset, input_rng = jax.random.split(self.rng_dataset)
return _dataloader_datasets_non_streaming(ds, input_rng)
@property
def length(self):
len_train_dataset, len_eval_dataset = None, None
if self.streaming:
# we don't know the length, let's just assume max_samples if defined
if self.max_train_samples is not None:
len_train_dataset = self.max_train_samples
if self.max_eval_samples is not None:
len_eval_dataset = self.max_eval_samples
else:
len_train_dataset = (
len(self.train_dataset) if hasattr(self, "train_dataset") else None
)
len_eval_dataset = (
len(self.eval_dataset) if hasattr(self, "eval_dataset") else None
)
return len_train_dataset, len_eval_dataset
def tokenize_captions(examples, tokenizer, caption_column):
captions = []
for caption in examples[caption_column]:
captions.append(caption)
inputs = tokenizer(
captions,
max_length=tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="np",
)
return inputs.input_ids
def preprocess_image(image, resolution, rescale=False):
image = np.array(image.convert("RGB").resize((resolution, resolution)))
image = image / 255.
if rescale: # [0, 1]
image = (2. * image) - 1.
return image
def preprocess_function(examples, tokenizer,
image_column,
conditioning_column,
caption_column, resolution):
images = [preprocess_image(image, resolution, rescale=True) for image in examples[image_column]]
conditioning_images = [preprocess_image(image, resolution, rescale=False) for image in examples[conditioning_column]]
examples["pixel_values"] = images
examples["conditioning_pixel_values"] = conditioning_images
examples["input_ids"] = tokenize_captions(examples, tokenizer, caption_column)
return examples
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment