Skip to content

Instantly share code, notes, and snippets.

@suyashdamle
Last active May 24, 2019 20:17
Show Gist options
  • Save suyashdamle/a2afb1baed28ed664fa46c8f4107f103 to your computer and use it in GitHub Desktop.
Save suyashdamle/a2afb1baed28ed664fa46c8f4107f103 to your computer and use it in GitHub Desktop.
'''
- For generator: (creates set of fake images in NCHW format ;
2nd channel being label encodings)
- Generate random integers between 0 and 9 (of 1 X batch_size shape).
These would be our labels for fake images
- Convert these into one-hot encoding : --> (batch_size, 10 )
- Broadcast to (batch_size X 10 X 10 shape) --> (batch_size, 10, 10)
- Add another dimension for channel: --> (batch_size, 1, 10, 10)
- Generate a 10X10 noise matrix using mx.random.normal() --> (batch_size, 1, 10, 10)
- Concatenate the label encodings with noise: --> (batch_size, 2, 10, 10)
'''
class GenInputIter(mx.io.DataIter):
def __init__(self, batch_size, dim_x, dim_y):
self.batch_size = batch_size
self.ndim = (dim_x,dim_y)
self.provide_data = [('rand_label', (batch_size, 2, dim_x, dim_y))]
self.provide_label = [('labels',(batch_size,))]
def iter_next(self):
return True
def __iter__(self):
return self
def __next__(self):
return self.next()
def reset(self):
self.current_batch = 0
def next(self):
labels = mx.ndarray.random.randint(0,10,shape=(batch_size,))
rand = mx.random.normal(0, 1.0, shape=(self.batch_size , 1,
self.ndim[0], self.ndim[1]))
gen_label = mx.ndarray.one_hot(mx.nd.array(labels)
,10).expand_dims(axis=2)
gen_label = gen_label.broadcast_to((self.batch_size, self.ndim[0],
self.ndim[1])).expand_dims(axis=1)
gen_label_batch = mx.ndarray.concat(rand, gen_label, dim =1)
data = [gen_label_batch.as_in_context(ctx)]
label = [labels.as_in_context(ctx)]
return mx.io.DataBatch(data),label
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment