Skip to content

Instantly share code, notes, and snippets.

@RomeoV
Created April 29, 2022 22:09
Show Gist options
  • Save RomeoV/a6f7221be60a3f6c71f090bb10e7649b to your computer and use it in GitHub Desktop.
Save RomeoV/a6f7221be60a3f6c71f090bb10e7649b to your computer and use it in GitHub Desktop.
# Implementation details:
# Note that the device_ch is unbuffered.
# Thus, the spawned thread loads data X to gpu and then waits at the
# `put!` command until the previously batch is consumed by the main thread.
"Asynchronously apply data transform to batches and bring one batch to the device."
async_data_prep(train_loader, transform_fn, device; num_threads=4) =
Channel{Batch_t}(spawn=true) do device_ch
transformed_data = throttled_parallel_data_transform(
train_loader, transform_fn, num_threads)
for (X, Y) in transformed_data
X = X |> device
put!(device_ch, (X, Y))
end
end
# Implementation details:
# We don't want to transform all the batches at once.
# Therefore, we use a "throttle_channel" that only let's though
# a fixed number of workers.
"Apply data transform on multiple threads."
throttled_parallel_data_transform(train_loader, transform_fn, num_threads) =
Channel{Batch_t}(num_threads; spawn=true) do threaded_ch
throttle_channel = Channel{Tuple}(num_threads) # make sure only N threads are running at a time
@sync for (x, y) in train_loader
put!(throttle_channel, ()) # "take a spot"
Threads.@spawn begin
put!(threaded_ch, transform((x, y)))
take!(throttle_channel) # "release a spot" when finished
end
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment