Last active
July 25, 2019 11:37
-
-
Save ben0it8/bd350b67be2ba6eec622d2345df31a76 to your computer and use it in GitHub Desktop.
Create dataloders
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from concurrent.futures import ProcessPoolExecutor | |
from multiprocessing import cpu_count | |
from itertools import repeat | |
num_cores = cpu_count() | |
def process_row(processor, row): | |
return processor.process_example((row[1][LABEL_COL], row[1][TEXT_COL])) | |
def create_dataloader(df: pd.DataFrame, | |
processor: TextProcessor, | |
batch_size: int = 32, | |
shuffle: bool = False, | |
valid_pct: float = None, | |
text_col: str = "text", | |
label_col: str = "label"): | |
"Process rows in `df` with `num_cores` workers using `processor`." | |
with ProcessPoolExecutor(max_workers=num_cores) as executor: | |
result = list( | |
tqdm(executor.map(process_row, | |
repeat(processor), | |
df.iterrows(), | |
chunksize=len(df) // 10), | |
desc=f"Processing {len(df)} examples on {num_cores} cores", | |
total=len(df))) | |
features = [r[0] for r in result] | |
labels = [r[1] for r in result] | |
dataset = TensorDataset(torch.tensor(features, dtype=torch.long), | |
torch.tensor(labels, dtype=torch.long)) | |
if valid_pct is not None: | |
valid_size = int(valid_pct * len(df)) | |
train_size = len(df) - valid_size | |
valid_dataset, train_dataset = random_split(dataset, | |
[valid_size, train_size]) | |
valid_loader = DataLoader(valid_dataset, | |
batch_size=batch_size, | |
shuffle=False) | |
train_loader = DataLoader(train_dataset, | |
batch_size=batch_size, | |
shuffle=True) | |
return train_loader, valid_loader | |
data_loader = DataLoader(dataset, | |
batch_size=batch_size, | |
num_workers=0, | |
shuffle=shuffle, | |
pin_memory=torch.cuda.is_available()) | |
return data_loader |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment