Skip to content

Instantly share code, notes, and snippets.

@prerakmody
Last active June 1, 2024 14:12
Show Gist options
  • Save prerakmody/0c5e9263d42b2fab26a48dfb6b818cca to your computer and use it in GitHub Desktop.
Save prerakmody/0c5e9263d42b2fab26a48dfb6b818cca to your computer and use it in GitHub Desktop.
TORCH MULTIPROCESSING
import torchDataloader
import myDataloader
import pdb
import copy
import time
import tqdm
import pprint
import traceback
import numpy as np
import torch
KEY_WORKERS = 'workers'
KEY_TIMELIST = 'timeList'
KEY_ITERPERSEC = 'iterPerSec'
KEY_EPOCHS = 'epochs'
KEY_BATCH_SIZE = 'batch_size'
KEY_TYPE = 'dataloader-type'
KEY_WB = 'Workers-BatchSize'
STR_TORCHDATALOADER = 'torchDataloader'
STR_MYDATALOADER = 'myDataloader'
nameCPU = 'cpu'
nameGPU = 'cuda:0'
deviceCPU = torch.device("cpu")
deviceGPU = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def calculateTime(datasetLen, dataloader, epochs, workerCount, batchSize, device=deviceCPU, meta=''):
timeForEpochs = []
try:
# Step 1 - Loop over epochs
print ('')
for epoch in range(epochs):
with tqdm.tqdm(total=datasetLen, desc=' --------------------- [{}][W={}, B={}] Epoch {}/{}'.format(meta, workerCount, batchSize, epoch+1, epochs)) as pbar:
t1 = time.time()
# Step 2 - Loop over dataloader
for i, (patientSliceArray, batchMeta) in enumerate(dataloader):
patientSliceArray = patientSliceArray.to(device)
pbar.update(patientSliceArray.shape[0])
t2= time.time()
timeForEpochs.append(t2-t1)
pbar.close()
except:
traceback.print_exc()
pdb.set_trace()
return timeForEpochs
if __name__ == "__main__":
try:
############################################################### Step 1 - Declare results object and other variables
resultsObj = {KEY_WORKERS: [], KEY_BATCH_SIZE: [], KEY_TIMELIST: [], KEY_ITERPERSEC: []}
# Step 1 - Setup patient slices (fixed count of slices per patient)
patientSlicesList = {
'P1': [45, 62, 32, 21, 69]
, 'P2': [13, 23, 87, 54, 5]
, 'P3': [34, 56, 78, 90, 12]
, 'P4': [34, 56, 78, 90, 12]
, 'P5': [45, 62, 32, 21, 69]
, 'P6': [13, 23, 87, 54, 5]
, 'P7': [34, 56, 78, 90, 12]
, 'P8': [34, 56, 78, 90, 12, 21]
}
# workerCountList, batchSizeList, totalEpochs = [1, 2, 4, 8], [1, 2, 4, 8], 6
workerCountList, batchSizeList, totalEpochs = [1,2,4, 8], [1,2,3,4,6], 2
device, deviceName = deviceCPU, nameCPU # deviceGPU
############################################################## Step 2 - Test torchDataLoader
torchDataloaderResultsObj = copy.deepcopy(resultsObj)
if 1:
torchDatasetObj = torchDataloader.myDataset(patientSlicesList)
torchDatasetLen = len(torchDatasetObj)
for workerCount in workerCountList:
for batchSize in batchSizeList:
torchDataloaderResultsObj[KEY_WORKERS].append(workerCount)
torchDataloaderResultsObj[KEY_BATCH_SIZE].append(batchSize)
dataloader1 = torch.utils.data.DataLoader(torchDatasetObj, batch_size=batchSize, num_workers=workerCount)
timeForEpochs = calculateTime(torchDatasetLen, dataloader1, totalEpochs, workerCount, batchSize, device, meta='torchDataloader')
itersPerSec = [torchDatasetLen / timeForEpoch for timeForEpoch in timeForEpochs]
torchDataloaderResultsObj[KEY_TIMELIST].append(timeForEpochs)
torchDataloaderResultsObj[KEY_ITERPERSEC].append(itersPerSec)
############################################################## Step 3 - Test myDataLoader
myDataloaderResultsObj = copy.deepcopy(resultsObj)
if 1:
for workerCount in workerCountList:
for batchSize in batchSizeList:
myDataloaderResultsObj[KEY_BATCH_SIZE].append(batchSize)
myDataloaderResultsObj[KEY_WORKERS].append(workerCount)
myDataloaderObj = myDataloader.myDataloaderClass(patientSlicesList, numWorkers=workerCount, batchSize=batchSize)
myDataloaderLen = len(myDataloaderObj)
timeForEpochs = calculateTime(myDataloaderLen, myDataloaderObj, totalEpochs, workerCount, batchSize, device, meta='myDataloader')
itersPerSec = [myDataloaderLen / timeForEpoch for timeForEpoch in timeForEpochs]
myDataloaderResultsObj[KEY_TIMELIST].append(timeForEpochs)
myDataloaderResultsObj[KEY_ITERPERSEC].append(itersPerSec)
time.sleep(2)
myDataloaderObj.closeProcesses()
############################################################### Step 5 - Print results
if 1:
# Step 5.1 - convert to dataframe
import pandas as pd
torchDataloaderResultsObj = pd.DataFrame(torchDataloaderResultsObj)
myDataloaderResultsObj = pd.DataFrame(myDataloaderResultsObj)
# Step 5.2 - Perform a mean operation on KEY_TIMELIST and KEY_ITERPERSEC
torchDataloaderResultsObj[KEY_TIMELIST] = torchDataloaderResultsObj[KEY_TIMELIST].apply(lambda x: np.mean(x))
myDataloaderResultsObj[KEY_TIMELIST] = myDataloaderResultsObj[KEY_TIMELIST].apply(lambda x: np.mean(x))
torchDataloaderResultsObj[KEY_ITERPERSEC] = torchDataloaderResultsObj[KEY_ITERPERSEC].apply(lambda x: np.mean(x))
myDataloaderResultsObj[KEY_ITERPERSEC] = myDataloaderResultsObj[KEY_ITERPERSEC].apply(lambda x: np.mean(x))
# Step 5.3 - Merger KEY_WORKERS and KEY_BATCH_SIZE columns
torchDataloaderResultsObj[KEY_WORKERS] = torchDataloaderResultsObj[KEY_WORKERS].astype(str)
torchDataloaderResultsObj[KEY_BATCH_SIZE] = torchDataloaderResultsObj[KEY_BATCH_SIZE].astype(str)
torchDataloaderResultsObj[KEY_WB] = torchDataloaderResultsObj[KEY_WORKERS] + '-' + torchDataloaderResultsObj[KEY_BATCH_SIZE]
torchDataloaderResultsObj.drop([KEY_WORKERS, KEY_BATCH_SIZE], axis=1, inplace=True)
myDataloaderResultsObj[KEY_WORKERS] = myDataloaderResultsObj[KEY_WORKERS].astype(str)
myDataloaderResultsObj[KEY_BATCH_SIZE] = myDataloaderResultsObj[KEY_BATCH_SIZE].astype(str)
myDataloaderResultsObj[KEY_WB] = myDataloaderResultsObj[KEY_WORKERS] + '-' + myDataloaderResultsObj[KEY_BATCH_SIZE]
myDataloaderResultsObj.drop([KEY_WORKERS, KEY_BATCH_SIZE], axis=1, inplace=True)
# Step 5.4 - Add meta info
torchDataloaderResultsObj[KEY_TYPE] = STR_TORCHDATALOADER
myDataloaderResultsObj[KEY_TYPE] = STR_MYDATALOADER
# Step 5.5 - COncatenate the two dataframes
vizObj = pd.concat([torchDataloaderResultsObj, myDataloaderResultsObj], axis=0)
# Step 5.4 - Plot (with seaborn) the dataframe using KEY_WB as the x-axis
import seaborn as sns
import matplotlib.pyplot as plt
f,axarr = plt.subplots(1,2, figsize=(15, 5))
sns.barplot(x=KEY_WB, y=KEY_TIMELIST, data=vizObj, hue=KEY_TYPE, ax=axarr[0])
axarr[0].set_title('Total Time (average over {} epochs)'.format(totalEpochs))
sns.barplot(x=KEY_WB, y=KEY_ITERPERSEC, data=vizObj, hue=KEY_TYPE, ax=axarr[1])
axarr[1].set_title('Iterations/second (average over {} epochs)'.format(totalEpochs))
plt.savefig('dataloaderCompare__W{}__B{}.png'.format('-'.join(map(str, workerCountList)), '-'.join(map(str, batchSizeList))))
plt.show(block=False)
pdb.set_trace()
except:
traceback.print_exc()
pdb.set_trace()
import tqdm
import time
import torch # v1.12.1
import numpy as np
##################################################
# myDataset
##################################################
def getPatientArray(patientName, verbose=False):
time.sleep(1.0) # emulates disk read time
workerId = torch.utils.data.get_worker_info().id if torch.utils.data.get_worker_info() else 0
if verbose:
print (' - [getPatientArray()][worker={}] Loading volumes for patient: {}'.format(workerId, patientName))
return torch.tensor(np.random.rand(1, 128, 128, 128))
def getPatientSliceArray(patientName, sliceId, patientArray=None, verbose=False):
time.sleep(0.1) # emulates procesing time
if patientArray is None:
patientArray = getPatientArray(patientName, verbose=verbose)
sliceArray = patientArray[:, sliceId, :, :]
return patientArray, sliceArray
else:
sliceArray = patientArray[:, sliceId, :, :]
return None, sliceArray
class myDataset(torch.utils.data.Dataset):
def __init__(self, patientSlicesList, verbose=False, patientsInMemory=1):
self.patientSlicesList = patientSlicesList
self.verbose = verbose
self.patientsInMemory = patientsInMemory
self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage.
self.inputQueue = []
for patientName in self.patientSlicesList:
for sliceId in self.patientSlicesList[patientName]:
self.inputQueue.append((patientName, sliceId))
def __len__(self):
return len(self.inputQueue)
def _managePatientObj(self):
if len(self.patientObj) > self.patientsInMemory:
self.patientObj.pop(list(self.patientObj.keys())[0])
def __getitem__(self, idx):
# workerId = torch.utils.data.get_worker_info().id if torch.utils.data.get_worker_info() else 0
# print (' - [__getitem__()][worker={}] idx: {}'.format(workerId, idx))
# Step 0 - Init
patientName, sliceId = self.inputQueue[idx]
# Step 1 - Get patient slice array
patientArrayThis = self.patientObj.get(patientName, None)
patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis, verbose=self.verbose)
if patientArray is not None:
self.patientObj[patientName] = patientArray
self._managePatientObj()
return patientSliceArray, [patientName, sliceId]
##################################################
# Main
##################################################
if __name__ == '__main__':
# Step 1 - Setup patient slices (fixed count of slices per patient)
patientSlicesList = {
'P1': [45, 62, 32, 21, 69]
, 'P2': [13, 23, 87, 54, 5]
, 'P3': [34, 56, 78, 90, 12]
, 'P4': [34, 56, 78, 90, 12]
}
workerCount, batchSize, epochs = 4, 3, 3
# Step 2.1 - Create dataset and dataloader
dataset = myDataset(patientSlicesList)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, num_workers=workerCount)
# Step 2.2 - Iterate over dataloader
print ('\n - [main] Iterating over (my) dataloader...')
t0 = time.time()
epochTimes = []
for epochId in range(epochs):
print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs))
with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar:
t1 = time.time()
for i, (patientSliceArray, meta) in enumerate(dataloader):
print (' - [main] meta: ', meta)
pbar.update(patientSliceArray.shape[0])
epochTimes.append(time.time()-t1)
print (' - [main] --------------------------------------- ')
print (' - [main] Total time: {:.2f} s (avg = {:.2f} s)'.format(time.time()-t0, np.mean(epochTimes)))
import time
import tqdm
import torch
import traceback
import numpy as np
import torch.multiprocessing as torchMP
def getPatientArray(patientName, workerId, verbose=False):
time.sleep(1.0) # emulates disk read time
workerId = torch.utils.data.get_worker_info().id if torch.utils.data.get_worker_info() else 0
if verbose:
print (' - [getPatientArray()][worker={}] Loading volumes for patient: {}'.format(workerId, patientName))
return torch.tensor(np.random.rand(1, 128, 128, 128))
def getPatientSliceArray(patientName, sliceId, workerId, patientArray=None, verbose=False):
time.sleep(0.1) # emulates procesing time
if patientArray is None:
patientArray = getPatientArray(patientName, workerId, verbose=verbose)
sliceArray = patientArray[:, sliceId, :, :]
return patientArray, sliceArray
else:
sliceArray = patientArray[:, sliceId, :, :]
return None, sliceArray
QUEUE_TIMEOUT = 5.0
class myDataloaderClass:
def __init__(self, patientSlicesList, numWorkers, batchSize, verbose=False) -> None:
self.patientSlicesList = patientSlicesList
self.numWorkers = numWorkers
self.batchSize = batchSize
self.verbose = verbose
self._initWorkers()
def _initWorkers(self):
# Step 1 - Initialize vars
self.workerProcesses = []
self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)]
self.workerOutputQueue = torchMP.Queue()
self.workerProcessingEvent = [torchMP.Event() for _ in range(self.numWorkers)]
self.workerStopEvent = torchMP.Event() # used in getSlice() and self.closeProcesses()
self.mpLock = torchMP.Lock() # used in self.closeProcesses()
for workerId in range(self.numWorkers):
if self.verbose:
print (' - [myDataloaderClass._initWorkers()] Starting worker {}'.format(workerId))
p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue
, self.workerProcessingEvent[workerId], self.workerStopEvent, self.verbose)
, daemon=True)
p.start()
self.workerProcesses.append(p)
def __len__(self):
counter = 0
for patientName in self.patientSlicesList:
counter += len(self.patientSlicesList[patientName])
return counter
def fillInputQueues(self):
"""
This function allows to split patients and slices across workers
"""
patientNames = list(self.patientSlicesList.keys())
for workerId in range(self.numWorkers):
startIdx = workerId * len(patientNames) // self.numWorkers
endIdx = (workerId + 1) * len(patientNames) // self.numWorkers
# print (' - [myDataloaderClass.fillInputQueues()] Worker {} will process patients {} to {}: {}'.format(workerId, startIdx, endIdx, patientNames[startIdx:endIdx]))
for patientName in patientNames[startIdx:endIdx]:
for sliceId in self.patientSlicesList[patientName]:
self.workerInputQueues[workerId].put((patientName, sliceId))
def emptyAllQueues(self):
try:
for workerId in range(self.numWorkers):
while not self.workerInputQueues[workerId].empty():
try: _ = self.workerInputQueues[workerId].get_nowait()
except self.workerInputQueues[workerId].Empty: break
while not self.workerOutputQueue.empty():
try: _ = self.workerOutputQueue.get_nowait()
except self.workerOutputQueue.Empty: break
if self.verbose:
print (' - [myDataloaderClass.emptyAllQueues()] workerInputQueues: {} || workerOutputQueue: {}'.format(
[self.workerInputQueues[i].qsize() for i in range(self.workerCount)], self.workerOutputQueue.qsize())
)
except:
traceback.print_exc()
def __iter__(self):
try:
# Step 0 - Init
self.fillInputQueues() # once for each epoch
batchArray, batchMeta = [], []
# Step 1 - Continuously yield results
while True:
if not self.workerOutputQueue.empty():
# Step 2.1 - Get data point
patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT)
# Step 2.2 - Append to batch
if len(batchArray) < self.batchSize:
batchArray.append(patientSliceArray)
batchMeta.append([patientName, sliceId])
# Step 2.3 - Yield batch
if len(batchArray) == self.batchSize:
batchArray = collate_tensor_fn(batchArray)
batchMeta = np.vstack(batchMeta).T
yield batchArray, batchMeta
batchArray, batchMeta = [], []
# Step 3 - End condition
if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and np.all([not self.workerProcessingEvent[i].is_set() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty():
break
# Step 4 - If there are any pending data points
if len(batchArray) > 0 and len(batchMeta) > 0:
batchArray = collate_tensor_fn(batchArray)
batchMeta = np.vstack(batchMeta).T
yield batchArray, batchMeta
except GeneratorExit:
self.emptyAllQueues()
except KeyboardInterrupt:
self.closeProcesses()
except:
traceback.print_exc()
def closeProcesses(self):
try:
# Step 1 - Set stop event
with self.mpLock:
self.workerStopEvent.set() # this should break the while loop in all workers
# Step 2 - Join all workers
for workerId in range(len(self.workerProcesses)):
self.workerProcesses[workerId].join()
# Step 3 - Close all queues
for workerId in range(self.numWorkers):
self.workerInputQueues[workerId].cancel_join_thread() # The cancel_join_thread() method is used to prevent the background thread associated with a queue from joining the main thread when the program exits. By default, when a program terminates, it waits for all non-daemon threads to complete before exiting.
self.workerInputQueues[workerId].close()
self.workerOutputQueue.cancel_join_thread()
self.workerOutputQueue.close()
finally:
for workerId in range(self.numWorkers):
if self.workerProcesses[workerId].is_alive():
print (' - [myDataloaderClass.closeProcesses()] Worker {} is still alive. Terminating...'.format(workerId))
self.workerProcesses[workerId].terminate()
def getSlice(workerId, inputQueue, outputQueue, processingEvent, stopEvent, verbose, patientsInMemory=1):
try:
# Step 0 - Init
torch.set_num_threads(1)
patientObj = {} # Internal memort of this workerId
def managePatientObj():
if len(patientObj) > patientsInMemory:
patientObj.pop(list(patientObj.keys())[0])
while not stopEvent.is_set():
try:
# Step 0 - Get data point
patientName, sliceId = inputQueue.get(timeout=QUEUE_TIMEOUT)
# Step 1 - Get patient slice array
processingEvent.set()
patientArrayThis = patientObj.get(patientName, None)
patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, workerId, patientArray=patientArrayThis, verbose=verbose)
if patientArray is not None:
patientObj[patientName] = patientArray
managePatientObj()
# Step 2 - Put data point
outputQueue.put((patientSliceArray, patientName, sliceId))
processingEvent.clear()
except:
if inputQueue.qsize() == 0:
continue
except KeyboardInterrupt:
print("\n - [getSlice()][worker={}] function was interrupted.".format(workerId))
except:
traceback.print_exc()
finally:
if stopEvent.is_set():
if verbose: print (' - [getSlice] Stopping worker {}...'.format(workerId))
def collate_tensor_fn(batch):
elem = batch[0]
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a shared memory tensor to avoid an extra copy
numel = sum(x.numel() for x in batch)
storage = elem._typed_storage()._new_shared(numel, device=elem.device)
out = elem.new(storage).resize_(len(batch), *list(elem.size()))
return torch.stack(batch, 0, out=out)
##################################################
# Main
##################################################
if __name__ == '__main__':
# Step 1 - Setup patient slices (fixed count of slices per patient)
patientSlicesList = {
'P1': [45, 62, 32, 21, 69]
, 'P2': [13, 23, 87, 54, 5]
, 'P3': [34, 56, 78, 90, 12]
, 'P4': [34, 56, 78, 90, 12, 21]
, 'P5': [45, 62, 32, 21, 69]
, 'P6': [13, 23, 87, 54, 5]
, 'P7': [34, 56, 78, 90, 12]
, 'P8': [34, 56, 78, 90, 12, 21]
}
workerCount, batchSize, epochs = 4, 3, 3
# Step 2.1 - Create dataset and dataloader
try:
dataloaderNew = None
dataloaderNew = myDataloaderClass(patientSlicesList, numWorkers=workerCount, batchSize=batchSize, verbose=False)
# Step 2.2 - Iterate over dataloader
t0 = time.time()
epochTimes = []
print ('\n - [main] Iterating over (myNew) dataloader...')
for epochId in range(epochs):
print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs))
with tqdm.tqdm(total=len(dataloaderNew), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar:
t1 = time.time()
for i, (X, meta) in enumerate(dataloaderNew):
# print (' - [main] {}'.format(meta.tolist()))
pbar.update(X.shape[0])
epochTimes.append(time.time()-t1)
print ('')
dataloaderNew.closeProcesses()
except KeyboardInterrupt:
if dataloaderNew is not None: dataloaderNew.closeProcesses()
except:
traceback.print_exc()
if dataloaderNew is not None: dataloaderNew.closeProcesses()
print (' - [main] --------------------------------------- ')
print (' - [main] Total time: {:.2f} s (avg = {:.2f} s)'.format(time.time()-t0, np.mean(epochTimes)))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment