Last active
June 1, 2024 14:12
-
-
Save prerakmody/0c5e9263d42b2fab26a48dfb6b818cca to your computer and use it in GitHub Desktop.
TORCH MULTIPROCESSING
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torchDataloader | |
import myDataloader | |
import pdb | |
import copy | |
import time | |
import tqdm | |
import pprint | |
import traceback | |
import numpy as np | |
import torch | |
KEY_WORKERS = 'workers' | |
KEY_TIMELIST = 'timeList' | |
KEY_ITERPERSEC = 'iterPerSec' | |
KEY_EPOCHS = 'epochs' | |
KEY_BATCH_SIZE = 'batch_size' | |
KEY_TYPE = 'dataloader-type' | |
KEY_WB = 'Workers-BatchSize' | |
STR_TORCHDATALOADER = 'torchDataloader' | |
STR_MYDATALOADER = 'myDataloader' | |
nameCPU = 'cpu' | |
nameGPU = 'cuda:0' | |
deviceCPU = torch.device("cpu") | |
deviceGPU = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") | |
def calculateTime(datasetLen, dataloader, epochs, workerCount, batchSize, device=deviceCPU, meta=''): | |
timeForEpochs = [] | |
try: | |
# Step 1 - Loop over epochs | |
print ('') | |
for epoch in range(epochs): | |
with tqdm.tqdm(total=datasetLen, desc=' --------------------- [{}][W={}, B={}] Epoch {}/{}'.format(meta, workerCount, batchSize, epoch+1, epochs)) as pbar: | |
t1 = time.time() | |
# Step 2 - Loop over dataloader | |
for i, (patientSliceArray, batchMeta) in enumerate(dataloader): | |
patientSliceArray = patientSliceArray.to(device) | |
pbar.update(patientSliceArray.shape[0]) | |
t2= time.time() | |
timeForEpochs.append(t2-t1) | |
pbar.close() | |
except: | |
traceback.print_exc() | |
pdb.set_trace() | |
return timeForEpochs | |
if __name__ == "__main__": | |
try: | |
############################################################### Step 1 - Declare results object and other variables | |
resultsObj = {KEY_WORKERS: [], KEY_BATCH_SIZE: [], KEY_TIMELIST: [], KEY_ITERPERSEC: []} | |
# Step 1 - Setup patient slices (fixed count of slices per patient) | |
patientSlicesList = { | |
'P1': [45, 62, 32, 21, 69] | |
, 'P2': [13, 23, 87, 54, 5] | |
, 'P3': [34, 56, 78, 90, 12] | |
, 'P4': [34, 56, 78, 90, 12] | |
, 'P5': [45, 62, 32, 21, 69] | |
, 'P6': [13, 23, 87, 54, 5] | |
, 'P7': [34, 56, 78, 90, 12] | |
, 'P8': [34, 56, 78, 90, 12, 21] | |
} | |
# workerCountList, batchSizeList, totalEpochs = [1, 2, 4, 8], [1, 2, 4, 8], 6 | |
workerCountList, batchSizeList, totalEpochs = [1,2,4, 8], [1,2,3,4,6], 2 | |
device, deviceName = deviceCPU, nameCPU # deviceGPU | |
############################################################## Step 2 - Test torchDataLoader | |
torchDataloaderResultsObj = copy.deepcopy(resultsObj) | |
if 1: | |
torchDatasetObj = torchDataloader.myDataset(patientSlicesList) | |
torchDatasetLen = len(torchDatasetObj) | |
for workerCount in workerCountList: | |
for batchSize in batchSizeList: | |
torchDataloaderResultsObj[KEY_WORKERS].append(workerCount) | |
torchDataloaderResultsObj[KEY_BATCH_SIZE].append(batchSize) | |
dataloader1 = torch.utils.data.DataLoader(torchDatasetObj, batch_size=batchSize, num_workers=workerCount) | |
timeForEpochs = calculateTime(torchDatasetLen, dataloader1, totalEpochs, workerCount, batchSize, device, meta='torchDataloader') | |
itersPerSec = [torchDatasetLen / timeForEpoch for timeForEpoch in timeForEpochs] | |
torchDataloaderResultsObj[KEY_TIMELIST].append(timeForEpochs) | |
torchDataloaderResultsObj[KEY_ITERPERSEC].append(itersPerSec) | |
############################################################## Step 3 - Test myDataLoader | |
myDataloaderResultsObj = copy.deepcopy(resultsObj) | |
if 1: | |
for workerCount in workerCountList: | |
for batchSize in batchSizeList: | |
myDataloaderResultsObj[KEY_BATCH_SIZE].append(batchSize) | |
myDataloaderResultsObj[KEY_WORKERS].append(workerCount) | |
myDataloaderObj = myDataloader.myDataloaderClass(patientSlicesList, numWorkers=workerCount, batchSize=batchSize) | |
myDataloaderLen = len(myDataloaderObj) | |
timeForEpochs = calculateTime(myDataloaderLen, myDataloaderObj, totalEpochs, workerCount, batchSize, device, meta='myDataloader') | |
itersPerSec = [myDataloaderLen / timeForEpoch for timeForEpoch in timeForEpochs] | |
myDataloaderResultsObj[KEY_TIMELIST].append(timeForEpochs) | |
myDataloaderResultsObj[KEY_ITERPERSEC].append(itersPerSec) | |
time.sleep(2) | |
myDataloaderObj.closeProcesses() | |
############################################################### Step 5 - Print results | |
if 1: | |
# Step 5.1 - convert to dataframe | |
import pandas as pd | |
torchDataloaderResultsObj = pd.DataFrame(torchDataloaderResultsObj) | |
myDataloaderResultsObj = pd.DataFrame(myDataloaderResultsObj) | |
# Step 5.2 - Perform a mean operation on KEY_TIMELIST and KEY_ITERPERSEC | |
torchDataloaderResultsObj[KEY_TIMELIST] = torchDataloaderResultsObj[KEY_TIMELIST].apply(lambda x: np.mean(x)) | |
myDataloaderResultsObj[KEY_TIMELIST] = myDataloaderResultsObj[KEY_TIMELIST].apply(lambda x: np.mean(x)) | |
torchDataloaderResultsObj[KEY_ITERPERSEC] = torchDataloaderResultsObj[KEY_ITERPERSEC].apply(lambda x: np.mean(x)) | |
myDataloaderResultsObj[KEY_ITERPERSEC] = myDataloaderResultsObj[KEY_ITERPERSEC].apply(lambda x: np.mean(x)) | |
# Step 5.3 - Merger KEY_WORKERS and KEY_BATCH_SIZE columns | |
torchDataloaderResultsObj[KEY_WORKERS] = torchDataloaderResultsObj[KEY_WORKERS].astype(str) | |
torchDataloaderResultsObj[KEY_BATCH_SIZE] = torchDataloaderResultsObj[KEY_BATCH_SIZE].astype(str) | |
torchDataloaderResultsObj[KEY_WB] = torchDataloaderResultsObj[KEY_WORKERS] + '-' + torchDataloaderResultsObj[KEY_BATCH_SIZE] | |
torchDataloaderResultsObj.drop([KEY_WORKERS, KEY_BATCH_SIZE], axis=1, inplace=True) | |
myDataloaderResultsObj[KEY_WORKERS] = myDataloaderResultsObj[KEY_WORKERS].astype(str) | |
myDataloaderResultsObj[KEY_BATCH_SIZE] = myDataloaderResultsObj[KEY_BATCH_SIZE].astype(str) | |
myDataloaderResultsObj[KEY_WB] = myDataloaderResultsObj[KEY_WORKERS] + '-' + myDataloaderResultsObj[KEY_BATCH_SIZE] | |
myDataloaderResultsObj.drop([KEY_WORKERS, KEY_BATCH_SIZE], axis=1, inplace=True) | |
# Step 5.4 - Add meta info | |
torchDataloaderResultsObj[KEY_TYPE] = STR_TORCHDATALOADER | |
myDataloaderResultsObj[KEY_TYPE] = STR_MYDATALOADER | |
# Step 5.5 - COncatenate the two dataframes | |
vizObj = pd.concat([torchDataloaderResultsObj, myDataloaderResultsObj], axis=0) | |
# Step 5.4 - Plot (with seaborn) the dataframe using KEY_WB as the x-axis | |
import seaborn as sns | |
import matplotlib.pyplot as plt | |
f,axarr = plt.subplots(1,2, figsize=(15, 5)) | |
sns.barplot(x=KEY_WB, y=KEY_TIMELIST, data=vizObj, hue=KEY_TYPE, ax=axarr[0]) | |
axarr[0].set_title('Total Time (average over {} epochs)'.format(totalEpochs)) | |
sns.barplot(x=KEY_WB, y=KEY_ITERPERSEC, data=vizObj, hue=KEY_TYPE, ax=axarr[1]) | |
axarr[1].set_title('Iterations/second (average over {} epochs)'.format(totalEpochs)) | |
plt.savefig('dataloaderCompare__W{}__B{}.png'.format('-'.join(map(str, workerCountList)), '-'.join(map(str, batchSizeList)))) | |
plt.show(block=False) | |
pdb.set_trace() | |
except: | |
traceback.print_exc() | |
pdb.set_trace() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tqdm | |
import time | |
import torch # v1.12.1 | |
import numpy as np | |
################################################## | |
# myDataset | |
################################################## | |
def getPatientArray(patientName, verbose=False): | |
time.sleep(1.0) # emulates disk read time | |
workerId = torch.utils.data.get_worker_info().id if torch.utils.data.get_worker_info() else 0 | |
if verbose: | |
print (' - [getPatientArray()][worker={}] Loading volumes for patient: {}'.format(workerId, patientName)) | |
return torch.tensor(np.random.rand(1, 128, 128, 128)) | |
def getPatientSliceArray(patientName, sliceId, patientArray=None, verbose=False): | |
time.sleep(0.1) # emulates procesing time | |
if patientArray is None: | |
patientArray = getPatientArray(patientName, verbose=verbose) | |
sliceArray = patientArray[:, sliceId, :, :] | |
return patientArray, sliceArray | |
else: | |
sliceArray = patientArray[:, sliceId, :, :] | |
return None, sliceArray | |
class myDataset(torch.utils.data.Dataset): | |
def __init__(self, patientSlicesList, verbose=False, patientsInMemory=1): | |
self.patientSlicesList = patientSlicesList | |
self.verbose = verbose | |
self.patientsInMemory = patientsInMemory | |
self.patientObj = {} # To store one patients 3D array. More patients lead to more memory usage. | |
self.inputQueue = [] | |
for patientName in self.patientSlicesList: | |
for sliceId in self.patientSlicesList[patientName]: | |
self.inputQueue.append((patientName, sliceId)) | |
def __len__(self): | |
return len(self.inputQueue) | |
def _managePatientObj(self): | |
if len(self.patientObj) > self.patientsInMemory: | |
self.patientObj.pop(list(self.patientObj.keys())[0]) | |
def __getitem__(self, idx): | |
# workerId = torch.utils.data.get_worker_info().id if torch.utils.data.get_worker_info() else 0 | |
# print (' - [__getitem__()][worker={}] idx: {}'.format(workerId, idx)) | |
# Step 0 - Init | |
patientName, sliceId = self.inputQueue[idx] | |
# Step 1 - Get patient slice array | |
patientArrayThis = self.patientObj.get(patientName, None) | |
patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, patientArray=patientArrayThis, verbose=self.verbose) | |
if patientArray is not None: | |
self.patientObj[patientName] = patientArray | |
self._managePatientObj() | |
return patientSliceArray, [patientName, sliceId] | |
################################################## | |
# Main | |
################################################## | |
if __name__ == '__main__': | |
# Step 1 - Setup patient slices (fixed count of slices per patient) | |
patientSlicesList = { | |
'P1': [45, 62, 32, 21, 69] | |
, 'P2': [13, 23, 87, 54, 5] | |
, 'P3': [34, 56, 78, 90, 12] | |
, 'P4': [34, 56, 78, 90, 12] | |
} | |
workerCount, batchSize, epochs = 4, 3, 3 | |
# Step 2.1 - Create dataset and dataloader | |
dataset = myDataset(patientSlicesList) | |
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batchSize, num_workers=workerCount) | |
# Step 2.2 - Iterate over dataloader | |
print ('\n - [main] Iterating over (my) dataloader...') | |
t0 = time.time() | |
epochTimes = [] | |
for epochId in range(epochs): | |
print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) | |
with tqdm.tqdm(total=len(dataset), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: | |
t1 = time.time() | |
for i, (patientSliceArray, meta) in enumerate(dataloader): | |
print (' - [main] meta: ', meta) | |
pbar.update(patientSliceArray.shape[0]) | |
epochTimes.append(time.time()-t1) | |
print (' - [main] --------------------------------------- ') | |
print (' - [main] Total time: {:.2f} s (avg = {:.2f} s)'.format(time.time()-t0, np.mean(epochTimes))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import tqdm | |
import torch | |
import traceback | |
import numpy as np | |
import torch.multiprocessing as torchMP | |
def getPatientArray(patientName, workerId, verbose=False): | |
time.sleep(1.0) # emulates disk read time | |
workerId = torch.utils.data.get_worker_info().id if torch.utils.data.get_worker_info() else 0 | |
if verbose: | |
print (' - [getPatientArray()][worker={}] Loading volumes for patient: {}'.format(workerId, patientName)) | |
return torch.tensor(np.random.rand(1, 128, 128, 128)) | |
def getPatientSliceArray(patientName, sliceId, workerId, patientArray=None, verbose=False): | |
time.sleep(0.1) # emulates procesing time | |
if patientArray is None: | |
patientArray = getPatientArray(patientName, workerId, verbose=verbose) | |
sliceArray = patientArray[:, sliceId, :, :] | |
return patientArray, sliceArray | |
else: | |
sliceArray = patientArray[:, sliceId, :, :] | |
return None, sliceArray | |
QUEUE_TIMEOUT = 5.0 | |
class myDataloaderClass: | |
def __init__(self, patientSlicesList, numWorkers, batchSize, verbose=False) -> None: | |
self.patientSlicesList = patientSlicesList | |
self.numWorkers = numWorkers | |
self.batchSize = batchSize | |
self.verbose = verbose | |
self._initWorkers() | |
def _initWorkers(self): | |
# Step 1 - Initialize vars | |
self.workerProcesses = [] | |
self.workerInputQueues = [torchMP.Queue() for _ in range(self.numWorkers)] | |
self.workerOutputQueue = torchMP.Queue() | |
self.workerProcessingEvent = [torchMP.Event() for _ in range(self.numWorkers)] | |
self.workerStopEvent = torchMP.Event() # used in getSlice() and self.closeProcesses() | |
self.mpLock = torchMP.Lock() # used in self.closeProcesses() | |
for workerId in range(self.numWorkers): | |
if self.verbose: | |
print (' - [myDataloaderClass._initWorkers()] Starting worker {}'.format(workerId)) | |
p = torchMP.Process(target=getSlice, args=(workerId, self.workerInputQueues[workerId], self.workerOutputQueue | |
, self.workerProcessingEvent[workerId], self.workerStopEvent, self.verbose) | |
, daemon=True) | |
p.start() | |
self.workerProcesses.append(p) | |
def __len__(self): | |
counter = 0 | |
for patientName in self.patientSlicesList: | |
counter += len(self.patientSlicesList[patientName]) | |
return counter | |
def fillInputQueues(self): | |
""" | |
This function allows to split patients and slices across workers | |
""" | |
patientNames = list(self.patientSlicesList.keys()) | |
for workerId in range(self.numWorkers): | |
startIdx = workerId * len(patientNames) // self.numWorkers | |
endIdx = (workerId + 1) * len(patientNames) // self.numWorkers | |
# print (' - [myDataloaderClass.fillInputQueues()] Worker {} will process patients {} to {}: {}'.format(workerId, startIdx, endIdx, patientNames[startIdx:endIdx])) | |
for patientName in patientNames[startIdx:endIdx]: | |
for sliceId in self.patientSlicesList[patientName]: | |
self.workerInputQueues[workerId].put((patientName, sliceId)) | |
def emptyAllQueues(self): | |
try: | |
for workerId in range(self.numWorkers): | |
while not self.workerInputQueues[workerId].empty(): | |
try: _ = self.workerInputQueues[workerId].get_nowait() | |
except self.workerInputQueues[workerId].Empty: break | |
while not self.workerOutputQueue.empty(): | |
try: _ = self.workerOutputQueue.get_nowait() | |
except self.workerOutputQueue.Empty: break | |
if self.verbose: | |
print (' - [myDataloaderClass.emptyAllQueues()] workerInputQueues: {} || workerOutputQueue: {}'.format( | |
[self.workerInputQueues[i].qsize() for i in range(self.workerCount)], self.workerOutputQueue.qsize()) | |
) | |
except: | |
traceback.print_exc() | |
def __iter__(self): | |
try: | |
# Step 0 - Init | |
self.fillInputQueues() # once for each epoch | |
batchArray, batchMeta = [], [] | |
# Step 1 - Continuously yield results | |
while True: | |
if not self.workerOutputQueue.empty(): | |
# Step 2.1 - Get data point | |
patientSliceArray, patientName, sliceId = self.workerOutputQueue.get(timeout=QUEUE_TIMEOUT) | |
# Step 2.2 - Append to batch | |
if len(batchArray) < self.batchSize: | |
batchArray.append(patientSliceArray) | |
batchMeta.append([patientName, sliceId]) | |
# Step 2.3 - Yield batch | |
if len(batchArray) == self.batchSize: | |
batchArray = collate_tensor_fn(batchArray) | |
batchMeta = np.vstack(batchMeta).T | |
yield batchArray, batchMeta | |
batchArray, batchMeta = [], [] | |
# Step 3 - End condition | |
if np.all([self.workerInputQueues[i].empty() for i in range(self.numWorkers)]) and np.all([not self.workerProcessingEvent[i].is_set() for i in range(self.numWorkers)]) and self.workerOutputQueue.empty(): | |
break | |
# Step 4 - If there are any pending data points | |
if len(batchArray) > 0 and len(batchMeta) > 0: | |
batchArray = collate_tensor_fn(batchArray) | |
batchMeta = np.vstack(batchMeta).T | |
yield batchArray, batchMeta | |
except GeneratorExit: | |
self.emptyAllQueues() | |
except KeyboardInterrupt: | |
self.closeProcesses() | |
except: | |
traceback.print_exc() | |
def closeProcesses(self): | |
try: | |
# Step 1 - Set stop event | |
with self.mpLock: | |
self.workerStopEvent.set() # this should break the while loop in all workers | |
# Step 2 - Join all workers | |
for workerId in range(len(self.workerProcesses)): | |
self.workerProcesses[workerId].join() | |
# Step 3 - Close all queues | |
for workerId in range(self.numWorkers): | |
self.workerInputQueues[workerId].cancel_join_thread() # The cancel_join_thread() method is used to prevent the background thread associated with a queue from joining the main thread when the program exits. By default, when a program terminates, it waits for all non-daemon threads to complete before exiting. | |
self.workerInputQueues[workerId].close() | |
self.workerOutputQueue.cancel_join_thread() | |
self.workerOutputQueue.close() | |
finally: | |
for workerId in range(self.numWorkers): | |
if self.workerProcesses[workerId].is_alive(): | |
print (' - [myDataloaderClass.closeProcesses()] Worker {} is still alive. Terminating...'.format(workerId)) | |
self.workerProcesses[workerId].terminate() | |
def getSlice(workerId, inputQueue, outputQueue, processingEvent, stopEvent, verbose, patientsInMemory=1): | |
try: | |
# Step 0 - Init | |
torch.set_num_threads(1) | |
patientObj = {} # Internal memort of this workerId | |
def managePatientObj(): | |
if len(patientObj) > patientsInMemory: | |
patientObj.pop(list(patientObj.keys())[0]) | |
while not stopEvent.is_set(): | |
try: | |
# Step 0 - Get data point | |
patientName, sliceId = inputQueue.get(timeout=QUEUE_TIMEOUT) | |
# Step 1 - Get patient slice array | |
processingEvent.set() | |
patientArrayThis = patientObj.get(patientName, None) | |
patientArray, patientSliceArray = getPatientSliceArray(patientName, sliceId, workerId, patientArray=patientArrayThis, verbose=verbose) | |
if patientArray is not None: | |
patientObj[patientName] = patientArray | |
managePatientObj() | |
# Step 2 - Put data point | |
outputQueue.put((patientSliceArray, patientName, sliceId)) | |
processingEvent.clear() | |
except: | |
if inputQueue.qsize() == 0: | |
continue | |
except KeyboardInterrupt: | |
print("\n - [getSlice()][worker={}] function was interrupted.".format(workerId)) | |
except: | |
traceback.print_exc() | |
finally: | |
if stopEvent.is_set(): | |
if verbose: print (' - [getSlice] Stopping worker {}...'.format(workerId)) | |
def collate_tensor_fn(batch): | |
elem = batch[0] | |
out = None | |
if torch.utils.data.get_worker_info() is not None: | |
# If we're in a background process, concatenate directly into a shared memory tensor to avoid an extra copy | |
numel = sum(x.numel() for x in batch) | |
storage = elem._typed_storage()._new_shared(numel, device=elem.device) | |
out = elem.new(storage).resize_(len(batch), *list(elem.size())) | |
return torch.stack(batch, 0, out=out) | |
################################################## | |
# Main | |
################################################## | |
if __name__ == '__main__': | |
# Step 1 - Setup patient slices (fixed count of slices per patient) | |
patientSlicesList = { | |
'P1': [45, 62, 32, 21, 69] | |
, 'P2': [13, 23, 87, 54, 5] | |
, 'P3': [34, 56, 78, 90, 12] | |
, 'P4': [34, 56, 78, 90, 12, 21] | |
, 'P5': [45, 62, 32, 21, 69] | |
, 'P6': [13, 23, 87, 54, 5] | |
, 'P7': [34, 56, 78, 90, 12] | |
, 'P8': [34, 56, 78, 90, 12, 21] | |
} | |
workerCount, batchSize, epochs = 4, 3, 3 | |
# Step 2.1 - Create dataset and dataloader | |
try: | |
dataloaderNew = None | |
dataloaderNew = myDataloaderClass(patientSlicesList, numWorkers=workerCount, batchSize=batchSize, verbose=False) | |
# Step 2.2 - Iterate over dataloader | |
t0 = time.time() | |
epochTimes = [] | |
print ('\n - [main] Iterating over (myNew) dataloader...') | |
for epochId in range(epochs): | |
print (' - [main] --------------------------------------- Epoch {}/{}'.format(epochId+1, epochs)) | |
with tqdm.tqdm(total=len(dataloaderNew), desc=' - Epoch {}/{}'.format(epochId+1, epochs)) as pbar: | |
t1 = time.time() | |
for i, (X, meta) in enumerate(dataloaderNew): | |
# print (' - [main] {}'.format(meta.tolist())) | |
pbar.update(X.shape[0]) | |
epochTimes.append(time.time()-t1) | |
print ('') | |
dataloaderNew.closeProcesses() | |
except KeyboardInterrupt: | |
if dataloaderNew is not None: dataloaderNew.closeProcesses() | |
except: | |
traceback.print_exc() | |
if dataloaderNew is not None: dataloaderNew.closeProcesses() | |
print (' - [main] --------------------------------------- ') | |
print (' - [main] Total time: {:.2f} s (avg = {:.2f} s)'.format(time.time()-t0, np.mean(epochTimes))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment