Skip to content

Instantly share code, notes, and snippets.

@iiSeymour
Created March 8, 2023 13:55
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save iiSeymour/bc3fac8280e0259bc22c922d8631a84c to your computer and use it in GitHub Desktop.
Save iiSeymour/bc3fac8280e0259bc22c922d8631a84c to your computer and use it in GitHub Desktop.
LSTM.py
import torch
import time
torch.backends.cudnn.benchmark = True
BATCH_SIZE = 64
LAYER_SIZE = 384
TIME_SERIES_LEN = 1600
NUM_LAYERS = 5
WARMUP_ROUNDS = 5
BENCHMARK_ROUNDS = 10
NUM_STREAMS = 5
def calculate_flops(batch_size, num_features, num_layers, time_series_len, num_gates):
num_weight_matrices = 2 # hidden-hidden weights as well as the input weights
return time_series_len * batch_size * num_features * num_features * num_gates * num_weight_matrices * 2
streams = [torch.cuda.Stream() for _ in range (NUM_STREAMS)]
# Instantiate lstms
lstms = [torch.nn.LSTM(LAYER_SIZE, LAYER_SIZE, NUM_LAYERS, bias=False, batch_first=False).cuda().half().eval() for _ in range(NUM_STREAMS)]
# Create some input data:
datas = [torch.rand(TIME_SERIES_LEN, BATCH_SIZE, LAYER_SIZE).cuda().half() for _ in range(NUM_STREAMS)]
#Warmup
for i in range(WARMUP_ROUNDS):
for s in range(NUM_STREAMS):
lstms[s](datas[s])
#benchmark
torch.cuda.synchronize()
t0 = time.time()
for s in range(NUM_STREAMS):
for i in range(BENCHMARK_ROUNDS):
with torch.cuda.stream(streams[s]):
lstms[s](datas[s])
torch.cuda.synchronize()
tf = time.time()
t_total = tf - t0
num_samples = TIME_SERIES_LEN * BATCH_SIZE * BENCHMARK_ROUNDS * NUM_STREAMS # at ONT, we use "sample" to mean "time point"
flops = calculate_flops(BATCH_SIZE, LAYER_SIZE, NUM_LAYERS, TIME_SERIES_LEN, 4)
TFLOPS = (flops * BENCHMARK_ROUNDS * NUM_STREAMS) / t_total / 1e12
v100_peak_tflops = 125
print("Took", t_total, "seconds")
print("MSample/s =", num_samples/t_total/1e6)
print("TFLOPS = ", TFLOPS)
print("% peak teoretical (V100) = ", TFLOPS/v100_peak_tflops * 100)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment