Skip to content

Instantly share code, notes, and snippets.

Created January 10, 2020 07:41
Show Gist options
  • Save david-macleod/2b933d28fd3ac09766785728ee191f09 to your computer and use it in GitHub Desktop.
Save david-macleod/2b933d28fd3ac09766785728ee191f09 to your computer and use it in GitHub Desktop.
def plot_timings(loader, n_batches, model_time=0.2, max_time=2.5):
fig, ax = plt.subplots()
ax.yaxis.grid(which="major", color='black', linewidth=1)
zero_time = time.time()
worker_ids = {}
worker_count = count()
for result in islice(loader, n_batches):
start = time.time()
end = time.time()
# check if already batched
if isinstance(result[0], torch.Tensor):
result = zip(*result)
batch = []
for item in result:
data, worker, t1, t2 = tuple(map(scalar, item))
# fix worker position in plot
if worker != -1:
if worker not in worker_ids:
worker_ids[worker] = next(worker_count)
worker = worker_ids[worker]
plot_time_box(data, worker, t1-zero_time, t2-zero_time, ax)
batch_str = ",".join(map(str, batch))
plot_time_box(batch_str, -1, start-zero_time, end-zero_time, ax, color='firebrick')
max_worker = len(worker_ids) - 1
ax.set_xlim(0, max_time)
ax.set_ylim(-1.5, max_worker + 0.5)
ax.set_xticks(np.arange(0, max_time, 0.2))
ax.set_yticks(np.arange(-1, max_worker+1, 1))
ax.tick_params(axis='y', colors=(0,0,0,0))
fig.set_figheight((max_worker + 2) * 0.5)
ax.tick_params(axis='x', colors='gray')
for spine in ax.spines.values():
def scalar(x):
return x.item() if hasattr(x, 'item') else x
def plot_time_box(data, worker, t1, t2, ax, color='steelblue'):
x = t1
y = worker - 0.25
w = t2 - t1
h = 0.6
rect = Rectangle((x, y), w, h, linewidth=2, edgecolor='black',facecolor=color)
ax.text(x + (w * 0.5), y + (h * 0.5), str(data), va='center', ha='center', color='white', weight='bold')
Copy link

hminle commented Sep 4, 2020

Hi, this is itertools.count()

Great! Thank you a lot. I can run it now 👍

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment