Skip to content

Instantly share code, notes, and snippets.

@Andy-P
Created December 14, 2015 05:30
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Andy-P/f661dabfca7805715a72 to your computer and use it in GitHub Desktop.
Save Andy-P/f661dabfca7805715a72 to your computer and use it in GitHub Desktop.
function generate_batch(batch_size, num_skips, skip_window, data_index)
@assert batch_size % num_skips == 0
@assert num_skips <= 2 * skip_window
batch = zeros(Int32, (batch_size))
labels = zeros(Int32, (batch_size,1))
span = 2 * skip_window + 1
buffer = Int[]
for _ in 1:span
push!(buffer,data[data_index])
data_index = (data_index + 1) > length(data)? 1:(data_index + 1) # circular pointer
end
for i in 0:round(Int,batch_size / num_skips)-1
target = skip_window +1 # target label at the center of the buffer
targets_to_avoid = [ (skip_window +1) ]
for j in 0:num_skips-1
while target in targets_to_avoid
target = rand(collect(1:span)) #random.randint(0, span - 1)
end
push!(targets_to_avoid, target)
batch[i * (num_skips) + j + 1] = buffer[skip_window+1]
labels[i * (num_skips) + j + 1, 1] = buffer[target]
end
push!(buffer,data[data_index]);
while length(buffer) > span shift!(buffer) end
data_index = (data_index + 1) > length(data)? 1:(data_index + 1)
end
return batch, labels, data_index
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment