Last active
October 3, 2022 05:56
-
-
Save Tony363/a308639c164551216d42bae448710543 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Loader(object): | |
def __init__( | |
self, | |
non_flicker_dir: str, | |
flicker_dir: str, | |
labels: dict, | |
batch_size: int, | |
in_mem_batches: int, | |
) -> None: | |
mp.set_start_method("spawn") | |
self.labels = labels | |
self.batch_size = batch_size | |
self.batch_idx = self.cur_batch = 0 | |
self.in_mem_batches = in_mem_batches | |
self.non_flicker_dir = non_flicker_dir | |
self.flicker_dir = flicker_dir | |
self.non_flicker_lst = os.listdir(non_flicker_dir) | |
self.flicker_lst = os.listdir(flicker_dir) | |
self.manager = mp.Manager() | |
self.producer_q = self.manager.Queue() | |
self.out_x = self.manager.Queue() | |
self.out_y = self.manager.Queue() | |
self.lock = self.manager.Lock() | |
self.event = None | |
self.producers = self.consumers = None | |
def __len__(self) -> int: | |
return len(self.non_flicker_lst) // ((self.batch_size//2)*self.in_mem_batches) + 1 | |
def __iter__(self): | |
return self | |
def __next__(self) -> Tuple[torch.Tensor, torch.Tensor]: | |
print("EPOCH STOP CONDITION ", self.batch_idx, | |
len(self.non_flicker_lst), flush=True) | |
if self.batch_idx > len(self.non_flicker_lst): | |
gc.collect() | |
raise StopIteration | |
print( | |
f"OUT {not self.out_x.empty()} - {not self.out_y.empty()}", flush=True) | |
if not bool(self.cur_batch): | |
if self.consumers is not None and self.event is not None and self.producers is not None: | |
self.event.set() | |
# for p in self.producers: | |
# p.join() | |
for p in self.producers: | |
p.terminate() | |
for c in self.consumers: | |
c.terminate() | |
self.event = mp.Event() | |
non_flickers = [ | |
os.path.join(self.non_flicker_dir, | |
self.non_flicker_lst[i % len(self.non_flicker_lst)]) | |
for i in range(self.batch_idx, self.batch_idx+(self.batch_size//2)*self.in_mem_batches) | |
] | |
flickers = [ | |
os.path.join(self.flicker_dir, | |
self.flicker_lst[i % len(self.flicker_lst)]) | |
for i in range(self.batch_idx, self.batch_idx+(self.batch_size//2)*self.in_mem_batches) | |
] | |
chunk_lst = non_flickers + flickers | |
random.shuffle(chunk_lst) | |
self.producers, self.consumers = self._load(chunk_lst) | |
self.cur_batch = self.in_mem_batches | |
self.batch_idx = self.batch_idx + \ | |
(self.batch_size//2)*self.in_mem_batches | |
gc.collect() | |
self.cur_batch -= 1 | |
X, y = self.out_x.get(), self.out_y.get() | |
return torch.from_numpy(X.astype(np.float32)), torch.from_numpy(y.astype(np.uint8)) | |
def _shuffle(self) -> None: | |
random.shuffle(self.non_flicker_lst) | |
random.shuffle(self.flicker_lst) | |
gc.collect() | |
def _load( | |
self, | |
chunk_lst: list, | |
) -> tuple: | |
split = np.split(np.array(chunk_lst), self.in_mem_batches) | |
producers = tuple( | |
mp.Process( | |
target=self._producers, | |
args=(chunk, self.labels, self.producer_q, self.lock) | |
) | |
for _, chunk in zip(range(self.in_mem_batches), split) | |
) | |
consumers = tuple(mp.Process( | |
target=self._consumers, | |
args=(self.producer_q, self.out_x, self.out_y, | |
self.batch_size, self.lock, self.event)) | |
for _ in range(os.cpu_count()-self.in_mem_batches) | |
) | |
for c in consumers: | |
c.daemon = True | |
c.start() | |
for p in producers: | |
p.start() | |
print("LOADED") | |
return producers, consumers | |
@staticmethod | |
def _producers( | |
cur_batch_lst: list, | |
labels: dict, | |
producer_q: mp.Queue, | |
lock: mp.Lock, | |
) -> None: | |
for path in cur_batch_lst: | |
idx, vid_name = path.split( | |
"/")[3].replace(".mp4", "").split("_", 1) | |
producer_q.put( | |
(int(idx in labels[vid_name]), skvideo.io.vread(path))) | |
with lock: | |
print(f"PRODUCER {vid_name} {os.getpid()}", flush=True) | |
@staticmethod | |
def _consumers( | |
q: mp.Queue, | |
out_x: mp.Queue, | |
out_y: mp.Queue, | |
batch_size: int, | |
lock: mp.Lock, | |
event: mp.Event | |
) -> None: | |
X = y = () | |
while not event.is_set(): | |
if len(X) == len(y) == batch_size: | |
out_x.put(np.array(X)) | |
out_y.put(np.array(y)) | |
X = y = () | |
with lock: | |
print( | |
f"CONSUMER NEW BATCH {os.getpid()}", flush=True) | |
try: | |
label, input = q.get(timeout=5) | |
except queue.Empty: | |
cpu_stats() | |
X += (input,) | |
y += (label,) | |
def cpu_stats(): | |
# print(sys.version) | |
print("CPU USAGE - ", psutil.cpu_percent()) | |
print("MEMORY USAGE - ", psutil.virtual_memory()) # physical memory usage | |
pid = os.getpid() | |
py = psutil.Process(pid) | |
# memory use in GB...I think | |
memoryUse = py.memory_info()[0] / 2. ** 30 | |
print('memory GB:', memoryUse) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment