Skip to content

Instantly share code, notes, and snippets.

@alfredplpl
Last active May 15, 2022 11:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alfredplpl/f721a07aa6e07f47909866bb9c4409e0 to your computer and use it in GitHub Desktop.
Save alfredplpl/f721a07aa6e07f47909866bb9c4409e0 to your computer and use it in GitHub Desktop.
Digit Generation by DDPM
import torch
from torch.utils.data import Dataset
from torchvision import datasets, transforms
import torch.optim as optim
import cv2
from tqdm import tqdm
from denoising_diffusion_pytorch import Unet, GaussianDiffusion
import numpy as np
class Mydataset(Dataset):
def __init__(self, mnist_data):
self.mnist = mnist_data
def __len__(self):
return len(self.mnist)
def __getitem__(self, idx):
X = self.mnist[idx][0]
X = (X * 2) - 1
y = self.mnist[idx][1]
return X, y
mnist_data = datasets.MNIST('.',
transform = transforms.Compose([
transforms.Resize(32,transforms.InterpolationMode.NEAREST),
transforms.ToTensor()
]),
download=True)
mnist_data_norm = Mydataset(mnist_data)
dataloader = torch.utils.data.DataLoader(mnist_data_norm,
batch_size=512,
shuffle=True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = Unet(
dim = 16,
dim_mults = (1, 2, 4, 8),
channels=1
).to(device)
diffusion = GaussianDiffusion(
model,
image_size = 32,
timesteps = 1000, # number of steps
loss_type = 'l1',
channels=1
).to(device)
optimizer = optim.Adam(diffusion.parameters())
for epoch in range(2000):
for i, data in enumerate(tqdm(dataloader), 0):
optimizer.zero_grad()
batch = data[0].to(device)
loss = diffusion(batch)
loss.backward()
optimizer.step()
imgs = diffusion.sample(batch_size = 16)
imgs = imgs.cpu()
imgs_img = np.empty((4 * imgs.shape[2], 4 * imgs.shape[3]), dtype=np.float32)
for y in range(4):
for x in range(4):
imgs_img[y * 32:(y + 1) * 32, x * 32:(x + 1) * 32] = imgs[y*4+x]
imgs_img = np.array(imgs_img * 255, dtype=np.uint8)
cv2.imwrite(f"./MNIST/ddpm_{epoch}.bmp", imgs_img)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment