Created
December 16, 2019 18:13
-
-
Save akshay-3apr/69f3d62115a8f6c0c405babe83a79f95 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Split into a training set and a test set using a stratified k fold | |
train_inputs,validation_inputs,train_labels,validation_labels = train_test_split(input_ids,labels,random_state=SEED,test_size=0.1) | |
train_masks,validation_masks,_,_ = train_test_split(attention_masks,input_ids,random_state=SEED,test_size=0.1) | |
# convert all our data into torch tensors, required data type for our model | |
train_inputs = torch.tensor(train_inputs) | |
validation_inputs = torch.tensor(validation_inputs) | |
train_labels = torch.tensor(train_labels) | |
validation_labels = torch.tensor(validation_labels) | |
train_masks = torch.tensor(train_masks) | |
validation_masks = torch.tensor(validation_masks) | |
# Select a batch size for training. For fine-tuning BERT on a specific task, the authors recommend a batch size of 16 or 32 | |
batch_size = 32 | |
# Create an iterator of our data with torch DataLoader. This helps save on memory during training because, unlike a for loop, | |
# with an iterator the entire dataset does not need to be loaded into memory | |
train_data = TensorDataset(train_inputs,train_masks,train_labels) | |
train_sampler = RandomSampler(train_data) | |
train_dataloader = DataLoader(train_data,sampler=train_sampler,batch_size=batch_size) | |
validation_data = TensorDataset(validation_inputs,validation_masks,validation_labels) | |
validation_sampler = RandomSampler(validation_data) | |
validation_dataloader = DataLoader(validation_data,sampler=validation_sampler,batch_size=batch_size) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment