Skip to content

Instantly share code, notes, and snippets.

@XinyueZ
Created November 12, 2022 00:08
Show Gist options
  • Save XinyueZ/6a309f6696fd1a1a841342f82b7a2f8e to your computer and use it in GitHub Desktop.
Save XinyueZ/6a309f6696fd1a1a841342f82b7a2f8e to your computer and use it in GitHub Desktop.
PixelRNNs, dataset logic
def pad_nulls(arraylike, seq_len, input_size):
'''
Pad nulls to the end of the arraylike, so that the length of the arraylike
is divisible by the batch_size.
'''
if len(arraylike)==0:
arraylike = np.zeros( (seq_len, input_size), dtype=np.float32)
remainder = len(arraylike) % seq_len
if remainder != 0:
nulls = np.zeros((seq_len - remainder, input_size), dtype=np.float32)
arraylike = np.concatenate((nulls, arraylike), axis=0)
return arraylike
X, Y = list(), list()
seq_len = image_flatten.shape[0]
input_size = image_flatten.shape[-1]
for idx, row in enumerate(image_flatten):
x = image_flatten[0 : idx+1, :]
y = image_flatten[idx+1, :] if idx+1<len(image_flatten) else np.zeros((input_size), dtype=np.float32)
x, y = pad_nulls(x, seq_len, input_size), y
x, y = np.array(x), np.array(y)
X.append(x); Y.append(y)
X_np, Y_np = np.array(X), np.array(Y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment