Skip to content

Instantly share code, notes, and snippets.

@8bitnand
Last active November 9, 2023 14:34
Show Gist options
  • Save 8bitnand/ade2d6d5e2a7dfd3193e1baaa7b32e1f to your computer and use it in GitHub Desktop.
Save 8bitnand/ade2d6d5e2a7dfd3193e1baaa7b32e1f to your computer and use it in GitHub Desktop.
Unconditional image generation - part 1: the diffusion. changes from the original code.
# source https://huggingface.co/docs/diffusers/tutorials/basic_training
def load_pipline(config):
pipeline = DDPMPipeline.from_pretrained(
"mrm8488/ddpm-ema-butterflies-128",
cache_dir="models/pretrained",
)
return pipeline.unet
def main():
config.dataset_name = "m1guelpf/nouns"
train_dataset = load_dataset(config.dataset_name, split="train")
train_dataset.set_transform(transform_stc)
train_dataloader = torch.utils.data.DataLoader(
train_dataset, batch_size=config.train_batch_size, shuffle=True
)
model = load_pipline(config)
noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
lr_scheduler = get_cosine_schedule_with_warmup(
optimizer=optimizer,
num_warmup_steps=config.lr_warmup_steps,
num_training_steps=(len(train_dataloader) * config.num_epochs),
)
train_loop(
config=config,
model=model,
noise_scheduler=noise_scheduler,
optimizer=optimizer,
train_dataloader=train_dataloader,
lr_scheduler=lr_scheduler,
)
if __name__ == "__main__":
main()
# cd to project dir
# accelerate config
# accelerate launch diffusion.model_main.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment