Created
April 29, 2022 22:09
-
-
Save RomeoV/a6f7221be60a3f6c71f090bb10e7649b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Implementation details: | |
# Note that the device_ch is unbuffered. | |
# Thus, the spawned thread loads data X to gpu and then waits at the | |
# `put!` command until the previously batch is consumed by the main thread. | |
"Asynchronously apply data transform to batches and bring one batch to the device." | |
async_data_prep(train_loader, transform_fn, device; num_threads=4) = | |
Channel{Batch_t}(spawn=true) do device_ch | |
transformed_data = throttled_parallel_data_transform( | |
train_loader, transform_fn, num_threads) | |
for (X, Y) in transformed_data | |
X = X |> device | |
put!(device_ch, (X, Y)) | |
end | |
end | |
# Implementation details: | |
# We don't want to transform all the batches at once. | |
# Therefore, we use a "throttle_channel" that only let's though | |
# a fixed number of workers. | |
"Apply data transform on multiple threads." | |
throttled_parallel_data_transform(train_loader, transform_fn, num_threads) = | |
Channel{Batch_t}(num_threads; spawn=true) do threaded_ch | |
throttle_channel = Channel{Tuple}(num_threads) # make sure only N threads are running at a time | |
@sync for (x, y) in train_loader | |
put!(throttle_channel, ()) # "take a spot" | |
Threads.@spawn begin | |
put!(threaded_ch, transform((x, y))) | |
take!(throttle_channel) # "release a spot" when finished | |
end | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment