Skip to content

Instantly share code, notes, and snippets.

@o-hanna
Last active September 8, 2022 17:57
Show Gist options
  • Save o-hanna/95b13a8f7ab750f09b656b1726ce9e70 to your computer and use it in GitHub Desktop.
Save o-hanna/95b13a8f7ab750f09b656b1726ce9e70 to your computer and use it in GitHub Desktop.
import torch
import numpy as np
def get_memory_plan(events, threshold = 1):
plans = []
sizes = []
for i in range(len(events)):
plan, size = memory_plan_for_device(events[i], threshold)
plans.append(plan.copy())
sizes.append(size)
return (plans, sizes)
class Event:
def __init__(self, size, start, end) -> None:
self.size = size
self.start = start
self.end = end
self.ptr = -1
def memory_plan_for_device(events, threshold): # returns a memory (plan, plan_size) to be passed to torch.cuda.set_memory_plan
# If the event ends after threshold*len(events), we allocate it on the left side of memory.
# We then greedily pick allocations end time that can be assigned with the previous allocations without icreasing the total memory plan size.
# *) If there are more events remaining we increase memory plan size and assign allocations greedily until we cannot assign more without increasing the size.
# We iterate *) untill all requests are allocated.
if len(events) == 0:
return ([],0)
seq_len = len(events)
events_with_end_pts = find_end_pts(events) # lis of allocation events of type Event
(left_allocations, remaining_allocations) = find_left_allocations(events_with_end_pts, threshold=threshold, seq_len=seq_len)
left_size = find_plan(left_allocations)
left_plan = left_allocations
# greedily allocate remaining actions
# complexity of the following part can be reduced using a tree instead of list
remaining_allocations.sort(key = lambda x: x.end - 1/(2+x.start))
while(len(remaining_allocations)>0):
new_remaining_allocations = []
for i in range(len(remaining_allocations)):
if not greedy_allocate(remaining_allocations[i], left_plan, left_size):
new_remaining_allocations.append(remaining_allocations[i])
remaining_allocations = new_remaining_allocations
if len(remaining_allocations)>0:
left_size += remaining_allocations[0].size
left_plan.sort(key = lambda x: x.start)
memory_plan = []
plan_size = 0
for event in left_plan:
plan_size = max(plan_size, event.ptr + event.size)
memory_plan.append(torch.cuda.createAllocFreeEvent(event.ptr, event.size))
return (memory_plan, plan_size)
def find_end_pts(events): # find start and end points of allocation events
events_with_end_pts = {}
num_changed_hash = 0
for i, event in enumerate(events):
if event.size > 0:
if event.ptr in events_with_end_pts:
events_with_end_pts[num_changed_hash] = events_with_end_pts[event.ptr]
num_changed_hash -= 1
events_with_end_pts[event.ptr] = Event(event.size, i, len(events)-1)
elif event.ptr in events_with_end_pts:
events_with_end_pts[event.ptr].end = i
return list(events_with_end_pts.values())
def find_left_allocations(events_with_end_pts, threshold, seq_len):
left_allocations, remaining_allocations = [], []
for event in events_with_end_pts:
if (event.end>=threshold*(seq_len-1)):
left_allocations.append(event)
else:
remaining_allocations.append(event)
return (left_allocations, remaining_allocations)
def find_plan(allocations):
allocations.sort(key = lambda x: -x.end-1/(2+x.start))
offset = 0
for event in allocations:
event.ptr = offset
offset += event.size
return offset
def greedy_allocate(event, side_plan, side_size):
offset = 0
while(offset+event.size<side_size):
# find overlap
overlap = False
for alloc in side_plan:
if ((alloc.start<=event.end) and (alloc.end>=event.start)):
if ((alloc.ptr<=offset+event.size-1) and (alloc.ptr+alloc.size-1>=offset)):
offset = alloc.ptr+alloc.size
overlap = True
if not overlap:
event.ptr = offset
side_plan.append(event)
return True
return False
def print_allocs(allocs):
for event in allocs:
print(event.ptr,event.size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment