Skip to content

Instantly share code, notes, and snippets.

@o-hanna
Last active September 3, 2022 14:37
Show Gist options
  • Save o-hanna/d4f17ad1d600014495dd13d3bd013401 to your computer and use it in GitHub Desktop.
Save o-hanna/d4f17ad1d600014495dd13d3bd013401 to your computer and use it in GitHub Desktop.
import torch
from examples.memory_plan import get_memory_plan, print_allocs
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
import numpy
import random
def plot_memory_events(AllocFreeEvents, device, filename='memory_events.pdf', title = 'Memory Events', alloc_free_seq_lengths = [], num_events=0, distance=int(1e6)):
events_for_device = AllocFreeEvents[device]
if num_events <=0:
num_events = len(events_for_device)
events = [[i,events_for_device[i]] for i in range(num_events)] # attaching sequence number to do sort then reverse sort
sorted_events = sorted(events,key=lambda x: x[1].ptr) # sort by pointer
# subtract min pointer value to start from 0
next_ptr_dist = sorted_events[0][1].ptr
if next_ptr_dist > distance:
for j in range(len(sorted_events)):
sorted_events[j][1].ptr -= next_ptr_dist
# caping large white spaces at "distance" and compute length of sequence
T = 1 # length of sequence
for i in range(len(sorted_events)-1):
next_ptr_dist = sorted_events[i+1][1].ptr - (sorted_events[i][1].ptr+abs(sorted_events[i][1].size))
if next_ptr_dist > distance:
for j in range(i+1,len(sorted_events)):
sorted_events[j][1].ptr += distance-next_ptr_dist
if numpy.sign(events[i][1].size) != numpy.sign(events[i+1][1].size):
T += 1
# reverse the sort
for x in sorted_events:
events[x[0]] = x[1]
min_address = 0
max_address = max(x[1].ptr+abs(x[1].size) for x in sorted_events)
# plot events
_, ax = plt.subplots()
t = 0
batch_idx = 0
for j in range(num_events):
if (j > 0) and (numpy.sign(events[j].size) != numpy.sign(events[j-1].size)):
t = t+1
if numpy.sign(events[j].size) == -1:
col = "white"
else:
col = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1))
ax.add_patch(Rectangle((events[j].ptr, t), abs(events[j].size)-1, T-t,color=col))
col = 'k'
linwidth = 0.5
if (len(alloc_free_seq_lengths)>0) and (j==alloc_free_seq_lengths[batch_idx]):
col = 'red'
linwidth = 1
batch_idx += 1
plt.axhline (y = t, xmin =min_address, xmax = max_address, color=col, linewidth = linwidth)
plt.xlim([min_address,max_address])
plt.ylim([0,t+1])
plt.xlabel("Memory status")
plt.ylabel("Sequence no.")
plt.title(title)
plt.show()
plt.savefig(filename)
current_device = torch.cuda.current_device()
torch.cuda.enable_memory_tracker()
size = int(5e8)
memory_plan = []
recovered = False
while(True):
try:
if len(memory_plan)>0:
print("with memory plan")
torch.cuda.set_memory_plan(plan=memory_plan, plan_size=plan_size)
mem = []
num_allocs = 82 # allocate num_allocs events with size
frac = 0.6 # free num_allocs events and allocate num_allocs events with frac*size
memory = num_allocs*size*frac # total memory used at the end
for i in range(num_allocs):
# print(i)
mem.append(torch.cuda.caching_allocator_alloc(size*1))
for i in range(num_allocs):
torch.cuda.caching_allocator_delete(mem[0])
del mem[0]
mem.append(torch.cuda.caching_allocator_alloc(int(size*frac)))
mem.append(torch.cuda.caching_allocator_alloc(int(size*32)))
if recovered:
print("OOM is avoided by plan")
break
except(RuntimeError):
print("OOM occured at memory usage:", memory/torch.cuda.get_device_properties(0).total_memory)
events = torch.cuda.get_alloc_free_events()[current_device]
l1 = len(events)
(memory_plan, plan_size) = get_memory_plan([events])
for mem_del in mem:
torch.cuda.caching_allocator_delete(mem_del)
if recovered:
break
recovered = True
events = torch.cuda.get_alloc_free_events()[current_device]
plot_memory_events([events[:l1]],0,filename='Normal', title = 'Memory Events: CudaCachingAllocator')
plot_memory_events([events[l1:]],0,filename='Ours', title = 'Memory Events: Our Memory Planner')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment