import os
import time
import queue
import threading
import random
import torch
import torchvision
from torchvision.transforms import v2
# Testing Hyperparameters #
demo_batchsize = 96
num_demo_layers = 18
image_height = 14800
image_width = 18400
input_image_size = 512
num_warmup_iters = 20
num_timing_iters = 40
dtype = torch.half
data_device = 'cpu'
device = 'cuda'
# Class Definitions #
class BatchedPreprocessingFunction():
def __init__(self, *args, image_size, **kwargs):
super().__init__(*args, **kwargs)
# TODO TODO TODO TODO: Potential #BUG: We need to double check here and see if it samples a unique angle for each batch item, or rotates them all the same? Could be a very, very subtle #BUG if so :')))) / :'((((
self.batched_rotate_op = torchvision.transforms.v2.RandomRotation(degrees=180, interpolation=torchvision.transforms.InterpolationMode.BILINEAR)
self.flip_lr_chance = .5
self.image_size = image_size
def __call__(self, stacked_input_targets, cast_pre_rotate=False):
# Assumes input of batchsize, 2, (height=width=image_size+rotate_padding,)*2
# Batch rotate on the GPU using the torchvisionv2 transforms. Assumes it's prepadded so that we can crop out the center to avoid extra surrounding grey space from the rotations
# CPU rotate does not support half, so we cast it....
if cast_pre_rotate:
stacked_input_targets = stacked_input_targets.float()
rotated = self.batched_rotate_op(stacked_input_targets)
if cast_pre_rotate:
stacked_input_targets = stacked_input_targets.half()
# Calculate the padding to remove
padding_to_trim = rotated.shape[2] - self.image_size, rotated.shape[3] - self.image_size
# Slice the rotated image down the target size
rotated = rotated[:, :, padding_to_trim[0]//2:-(padding_to_trim[0]-padding_to_trim[0]//2), padding_to_trim[1]//2:-(padding_to_trim[1]-padding_to_trim[1]//2)]
# Batch-sample values between 0. and 1. to calculate left-right flip probabilities.
to_flip = (torch.rand((rotated.shape[0],), device=stacked_input_targets.device) < self.flip_lr_chance).view(-1, 1, 1, 1).half()
# There may be a more memory-efficient and more straightforward way to do this, but right now we have to flip then select between the flipped values
# I can't remember off the top of my head why the stack is necessary, but I think it had something to do with a peculiarity of how torch flipped the tensor, sadgely :'((((
result = to_flip * rotated + (1. - to_flip) * torch.stack((torch.fliplr(rotated[:, 0, :, :]), torch.fliplr(rotated[:, 1, :, :])), dim=1)
return result
# current, slight #hack for now, we define the data in a scope outside of this class to make the forking easier.
class CustomImageDataset(
def __init__(self, rotation_padding, per_image_preprocessing_fn=None):
# Don't store the labels in the dataloader, as the fork() command copies all of the memory in it :'(((((
# slightly hacky as the this 'rotation_padding' variable is used weirdly in different places, # TODO maybe clean it up a bit...
self.rotation_padding = rotation_padding
self.per_image_preprocessing_fn = per_image_preprocessing_fn
def __len__(self):
return 4*demo_batchsize*num_warmup_iters+num_timing_iters
def __getitem__(self, idx):
# Choses randomly based on label image, so there's likely going to be uneven sampling based upon the amount of training images paired per label image in da dataset.... :'(((((((((
label_image, input_images = random.choice(demo_input_label_data_pairs)
input_image_randomly_selected = random.choice(input_images)
# Can be improved, # TODO note to self, let's replace with a vectorized version in the future please! <3 :'))))
inputs_and_targets = slow_cpu_iterative_image_slicer(input_image_randomly_selected, label_image, rotation_padding=self.rotation_padding, chunk_size=input_image_size) ####, batchsize=batchsize)
if self.per_image_preprocessing_fn is not None:
inputs_and_targets = self.per_image_preprocessing_fn(inputs_and_targets.unsqueeze(0), cast_pre_rotate=True).squeeze()
return inputs_and_targets
# Co-written and modified w/ ChatGPT 4.0
class GPUPrefetcher:
def __init__(self, loader, batched_fn_to_apply, queue_size=12, num_threads=2):
self.loader = iter(loader)
self.queue = queue.Queue(maxsize=queue_size) # Adjust size as needed
self.is_running = True
self.threads = []
self.streams = [torch.cuda.Stream() for _ in range(num_threads)]
self.batched_fn_to_apply = batched_fn_to_apply
for i in range(num_threads):
thread = threading.Thread(target=self._prefetch, args=(i,))
thread.daemon = True
def _prefetch(self, thread_id):
while self.is_running:
with torch.no_grad():
input_target_tensor = next(self.loader)
cuda_tensor ='cuda', memory_format=torch.channels_last, non_blocking=True)
batch_prepped = self.batched_fn_to_apply(cuda_tensor)
except StopIteration:
def __next__(self):
with torch.no_grad():
if not self.is_running and self.queue.empty():
raise StopIteration
# Can be helpful to uncomment this below line if'n u wanna watch queue health over the course of your training run (might be helpful to log in other ways if you like doing that as well)
####print("current queue depth: ", self.queue.qsize())
for stream in self.streams:
next_input_target_tensor = self.queue.get()
if next_input_target_tensor is None:
raise StopIteration
return next_input_target_tensor
def __del__(self):
self.is_running = False
for thread in self.threads:
def stop(self):
self.is_running = False
for _ in self.threads:
self.queue.put(None) # Signal threads to stop
# Checks multiple batches of candidate points in a potential image all at once to see if the center point is 0 (i.e., a very crude test to see if the image is 'mostly dark', or not), and returns all of the valid points (16 has been more then enough so far, you may need to increase this number for images with more black points or maybe very long training runs...)
def get_valid_image_offsets_batched(image, height_size, width_size, batchsize, mult_to_check=16):
with torch.no_grad():
zero_val = 1e-2
max_height, max_width = image.squeeze().shape # Assuming 1 input image for now.... :'(((((
height_offsets = torch.randint(max_height-height_size, size=(batchsize*mult_to_check,))
width_offsets = torch.randint(max_width-width_size, size=(batchsize*mult_to_check,))
center_pixels = image[height_offsets+height_size//2, width_offsets+width_size//2]
valid_pixels = ~(center_pixels < zero_val)
heights_filtered = torch.masked_select(height_offsets, valid_pixels)[:batchsize]
widths_filtered = torch.masked_select(width_offsets, valid_pixels)[:batchsize]
if batchsize == 1: # if single batch, we may need to reduce the dimensions hiers <3 :'))))
heights_filtered = heights_filtered.squeeze()
widths_filtered = widths_filtered.squeeze()
return heights_filtered, widths_filtered
# TODO: TODO: TODO: TODO: TODO: Note to self: Gotta fold and fuse a few of these kernels, please! <3 :'))))
# Def gotta eventually v/hmap dis gui
def slow_cpu_iterative_image_slicer(input_image, label_image, chunk_size, rotation_padding, fliplr_chance=.5, rotate_degrees=180):
with torch.no_grad():
height_offset, width_offset = get_valid_image_offsets_batched(input_image, chunk_size+rotation_padding, chunk_size+rotation_padding, batchsize=1)
input_image_batch_item = input_image.squeeze()[height_offset:height_offset+chunk_size+rotation_padding, width_offset:width_offset+chunk_size+rotation_padding]
label_image_batch_item = label_image.squeeze()[height_offset:height_offset+chunk_size+rotation_padding, width_offset:width_offset+chunk_size+rotation_padding]
return torch.stack((input_image_batch_item, label_image_batch_item), dim=0)
# this should be on the cpu in order to properly test dis gui
demo_input_data_raw = torch.ones((num_demo_layers, image_height, image_width), dtype=dtype, device=data_device)
demo_label_data_raw = torch.ones((1, image_height, image_width), dtype=dtype, device=data_device)
# make only label image pair for now (one label as the index for a stacked tensor containing all training images for a given dataset)
demo_input_label_data_pairs = ((demo_label_data_raw, demo_input_data_raw),)
# required for the multiprocessing, to avoid duplicating these source images in each process
for label_image, input_images in demo_input_label_data_pairs:
[input_image.share_memory_() for input_image in input_images]
# Dataloader Setup #
num_cpus = os.cpu_count()
# the multiplier is hardcoded for now, you might want to adjust it, though there may be some memory/processing tradeoff if so....
rotation_padding = round(.35*input_image_size)
print("num cpus available! : <3 :'))))", num_cpus)
batched_preprocessing_fn = BatchedPreprocessingFunction(image_size=input_image_size)
non_prefetched_image_dataset = CustomImageDataset(rotation_padding=rotation_padding, per_image_preprocessing_fn=batched_preprocessing_fn)
prefetched_image_dataset = CustomImageDataset(rotation_padding=rotation_padding)
non_prefetched_train_dataset_gpu_loader = iter(, batch_size=demo_batchsize, drop_last=True, shuffle=True, num_workers=num_cpus//2, pin_memory=False, persistent_workers=False, prefetch_factor=2))
train_dataset_gpu_loader = iter(, batch_size=demo_batchsize, drop_last=True, shuffle=True, num_workers=num_cpus//2, pin_memory=False, persistent_workers=False, prefetch_factor=2))
train_dataset_gpu_prefetcher = GPUPrefetcher(train_dataset_gpu_loader, batched_fn_to_apply=batched_preprocessing_fn, queue_size=6, num_threads=2)
# Sorta hackey for now to unroll the speed tests like this, but just keeping a flat-ish structure for flexibility.
# Non-prefetched dataloader
for _ in range(num_warmup_iters):
inputs, targets = next(non_prefetched_train_dataset_gpu_loader).to(device='cuda', memory_format=torch.channels_last, non_blocking=True).unsqueeze(2).unbind(1)
non_prefetched_begin = time.time()
for _ in range(num_timing_iters):
inputs, targets = next(non_prefetched_train_dataset_gpu_loader).to(device='cuda', memory_format=torch.channels_last, non_blocking=True).unsqueeze(2).unbind(1)
non_prefetched_end = time.time()
# Prefetched dataloader
for _ in range(num_warmup_iters):
inputs, targets = next(train_dataset_gpu_prefetcher).unsqueeze(2).unbind(1)
prefetched_begin = time.time()
for _ in range(num_timing_iters):
inputs, targets = next(train_dataset_gpu_prefetcher).unsqueeze(2).unbind(1)
prefetched_end = time.time()
non_prefetched_seconds_per_step = (non_prefetched_end-non_prefetched_begin)/num_timing_iters
prefetched_seconds_per_step = (prefetched_end-prefetched_begin)/num_timing_iters
print(f"Avg non-prefetched train dataset gpu loader time per step (in seconds):\t {non_prefetched_seconds_per_step}\t|")
print(f"Avg prefetched train dataset gpu loader time per step (in seconds):\t {prefetched_seconds_per_step}\t|")
print("\nSpeed factor of improvement: ", non_prefetched_seconds_per_step/prefetched_seconds_per_step)
