Last active
September 8, 2022 17:57
-
-
Save o-hanna/95b13a8f7ab750f09b656b1726ce9e70 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import numpy as np | |
def get_memory_plan(events, threshold = 1): | |
plans = [] | |
sizes = [] | |
for i in range(len(events)): | |
plan, size = memory_plan_for_device(events[i], threshold) | |
plans.append(plan.copy()) | |
sizes.append(size) | |
return (plans, sizes) | |
class Event: | |
def __init__(self, size, start, end) -> None: | |
self.size = size | |
self.start = start | |
self.end = end | |
self.ptr = -1 | |
def memory_plan_for_device(events, threshold): # returns a memory (plan, plan_size) to be passed to torch.cuda.set_memory_plan | |
# If the event ends after threshold*len(events), we allocate it on the left side of memory. | |
# We then greedily pick allocations end time that can be assigned with the previous allocations without icreasing the total memory plan size. | |
# *) If there are more events remaining we increase memory plan size and assign allocations greedily until we cannot assign more without increasing the size. | |
# We iterate *) untill all requests are allocated. | |
if len(events) == 0: | |
return ([],0) | |
seq_len = len(events) | |
events_with_end_pts = find_end_pts(events) # lis of allocation events of type Event | |
(left_allocations, remaining_allocations) = find_left_allocations(events_with_end_pts, threshold=threshold, seq_len=seq_len) | |
left_size = find_plan(left_allocations) | |
left_plan = left_allocations | |
# greedily allocate remaining actions | |
# complexity of the following part can be reduced using a tree instead of list | |
remaining_allocations.sort(key = lambda x: x.end - 1/(2+x.start)) | |
while(len(remaining_allocations)>0): | |
new_remaining_allocations = [] | |
for i in range(len(remaining_allocations)): | |
if not greedy_allocate(remaining_allocations[i], left_plan, left_size): | |
new_remaining_allocations.append(remaining_allocations[i]) | |
remaining_allocations = new_remaining_allocations | |
if len(remaining_allocations)>0: | |
left_size += remaining_allocations[0].size | |
left_plan.sort(key = lambda x: x.start) | |
memory_plan = [] | |
plan_size = 0 | |
for event in left_plan: | |
plan_size = max(plan_size, event.ptr + event.size) | |
memory_plan.append(torch.cuda.createAllocFreeEvent(event.ptr, event.size)) | |
return (memory_plan, plan_size) | |
def find_end_pts(events): # find start and end points of allocation events | |
events_with_end_pts = {} | |
num_changed_hash = 0 | |
for i, event in enumerate(events): | |
if event.size > 0: | |
if event.ptr in events_with_end_pts: | |
events_with_end_pts[num_changed_hash] = events_with_end_pts[event.ptr] | |
num_changed_hash -= 1 | |
events_with_end_pts[event.ptr] = Event(event.size, i, len(events)-1) | |
elif event.ptr in events_with_end_pts: | |
events_with_end_pts[event.ptr].end = i | |
return list(events_with_end_pts.values()) | |
def find_left_allocations(events_with_end_pts, threshold, seq_len): | |
left_allocations, remaining_allocations = [], [] | |
for event in events_with_end_pts: | |
if (event.end>=threshold*(seq_len-1)): | |
left_allocations.append(event) | |
else: | |
remaining_allocations.append(event) | |
return (left_allocations, remaining_allocations) | |
def find_plan(allocations): | |
allocations.sort(key = lambda x: -x.end-1/(2+x.start)) | |
offset = 0 | |
for event in allocations: | |
event.ptr = offset | |
offset += event.size | |
return offset | |
def greedy_allocate(event, side_plan, side_size): | |
offset = 0 | |
while(offset+event.size<side_size): | |
# find overlap | |
overlap = False | |
for alloc in side_plan: | |
if ((alloc.start<=event.end) and (alloc.end>=event.start)): | |
if ((alloc.ptr<=offset+event.size-1) and (alloc.ptr+alloc.size-1>=offset)): | |
offset = alloc.ptr+alloc.size | |
overlap = True | |
if not overlap: | |
event.ptr = offset | |
side_plan.append(event) | |
return True | |
return False | |
def print_allocs(allocs): | |
for event in allocs: | |
print(event.ptr,event.size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment