Last active
April 22, 2023 14:01
-
-
Save alexcpn/b9bb2b0f01833d1bb862502faf99bab8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import T5Tokenizer | |
import numpy as np | |
class FlaxDataCollatorForT5MLM: | |
""" | |
From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py | |
""" | |
def __init__(self,tokenizer,noise_density,mean_noise_span_length) -> None: | |
self.tokenizer = tokenizer | |
self.noise_density = noise_density | |
self.mean_noise_span_length =mean_noise_span_length | |
def create_sentinel_ids(self, mask_indices): | |
""" | |
Sentinel ids creation given the indices that should be masked. | |
The start indices of each mask are replaced by the sentinel ids in increasing | |
order. Consecutive mask indices to be deleted are replaced with `-1`. | |
""" | |
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices | |
start_indices[:, 0] = mask_indices[:, 0] | |
sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices) | |
sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0) | |
sentinel_ids -= mask_indices - start_indices | |
return sentinel_ids | |
def filter_input_ids(self, input_ids, sentinel_ids): | |
""" | |
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. | |
This will reduce the sequence length from `expanded_inputs_length` to `input_length`. | |
""" | |
batch_size = input_ids.shape[0] | |
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) | |
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are | |
# masked tokens coming after sentinel tokens and should be removed | |
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) | |
input_ids = np.concatenate( | |
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1 | |
) | |
return input_ids | |
def random_spans_noise_mask(self, length): | |
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . | |
Noise mask consisting of random spans of noise tokens. | |
The number of noise tokens and the number of noise spans and non-noise spans | |
are determined deterministically as follows: | |
num_noise_tokens = round(length * noise_density) | |
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) | |
Spans alternate between non-noise and noise, beginning with non-noise. | |
Subject to the above restrictions, all masks are equally likely. | |
Args: | |
length: an int32 scalar (length of the incoming token sequence) | |
noise_density: a float - approximate density of output mask | |
mean_noise_span_length: a number | |
Returns: | |
a boolean tensor with shape [length] | |
""" | |
orig_length = length | |
num_noise_tokens = int(np.round(length * self.noise_density)) | |
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. | |
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) | |
#num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) | |
num_nonnoise_tokens = length - num_noise_tokens | |
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) | |
#num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens | |
#num_noise_spans = int(np.round(min(num_noise_tokens,num_nonnoise_tokens) / self.mean_noise_span_length)) | |
# avoid degeneracy by ensuring positive number of noise spans | |
num_noise_spans = max(num_noise_spans, 1) | |
# pick the lengths of the noise spans and the non-noise spans | |
def _random_segmentation(num_items, num_segments): | |
"""Partition a sequence of items randomly into non-empty segments. | |
Args: | |
num_items: an integer scalar > 0 | |
num_segments: an integer scalar in [1, num_items] | |
Returns: | |
a Tensor with shape [num_segments] containing positive integers that add | |
up to num_items | |
""" | |
if num_segments <= num_items: | |
print(f"num_segments {num_segments} <=num_segments {num_items}!! this is going to be a problem") | |
mask_indices = np.arange(num_items - 1) < (num_segments - 1) | |
np.random.shuffle(mask_indices) | |
first_in_segment = np.pad(mask_indices, [[1, 0]]) | |
segment_id = np.cumsum(first_in_segment) | |
# count length of sub segments assuming that list is sorted | |
_, segment_length = np.unique(segment_id, return_counts=True) | |
return segment_length | |
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) | |
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) | |
interleaved_span_lengths = np.reshape( | |
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2] | |
) | |
span_starts = np.cumsum(interleaved_span_lengths)[:-1] | |
span_start_indicator = np.zeros((length,), dtype=np.int8) | |
span_start_indicator[span_starts] = True | |
span_num = np.cumsum(span_start_indicator) | |
is_noise = np.equal(span_num % 2, 1) | |
return is_noise[:orig_length] | |
if __name__ == '__main__': | |
model_name = 't5-base' | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
len_tokenizer =len(tokenizer) # 32100 to get the sentinel ids | |
print(f"len_tokenizer={len_tokenizer}") | |
# Unsupervised denoising training | |
# https://huggingface.co/docs/transformers/main/model_doc/t5#training | |
print("-"*20) | |
prompt = "The <extra_id_0> walks in <extra_id_1> park" | |
encoded_prompt = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt").input_ids | |
print(f"encoded_prompt ={encoded_prompt}") | |
labels ="<extra_id_0> cute dog <extra_id_1> the <extra_id_2>" | |
encoded_labels = tokenizer(labels, truncation=False, padding=False, return_tensors="pt").input_ids | |
print(f"encoded_labels ={encoded_labels}") | |
print(f"{encoded_prompt.shape} ={encoded_labels.shape}") | |
print("-"*20) | |
# simulating the above | |
prompt = "The cute dog walks in the green park" | |
encoded = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt").input_ids | |
batch_size =1 | |
input_length = encoded.shape[1] | |
denoiser = FlaxDataCollatorForT5MLM(tokenizer,.35,1) # okay | |
denoiser = FlaxDataCollatorForT5MLM(tokenizer,.55,1) # not ok | |
mask_indices = np.asarray([denoiser.random_spans_noise_mask(input_length) for i in range(batch_size)]) | |
labels_mask = ~mask_indices | |
input_ids_sentinel = denoiser.create_sentinel_ids(mask_indices.astype(np.int8)) | |
labels_sentinel = denoiser.create_sentinel_ids(labels_mask.astype(np.int8)) | |
input_ids = denoiser.filter_input_ids(encoded, input_ids_sentinel) | |
labels = denoiser.filter_input_ids(encoded, labels_sentinel) | |
print(f"input_ids decoded = {tokenizer.decode(*input_ids,skip_special_tokens=False)}") | |
print(f"labels decoded = {tokenizer.decode(*labels,skip_special_tokens=False)}") | |
print(f"input_ids.shape {input_ids.shape} labels.shape {labels.shape}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from transformers import T5Tokenizer | |
import numpy as np | |
class FlaxDataCollatorForT5MLM: | |
""" | |
From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py | |
""" | |
def __init__(self,tokenizer,noise_density,mean_noise_span_length) -> None: | |
self.tokenizer = tokenizer | |
self.noise_density = noise_density | |
self.mean_noise_span_length =mean_noise_span_length | |
def create_sentinel_ids(self, mask_indices): | |
""" | |
Sentinel ids creation given the indices that should be masked. | |
The start indices of each mask are replaced by the sentinel ids in increasing | |
order. Consecutive mask indices to be deleted are replaced with `-1`. | |
""" | |
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices | |
start_indices[:, 0] = mask_indices[:, 0] | |
sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices) | |
sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0) | |
sentinel_ids -= mask_indices - start_indices | |
return sentinel_ids | |
def filter_input_ids(self, input_ids, sentinel_ids): | |
""" | |
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting. | |
This will reduce the sequence length from `expanded_inputs_length` to `input_length`. | |
""" | |
batch_size = input_ids.shape[0] | |
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids) | |
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are | |
# masked tokens coming after sentinel tokens and should be removed | |
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1)) | |
input_ids = np.concatenate( | |
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1 | |
) | |
return input_ids | |
def random_spans_noise_mask(self, length): | |
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ . | |
Noise mask consisting of random spans of noise tokens. | |
The number of noise tokens and the number of noise spans and non-noise spans | |
are determined deterministically as follows: | |
num_noise_tokens = round(length * noise_density) | |
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length) | |
Spans alternate between non-noise and noise, beginning with non-noise. | |
Subject to the above restrictions, all masks are equally likely. | |
Args: | |
length: an int32 scalar (length of the incoming token sequence) | |
noise_density: a float - approximate density of output mask | |
mean_noise_span_length: a number | |
Returns: | |
a boolean tensor with shape [length] | |
""" | |
orig_length = length | |
num_noise_tokens = int(np.round(length * self.noise_density)) | |
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens. | |
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1) | |
#num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) | |
num_nonnoise_tokens = length - num_noise_tokens | |
#num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) <--- Bug | |
#num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens | |
num_noise_spans = int(np.round(min(num_noise_tokens,num_nonnoise_tokens) / self.mean_noise_span_length)) | |
# avoid degeneracy by ensuring positive number of noise spans | |
num_noise_spans = max(num_noise_spans, 1) | |
# pick the lengths of the noise spans and the non-noise spans | |
def _random_segmentation(num_items, num_segments): | |
"""Partition a sequence of items randomly into non-empty segments. | |
Args: | |
num_items: an integer scalar > 0 | |
num_segments: an integer scalar in [1, num_items] | |
Returns: | |
a Tensor with shape [num_segments] containing positive integers that add | |
up to num_items | |
""" | |
if num_segments <= num_items: | |
print(f"num_segments {num_segments} <=num_segments {num_items}!! this is going to be a problem") | |
mask_indices = np.arange(num_items - 1) < (num_segments - 1) | |
np.random.shuffle(mask_indices) | |
first_in_segment = np.pad(mask_indices, [[1, 0]]) | |
segment_id = np.cumsum(first_in_segment) | |
# count length of sub segments assuming that list is sorted | |
_, segment_length = np.unique(segment_id, return_counts=True) | |
return segment_length | |
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans) | |
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans) | |
interleaved_span_lengths = np.reshape( | |
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2] | |
) | |
span_starts = np.cumsum(interleaved_span_lengths)[:-1] | |
span_start_indicator = np.zeros((length,), dtype=np.int8) | |
span_start_indicator[span_starts] = True | |
span_num = np.cumsum(span_start_indicator) | |
is_noise = np.equal(span_num % 2, 1) | |
return is_noise[:orig_length] | |
if __name__ == '__main__': | |
model_name = 't5-base' | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
len_tokenizer =len(tokenizer) # 32100 to get the sentinel ids | |
print(f"len_tokenizer={len_tokenizer}") | |
# Unsupervised denoising training | |
# https://huggingface.co/docs/transformers/main/model_doc/t5#training | |
print("-"*20) | |
prompt = "The <extra_id_0> walks in <extra_id_1> park" | |
encoded_prompt = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt").input_ids | |
print(f"encoded_prompt ={encoded_prompt}") | |
labels ="<extra_id_0> cute dog <extra_id_1> the <extra_id_2>" | |
encoded_labels = tokenizer(labels, truncation=False, padding=False, return_tensors="pt").input_ids | |
print(f"encoded_labels ={encoded_labels}") | |
print(f"{encoded_prompt.shape} ={encoded_labels.shape}") | |
print("-"*20) | |
# simulating the above | |
prompt = "The cute dog walks in the green park" | |
encoded = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt").input_ids | |
batch_size =1 | |
input_length = encoded.shape[1] | |
denoiser = FlaxDataCollatorForT5MLM(tokenizer,.35,1) # okay | |
denoiser = FlaxDataCollatorForT5MLM(tokenizer,.55,1) # not ok | |
mask_indices = np.asarray([denoiser.random_spans_noise_mask(input_length) for i in range(batch_size)]) | |
labels_mask = ~mask_indices | |
input_ids_sentinel = denoiser.create_sentinel_ids(mask_indices.astype(np.int8)) | |
labels_sentinel = denoiser.create_sentinel_ids(labels_mask.astype(np.int8)) | |
input_ids = denoiser.filter_input_ids(encoded, input_ids_sentinel) | |
labels = denoiser.filter_input_ids(encoded, labels_sentinel) | |
print(f"input_ids decoded = {tokenizer.decode(*input_ids,skip_special_tokens=False)}") | |
print(f"labels decoded = {tokenizer.decode(*labels,skip_special_tokens=False)}") | |
print(f"input_ids.shape {input_ids.shape} labels.shape {labels.shape}") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Ouput