Created
January 14, 2021 11:48
-
-
Save johndpope/fc09184dc56c45578b605fa42e3df099 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# -*- coding: utf-8 -*- | |
"""DALL-E Pytorch - COCO dataset | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1KxG1iGBoKt2fLVH7uXG_vhvll2OlFkey | |
**Install necessary dependencies** | |
""" | |
# Commented out IPython magic to ensure Python compatibility. | |
!pip install einops | |
!pip install x-transformers | |
!git clone https://github.com/htoyryla/DALLE-pytorch | |
# %cd DALLE-pytorch/ | |
"""**Download and process COCO dataset**""" | |
from PIL import Image | |
import json | |
import os | |
#Removes old data files | |
cleanBeforeProcessing = True | |
#Image dimensions | |
imgSize = 256 | |
#Download image dataset | |
!wget http://images.cocodataset.org/zips/train2017.zip | |
#Extract archive | |
!unzip train2017.zip | |
#Clean up | |
!rm train2017.zip | |
#Resize all images | |
for root, dirs, files in os.walk("train2017"): | |
for fname in files: | |
if ".png" in fname or ".jpg" in fname: | |
path = os.path.join(root, fname) | |
im = Image.open(path) | |
im_resized = im.resize((imgSize, imgSize)) | |
im_resized.save(path) | |
print("Resizing: " + fname) | |
#Download captions | |
!wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip | |
#Extract archive | |
!unzip annotations_trainval2017.zip | |
#Make data format compatable with trainDALLE.py | |
if cleanBeforeProcessing: | |
os.system('rm od-captionsonly.txt') | |
os.system('rm od-captions.txt') | |
#Read images information | |
with open('annotations/captions_train2017.json', 'r') as f: | |
array = json.load(f) | |
#Read all images paths | |
imagePaths = [] | |
for root, dirs, files in os.walk("train2017"): | |
for fname in files: | |
if ".png" in fname or ".jpg" in fname: | |
path = os.path.join(root, fname) | |
imagePaths.append(path) | |
#Read information for each image | |
idAlreadyProcessed = {} | |
count = 0 | |
for info in array['annotations']: | |
image_id = int(info['id']) | |
temp = "000000000000" | |
image_id_str = temp[:len(temp)-len(str(image_id))] + str(image_id) | |
path = "train2017/" + image_id_str + ".jpg" | |
caption = info['caption'] | |
#Choose first caption and image_id | |
with open("od-captionsonly.txt", 'a') as out: | |
out.write(caption + '\n') | |
with open("od-captions.txt", 'a') as out: | |
out.write(str(image_id) + " : " + caption + '\n') | |
#In order to make sure there only are one caption for each image | |
idAlreadyProcessed[image_id] = True | |
count += 1 | |
if count%5000 == 0: | |
print("Processed " + str(count)) | |
**Traing DALL-E** | |
!python3 trainDALLE.py |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment