Last active
September 3, 2022 14:37
-
-
Save o-hanna/d4f17ad1d600014495dd13d3bd013401 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
from examples.memory_plan import get_memory_plan, print_allocs | |
import matplotlib.pyplot as plt | |
from matplotlib.patches import Rectangle | |
import numpy | |
import random | |
def plot_memory_events(AllocFreeEvents, device, filename='memory_events.pdf', title = 'Memory Events', alloc_free_seq_lengths = [], num_events=0, distance=int(1e6)): | |
events_for_device = AllocFreeEvents[device] | |
if num_events <=0: | |
num_events = len(events_for_device) | |
events = [[i,events_for_device[i]] for i in range(num_events)] # attaching sequence number to do sort then reverse sort | |
sorted_events = sorted(events,key=lambda x: x[1].ptr) # sort by pointer | |
# subtract min pointer value to start from 0 | |
next_ptr_dist = sorted_events[0][1].ptr | |
if next_ptr_dist > distance: | |
for j in range(len(sorted_events)): | |
sorted_events[j][1].ptr -= next_ptr_dist | |
# caping large white spaces at "distance" and compute length of sequence | |
T = 1 # length of sequence | |
for i in range(len(sorted_events)-1): | |
next_ptr_dist = sorted_events[i+1][1].ptr - (sorted_events[i][1].ptr+abs(sorted_events[i][1].size)) | |
if next_ptr_dist > distance: | |
for j in range(i+1,len(sorted_events)): | |
sorted_events[j][1].ptr += distance-next_ptr_dist | |
if numpy.sign(events[i][1].size) != numpy.sign(events[i+1][1].size): | |
T += 1 | |
# reverse the sort | |
for x in sorted_events: | |
events[x[0]] = x[1] | |
min_address = 0 | |
max_address = max(x[1].ptr+abs(x[1].size) for x in sorted_events) | |
# plot events | |
_, ax = plt.subplots() | |
t = 0 | |
batch_idx = 0 | |
for j in range(num_events): | |
if (j > 0) and (numpy.sign(events[j].size) != numpy.sign(events[j-1].size)): | |
t = t+1 | |
if numpy.sign(events[j].size) == -1: | |
col = "white" | |
else: | |
col = (random.uniform(0, 1), random.uniform(0, 1), random.uniform(0, 1)) | |
ax.add_patch(Rectangle((events[j].ptr, t), abs(events[j].size)-1, T-t,color=col)) | |
col = 'k' | |
linwidth = 0.5 | |
if (len(alloc_free_seq_lengths)>0) and (j==alloc_free_seq_lengths[batch_idx]): | |
col = 'red' | |
linwidth = 1 | |
batch_idx += 1 | |
plt.axhline (y = t, xmin =min_address, xmax = max_address, color=col, linewidth = linwidth) | |
plt.xlim([min_address,max_address]) | |
plt.ylim([0,t+1]) | |
plt.xlabel("Memory status") | |
plt.ylabel("Sequence no.") | |
plt.title(title) | |
plt.show() | |
plt.savefig(filename) | |
current_device = torch.cuda.current_device() | |
torch.cuda.enable_memory_tracker() | |
size = int(5e8) | |
memory_plan = [] | |
recovered = False | |
while(True): | |
try: | |
if len(memory_plan)>0: | |
print("with memory plan") | |
torch.cuda.set_memory_plan(plan=memory_plan, plan_size=plan_size) | |
mem = [] | |
num_allocs = 82 # allocate num_allocs events with size | |
frac = 0.6 # free num_allocs events and allocate num_allocs events with frac*size | |
memory = num_allocs*size*frac # total memory used at the end | |
for i in range(num_allocs): | |
# print(i) | |
mem.append(torch.cuda.caching_allocator_alloc(size*1)) | |
for i in range(num_allocs): | |
torch.cuda.caching_allocator_delete(mem[0]) | |
del mem[0] | |
mem.append(torch.cuda.caching_allocator_alloc(int(size*frac))) | |
mem.append(torch.cuda.caching_allocator_alloc(int(size*32))) | |
if recovered: | |
print("OOM is avoided by plan") | |
break | |
except(RuntimeError): | |
print("OOM occured at memory usage:", memory/torch.cuda.get_device_properties(0).total_memory) | |
events = torch.cuda.get_alloc_free_events()[current_device] | |
l1 = len(events) | |
(memory_plan, plan_size) = get_memory_plan([events]) | |
for mem_del in mem: | |
torch.cuda.caching_allocator_delete(mem_del) | |
if recovered: | |
break | |
recovered = True | |
events = torch.cuda.get_alloc_free_events()[current_device] | |
plot_memory_events([events[:l1]],0,filename='Normal', title = 'Memory Events: CudaCachingAllocator') | |
plot_memory_events([events[l1:]],0,filename='Ours', title = 'Memory Events: Our Memory Planner') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment