Skip to content

Instantly share code, notes, and snippets.

@Tony363
Last active October 3, 2022 05:56
Show Gist options
  • Save Tony363/a308639c164551216d42bae448710543 to your computer and use it in GitHub Desktop.
Save Tony363/a308639c164551216d42bae448710543 to your computer and use it in GitHub Desktop.
class Loader(object):
def __init__(
self,
non_flicker_dir: str,
flicker_dir: str,
labels: dict,
batch_size: int,
in_mem_batches: int,
) -> None:
mp.set_start_method("spawn")
self.labels = labels
self.batch_size = batch_size
self.batch_idx = self.cur_batch = 0
self.in_mem_batches = in_mem_batches
self.non_flicker_dir = non_flicker_dir
self.flicker_dir = flicker_dir
self.non_flicker_lst = os.listdir(non_flicker_dir)
self.flicker_lst = os.listdir(flicker_dir)
self.manager = mp.Manager()
self.producer_q = self.manager.Queue()
self.out_x = self.manager.Queue()
self.out_y = self.manager.Queue()
self.lock = self.manager.Lock()
self.event = None
self.producers = self.consumers = None
def __len__(self) -> int:
return len(self.non_flicker_lst) // ((self.batch_size//2)*self.in_mem_batches) + 1
def __iter__(self):
return self
def __next__(self) -> Tuple[torch.Tensor, torch.Tensor]:
print("EPOCH STOP CONDITION ", self.batch_idx,
len(self.non_flicker_lst), flush=True)
if self.batch_idx > len(self.non_flicker_lst):
gc.collect()
raise StopIteration
print(
f"OUT {not self.out_x.empty()} - {not self.out_y.empty()}", flush=True)
if not bool(self.cur_batch):
if self.consumers is not None and self.event is not None and self.producers is not None:
self.event.set()
# for p in self.producers:
# p.join()
for p in self.producers:
p.terminate()
for c in self.consumers:
c.terminate()
self.event = mp.Event()
non_flickers = [
os.path.join(self.non_flicker_dir,
self.non_flicker_lst[i % len(self.non_flicker_lst)])
for i in range(self.batch_idx, self.batch_idx+(self.batch_size//2)*self.in_mem_batches)
]
flickers = [
os.path.join(self.flicker_dir,
self.flicker_lst[i % len(self.flicker_lst)])
for i in range(self.batch_idx, self.batch_idx+(self.batch_size//2)*self.in_mem_batches)
]
chunk_lst = non_flickers + flickers
random.shuffle(chunk_lst)
self.producers, self.consumers = self._load(chunk_lst)
self.cur_batch = self.in_mem_batches
self.batch_idx = self.batch_idx + \
(self.batch_size//2)*self.in_mem_batches
gc.collect()
self.cur_batch -= 1
X, y = self.out_x.get(), self.out_y.get()
return torch.from_numpy(X.astype(np.float32)), torch.from_numpy(y.astype(np.uint8))
def _shuffle(self) -> None:
random.shuffle(self.non_flicker_lst)
random.shuffle(self.flicker_lst)
gc.collect()
def _load(
self,
chunk_lst: list,
) -> tuple:
split = np.split(np.array(chunk_lst), self.in_mem_batches)
producers = tuple(
mp.Process(
target=self._producers,
args=(chunk, self.labels, self.producer_q, self.lock)
)
for _, chunk in zip(range(self.in_mem_batches), split)
)
consumers = tuple(mp.Process(
target=self._consumers,
args=(self.producer_q, self.out_x, self.out_y,
self.batch_size, self.lock, self.event))
for _ in range(os.cpu_count()-self.in_mem_batches)
)
for c in consumers:
c.daemon = True
c.start()
for p in producers:
p.start()
print("LOADED")
return producers, consumers
@staticmethod
def _producers(
cur_batch_lst: list,
labels: dict,
producer_q: mp.Queue,
lock: mp.Lock,
) -> None:
for path in cur_batch_lst:
idx, vid_name = path.split(
"/")[3].replace(".mp4", "").split("_", 1)
producer_q.put(
(int(idx in labels[vid_name]), skvideo.io.vread(path)))
with lock:
print(f"PRODUCER {vid_name} {os.getpid()}", flush=True)
@staticmethod
def _consumers(
q: mp.Queue,
out_x: mp.Queue,
out_y: mp.Queue,
batch_size: int,
lock: mp.Lock,
event: mp.Event
) -> None:
X = y = ()
while not event.is_set():
if len(X) == len(y) == batch_size:
out_x.put(np.array(X))
out_y.put(np.array(y))
X = y = ()
with lock:
print(
f"CONSUMER NEW BATCH {os.getpid()}", flush=True)
try:
label, input = q.get(timeout=5)
except queue.Empty:
cpu_stats()
X += (input,)
y += (label,)
def cpu_stats():
# print(sys.version)
print("CPU USAGE - ", psutil.cpu_percent())
print("MEMORY USAGE - ", psutil.virtual_memory()) # physical memory usage
pid = os.getpid()
py = psutil.Process(pid)
# memory use in GB...I think
memoryUse = py.memory_info()[0] / 2. ** 30
print('memory GB:', memoryUse)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment