Skip to content

Instantly share code, notes, and snippets.

@akshay-3apr
Created December 16, 2019 18:13
Show Gist options
  • Save akshay-3apr/69f3d62115a8f6c0c405babe83a79f95 to your computer and use it in GitHub Desktop.
Save akshay-3apr/69f3d62115a8f6c0c405babe83a79f95 to your computer and use it in GitHub Desktop.
# Split into a training set and a test set using a stratified k fold
train_inputs,validation_inputs,train_labels,validation_labels = train_test_split(input_ids,labels,random_state=SEED,test_size=0.1)
train_masks,validation_masks,_,_ = train_test_split(attention_masks,input_ids,random_state=SEED,test_size=0.1)
# convert all our data into torch tensors, required data type for our model
train_inputs = torch.tensor(train_inputs)
validation_inputs = torch.tensor(validation_inputs)
train_labels = torch.tensor(train_labels)
validation_labels = torch.tensor(validation_labels)
train_masks = torch.tensor(train_masks)
validation_masks = torch.tensor(validation_masks)
# Select a batch size for training. For fine-tuning BERT on a specific task, the authors recommend a batch size of 16 or 32
batch_size = 32
# Create an iterator of our data with torch DataLoader. This helps save on memory during training because, unlike a for loop,
# with an iterator the entire dataset does not need to be loaded into memory
train_data = TensorDataset(train_inputs,train_masks,train_labels)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data,sampler=train_sampler,batch_size=batch_size)
validation_data = TensorDataset(validation_inputs,validation_masks,validation_labels)
validation_sampler = RandomSampler(validation_data)
validation_dataloader = DataLoader(validation_data,sampler=validation_sampler,batch_size=batch_size)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment