Last active
May 11, 2022 17:54
-
-
Save breakds/d39bee85c81afdbee0a490aa690b2849 to your computer and use it in GitHub Desktop.
Demonstrate multi_threading in pytorch on tensor slicing (CPU)
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import time | |
num_envs = 40 | |
buffer_length = 1_000_000 | |
buffer = torch.zeros([num_envs, buffer_length]).cpu() | |
B = 5120 | |
T = 100 | |
result = torch.zeros([B, T]).cpu() | |
N = 100 | |
for num_threads in range(1, 24): | |
torch.set_num_threads(num_threads) | |
total_time_elapsed = 0. | |
for i in range(N): | |
env_ids = torch.randint(num_envs, size=(B, )).cpu() # [B,] | |
env_ids = env_ids.unsqueeze(-1).cpu() # [B, 1] | |
idx = torch.arange(0, T).cpu() # [T,] | |
start = time.perf_counter() | |
batch = buffer[env_ids, idx] | |
elapsed = time.perf_counter() - start | |
total_time_elapsed += elapsed | |
# print(f'Elapsed: {elapsed} seconds') | |
result += batch | |
print(f'Average: {total_time_elapsed / N} seconds with {num_threads} threads') | |
print(result.shape) | |
# Output with experiment on Ryzen 9 3900X | |
# Average: 0.0027781074051745234 seconds with 1 threads | |
# Average: 0.0027162336080800743 seconds with 2 threads | |
# Average: 0.0027230232767760753 seconds with 3 threads | |
# Average: 0.0027171976200770585 seconds with 4 threads | |
# Average: 0.0029364847659599036 seconds with 5 threads | |
# Average: 0.002757620553020388 seconds with 6 threads | |
# Average: 0.002800966005306691 seconds with 7 threads | |
# Average: 0.001432317856233567 seconds with 8 threads | |
# Average: 0.0011749703669920563 seconds with 9 threads | |
# Average: 0.0010148775530979038 seconds with 10 threads | |
# Average: 0.0007950463821180165 seconds with 11 threads | |
# Average: 0.0008191179751884192 seconds with 12 threads | |
# Average: 0.0005180383194237948 seconds with 13 threads | |
# Average: 0.0005265644623432308 seconds with 14 threads | |
# Average: 0.00043936547241173684 seconds with 15 threads | |
# Average: 0.0006872055784333497 seconds with 16 threads | |
# Average: 0.0005627814284525812 seconds with 17 threads | |
# Average: 0.000571787329390645 seconds with 18 threads | |
# Average: 0.000550786901731044 seconds with 19 threads | |
# Average: 0.0004963862628210336 seconds with 20 threads | |
# Average: 0.0004958512110169977 seconds with 21 threads | |
# Average: 0.0004888727166689933 seconds with 22 threads | |
# Average: 0.0005681470630224795 seconds with 23 threads |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment