Skip to content

Instantly share code, notes, and snippets.

@wut0n9
Created January 18, 2019 11:27
Show Gist options
  • Save wut0n9/cacaace930d546cba8fc4efe9360ceaa to your computer and use it in GitHub Desktop.
Save wut0n9/cacaace930d546cba8fc4efe9360ceaa to your computer and use it in GitHub Desktop.
生成批训练样本
def generate_batch(batch_size, data_vec, word_to_int):
n_chunk = len(data_vec) // batch_size
x_batches = []
y_batches = []
for i in range(n_chunk):
start_index = i * batch_size
end_index = start_index + batch_size
batches = data_vec[start_index:end_index]
length = max(map(len, batches))
x_data = np.full((batch_size, length), word_to_int[UNK_TOKEN], np.int32)
for row in range(batch_size):
x_data[row, :len(batches[row])] = batches[row]
y_data = np.copy(x_data)
y_data[:, :-1] = x_data[:, 1:]
"""
x_data y_data
[6,2,4,6,9] [2,4,6,9,9]
[1,4,2,8,5] [4,2,8,5,5]
"""
x_batches.append(x_data)
y_batches.append(y_data)
return x_batches, y_batches
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment