Skip to content

Instantly share code, notes, and snippets.

@pythonlessons
Created September 4, 2023 15:04
Show Gist options
  • Save pythonlessons/71bd7c0d4a868b3e527eab99b21737a3 to your computer and use it in GitHub Desktop.
Save pythonlessons/71bd7c0d4a868b3e527eab99b21737a3 to your computer and use it in GitHub Desktop.
transformers_training
def preprocess_inputs(data_batch, label_batch):
encoder_input = np.zeros((len(data_batch), tokenizer.max_length)).astype(np.int64)
decoder_input = np.zeros((len(label_batch), detokenizer.max_length)).astype(np.int64)
decoder_output = np.zeros((len(label_batch), detokenizer.max_length)).astype(np.int64)
data_batch_tokens = tokenizer.texts_to_sequences(data_batch)
label_batch_tokens = detokenizer.texts_to_sequences(label_batch)
for index, (data, label) in enumerate(zip(data_batch_tokens, label_batch_tokens)):
encoder_input[index][:len(data)] = data
decoder_input[index][:len(label)-1] = label[:-1] # Drop the [END] tokens
decoder_output[index][:len(label)-1] = label[1:] # Drop the [START] tokens
return (encoder_input, decoder_input), decoder_output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment