Skip to content

Instantly share code, notes, and snippets.

@tals
Created March 16, 2020 14:27
Show Gist options
  • Save tals/941ea27021490eb60f4fa164dce4b413 to your computer and use it in GitHub Desktop.
Save tals/941ea27021490eb60f4fa164dce4b413 to your computer and use it in GitHub Desktop.
A StyleGAN PNG loader. Saves the need for intermediate TFRecords if you have the CPU/GPU for it
"""
A StyleGAN PNG data loader.
Saves the need for intermediate TFRecords if you have the CPU/GPU for it.
Apply the following patch:
diff --git a/run_training.py b/run_training.py
index bc4c0a2..61e5a33 100755
--- a/run_training.py
+++ b/run_training.py
@@ -58,7 +58,11 @@ def run(dataset, data_dir, result_dir, config_id, num_gpus, total_kimg, gamma, m
desc = 'stylegan2'
desc += '-' + dataset
- dataset_args = EasyDict(tfrecord_dir=dataset)
+ dataset_args = EasyDict(
+ path=os.path.join(data_dir, dataset),
+ class_name='artbreeder_extensions.dataset.PNGDataset',
+ num_workers=16
+ )
assert num_gpus in [1, 2, 4, 8]
sc.num_gpus = num_gpus
"""
from pathlib import Path
import numpy as np
import tensorflow as tf
from dnnlib import tflib
class PNGDataset:
def __init__(
self,
path,
num_workers=16,
repeat=True,
shuffle_mb = 4096, # Shuffle data within specified window (megabytes), 0 = disable shuffling.
prefetch_mb = 2048, # Amount of data to prefetch (megabytes), 0 = disable prefetching.
):
self.shape = [3, 1024, 1024]
self.resolution = 1024
self.label_size = 0
self.dtype = np.uint8
self.label_dtype = np.float32
self.dynamic_range = [0, 255]
self.resolution_log2 = int(np.log2(self.resolution))
path = Path(path)
files = [str(x) for x in path.glob("*.png") if x.stat().st_size > 0]
print(f"found {len(files)} files")
with tf.name_scope("Dataset"), tf.device("/cpu:0"):
self._tf_minibatch_in = tf.placeholder(
tf.int64, name="minibatch_in", shape=[]
)
dset = tf.data.Dataset.from_generator(
lambda: files,
output_types=(tf.string),
)
def map_func(x):
x = tf.io.read_file(x)
x = tf.io.decode_png(x)
x = tf.transpose(x, (2,0,1)) # HWC to CHW
lbl = tf.zeros((0,), dtype=tf.float32)
return (x, lbl)
dset = dset.map(
map_func,
num_parallel_calls=num_workers
)
bytes_per_item = np.prod(self.shape) * np.dtype(self.dtype).itemsize
if shuffle_mb > 0:
dset = dset.shuffle(((shuffle_mb << 20) - 1) // bytes_per_item + 1)
if repeat:
dset = dset.repeat()
if prefetch_mb > 0:
dset = dset.prefetch(((prefetch_mb << 20) - 1) // bytes_per_item + 1)
dset = dset.batch(self._tf_minibatch_in)
self._cur_minibatch = 0
self._tf_minibatch_np = None
self._tf_iterator = tf.data.Iterator.from_structure(
output_types=dset.output_types,
output_shapes=dset.output_shapes,
)
self._tf_init_op = self._tf_iterator.make_initializer(dset)
def close(self):
pass
def configure(self, minibatch_size, lod=0):
assert minibatch_size >= 1
if self._cur_minibatch != minibatch_size:
self._tf_init_op.run({self._tf_minibatch_in: minibatch_size})
self._cur_minibatch = minibatch_size
def get_minibatch_tf(self): # => images, labels
return self._tf_iterator.get_next()
# Get next minibatch as NumPy arrays.
def get_minibatch_np(self, minibatch_size, lod=0): # => images, labels
self.configure(minibatch_size)
with tf.name_scope("Dataset"):
if self._tf_minibatch_np is None:
self._tf_minibatch_np = self.get_minibatch_tf()
return tflib.run(self._tf_minibatch_np)
def get_random_labels_tf(self, minibatch_size): # => labels
with tf.name_scope('Dataset'):
return tf.zeros([minibatch_size, 0], self.label_dtype)
# Get random labels as NumPy array.
def get_random_labels_np(self, minibatch_size): # => labels
return np.zeros([minibatch_size, 0], self.label_dtype)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment