Skip to content

Instantly share code, notes, and snippets.

@breakds
Last active May 11, 2022 17:54
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save breakds/d39bee85c81afdbee0a490aa690b2849 to your computer and use it in GitHub Desktop.
Save breakds/d39bee85c81afdbee0a490aa690b2849 to your computer and use it in GitHub Desktop.
Demonstrate multi_threading in pytorch on tensor slicing (CPU)
import torch
import time
num_envs = 40
buffer_length = 1_000_000
buffer = torch.zeros([num_envs, buffer_length]).cpu()
B = 5120
T = 100
result = torch.zeros([B, T]).cpu()
N = 100
for num_threads in range(1, 24):
torch.set_num_threads(num_threads)
total_time_elapsed = 0.
for i in range(N):
env_ids = torch.randint(num_envs, size=(B, )).cpu() # [B,]
env_ids = env_ids.unsqueeze(-1).cpu() # [B, 1]
idx = torch.arange(0, T).cpu() # [T,]
start = time.perf_counter()
batch = buffer[env_ids, idx]
elapsed = time.perf_counter() - start
total_time_elapsed += elapsed
# print(f'Elapsed: {elapsed} seconds')
result += batch
print(f'Average: {total_time_elapsed / N} seconds with {num_threads} threads')
print(result.shape)
# Output with experiment on Ryzen 9 3900X
# Average: 0.0027781074051745234 seconds with 1 threads
# Average: 0.0027162336080800743 seconds with 2 threads
# Average: 0.0027230232767760753 seconds with 3 threads
# Average: 0.0027171976200770585 seconds with 4 threads
# Average: 0.0029364847659599036 seconds with 5 threads
# Average: 0.002757620553020388 seconds with 6 threads
# Average: 0.002800966005306691 seconds with 7 threads
# Average: 0.001432317856233567 seconds with 8 threads
# Average: 0.0011749703669920563 seconds with 9 threads
# Average: 0.0010148775530979038 seconds with 10 threads
# Average: 0.0007950463821180165 seconds with 11 threads
# Average: 0.0008191179751884192 seconds with 12 threads
# Average: 0.0005180383194237948 seconds with 13 threads
# Average: 0.0005265644623432308 seconds with 14 threads
# Average: 0.00043936547241173684 seconds with 15 threads
# Average: 0.0006872055784333497 seconds with 16 threads
# Average: 0.0005627814284525812 seconds with 17 threads
# Average: 0.000571787329390645 seconds with 18 threads
# Average: 0.000550786901731044 seconds with 19 threads
# Average: 0.0004963862628210336 seconds with 20 threads
# Average: 0.0004958512110169977 seconds with 21 threads
# Average: 0.0004888727166689933 seconds with 22 threads
# Average: 0.0005681470630224795 seconds with 23 threads
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment