Last active
May 24, 2019 20:17
-
-
Save suyashdamle/a2afb1baed28ed664fa46c8f4107f103 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
- For generator: (creates set of fake images in NCHW format ; | |
2nd channel being label encodings) | |
- Generate random integers between 0 and 9 (of 1 X batch_size shape). | |
These would be our labels for fake images | |
- Convert these into one-hot encoding : --> (batch_size, 10 ) | |
- Broadcast to (batch_size X 10 X 10 shape) --> (batch_size, 10, 10) | |
- Add another dimension for channel: --> (batch_size, 1, 10, 10) | |
- Generate a 10X10 noise matrix using mx.random.normal() --> (batch_size, 1, 10, 10) | |
- Concatenate the label encodings with noise: --> (batch_size, 2, 10, 10) | |
''' | |
class GenInputIter(mx.io.DataIter): | |
def __init__(self, batch_size, dim_x, dim_y): | |
self.batch_size = batch_size | |
self.ndim = (dim_x,dim_y) | |
self.provide_data = [('rand_label', (batch_size, 2, dim_x, dim_y))] | |
self.provide_label = [('labels',(batch_size,))] | |
def iter_next(self): | |
return True | |
def __iter__(self): | |
return self | |
def __next__(self): | |
return self.next() | |
def reset(self): | |
self.current_batch = 0 | |
def next(self): | |
labels = mx.ndarray.random.randint(0,10,shape=(batch_size,)) | |
rand = mx.random.normal(0, 1.0, shape=(self.batch_size , 1, | |
self.ndim[0], self.ndim[1])) | |
gen_label = mx.ndarray.one_hot(mx.nd.array(labels) | |
,10).expand_dims(axis=2) | |
gen_label = gen_label.broadcast_to((self.batch_size, self.ndim[0], | |
self.ndim[1])).expand_dims(axis=1) | |
gen_label_batch = mx.ndarray.concat(rand, gen_label, dim =1) | |
data = [gen_label_batch.as_in_context(ctx)] | |
label = [labels.as_in_context(ctx)] | |
return mx.io.DataBatch(data),label |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment