When working with imbalanced data for machine learning tasks in PyTorch, and simple random split might not be able to partly divide classes that are not well represented. Resulting sample splits might not portray the real-world population, leading to poor predictive peformance in the resulting model.
Therefore, I have created a simple function for conducting a stratified split with random shuffling, similar to that of StratifiedShuffleSplit from scikit-learn (https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedShuffleSplit.html)
import random
import math
import torch.utils.data