Skip to content

Instantly share code, notes, and snippets.

@alexcpn
Last active April 22, 2023 14:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save alexcpn/b9bb2b0f01833d1bb862502faf99bab8 to your computer and use it in GitHub Desktop.
Save alexcpn/b9bb2b0f01833d1bb862502faf99bab8 to your computer and use it in GitHub Desktop.
from transformers import T5Tokenizer
import numpy as np
class FlaxDataCollatorForT5MLM:
"""
From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
"""
def __init__(self,tokenizer,noise_density,mean_noise_span_length) -> None:
self.tokenizer = tokenizer
self.noise_density = noise_density
self.mean_noise_span_length =mean_noise_span_length
def create_sentinel_ids(self, mask_indices):
"""
Sentinel ids creation given the indices that should be masked.
The start indices of each mask are replaced by the sentinel ids in increasing
order. Consecutive mask indices to be deleted are replaced with `-1`.
"""
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
start_indices[:, 0] = mask_indices[:, 0]
sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
sentinel_ids -= mask_indices - start_indices
return sentinel_ids
def filter_input_ids(self, input_ids, sentinel_ids):
"""
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
"""
batch_size = input_ids.shape[0]
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
# masked tokens coming after sentinel tokens and should be removed
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
input_ids = np.concatenate(
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
)
return input_ids
def random_spans_noise_mask(self, length):
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args:
length: an int32 scalar (length of the incoming token sequence)
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
a boolean tensor with shape [length]
"""
orig_length = length
num_noise_tokens = int(np.round(length * self.noise_density))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
#num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
num_nonnoise_tokens = length - num_noise_tokens
num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
#num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens
#num_noise_spans = int(np.round(min(num_noise_tokens,num_nonnoise_tokens) / self.mean_noise_span_length))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = max(num_noise_spans, 1)
# pick the lengths of the noise spans and the non-noise spans
def _random_segmentation(num_items, num_segments):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segments: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segments] containing positive integers that add
up to num_items
"""
if num_segments <= num_items:
print(f"num_segments {num_segments} <=num_segments {num_items}!! this is going to be a problem")
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]])
segment_id = np.cumsum(first_in_segment)
# count length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
interleaved_span_lengths = np.reshape(
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
)
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = np.zeros((length,), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
if __name__ == '__main__':
model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
len_tokenizer =len(tokenizer) # 32100 to get the sentinel ids
print(f"len_tokenizer={len_tokenizer}")
# Unsupervised denoising training
# https://huggingface.co/docs/transformers/main/model_doc/t5#training
print("-"*20)
prompt = "The <extra_id_0> walks in <extra_id_1> park"
encoded_prompt = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt").input_ids
print(f"encoded_prompt ={encoded_prompt}")
labels ="<extra_id_0> cute dog <extra_id_1> the <extra_id_2>"
encoded_labels = tokenizer(labels, truncation=False, padding=False, return_tensors="pt").input_ids
print(f"encoded_labels ={encoded_labels}")
print(f"{encoded_prompt.shape} ={encoded_labels.shape}")
print("-"*20)
# simulating the above
prompt = "The cute dog walks in the green park"
encoded = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt").input_ids
batch_size =1
input_length = encoded.shape[1]
denoiser = FlaxDataCollatorForT5MLM(tokenizer,.35,1) # okay
denoiser = FlaxDataCollatorForT5MLM(tokenizer,.55,1) # not ok
mask_indices = np.asarray([denoiser.random_spans_noise_mask(input_length) for i in range(batch_size)])
labels_mask = ~mask_indices
input_ids_sentinel = denoiser.create_sentinel_ids(mask_indices.astype(np.int8))
labels_sentinel = denoiser.create_sentinel_ids(labels_mask.astype(np.int8))
input_ids = denoiser.filter_input_ids(encoded, input_ids_sentinel)
labels = denoiser.filter_input_ids(encoded, labels_sentinel)
print(f"input_ids decoded = {tokenizer.decode(*input_ids,skip_special_tokens=False)}")
print(f"labels decoded = {tokenizer.decode(*labels,skip_special_tokens=False)}")
print(f"input_ids.shape {input_ids.shape} labels.shape {labels.shape}")
from transformers import T5Tokenizer
import numpy as np
class FlaxDataCollatorForT5MLM:
"""
From https://github.com/huggingface/transformers/blob/main/examples/flax/language-modeling/run_t5_mlm_flax.py
"""
def __init__(self,tokenizer,noise_density,mean_noise_span_length) -> None:
self.tokenizer = tokenizer
self.noise_density = noise_density
self.mean_noise_span_length =mean_noise_span_length
def create_sentinel_ids(self, mask_indices):
"""
Sentinel ids creation given the indices that should be masked.
The start indices of each mask are replaced by the sentinel ids in increasing
order. Consecutive mask indices to be deleted are replaced with `-1`.
"""
start_indices = mask_indices - np.roll(mask_indices, 1, axis=-1) * mask_indices
start_indices[:, 0] = mask_indices[:, 0]
sentinel_ids = np.where(start_indices != 0, np.cumsum(start_indices, axis=-1), start_indices)
sentinel_ids = np.where(sentinel_ids != 0, (len(self.tokenizer) - sentinel_ids), 0)
sentinel_ids -= mask_indices - start_indices
return sentinel_ids
def filter_input_ids(self, input_ids, sentinel_ids):
"""
Puts sentinel mask on `input_ids` and fuse consecutive mask tokens into a single mask token by deleting.
This will reduce the sequence length from `expanded_inputs_length` to `input_length`.
"""
batch_size = input_ids.shape[0]
input_ids_full = np.where(sentinel_ids != 0, sentinel_ids, input_ids)
# input_ids tokens and sentinel tokens are >= 0, tokens < 0 are
# masked tokens coming after sentinel tokens and should be removed
input_ids = input_ids_full[input_ids_full >= 0].reshape((batch_size, -1))
input_ids = np.concatenate(
[input_ids, np.full((batch_size, 1), self.tokenizer.eos_token_id, dtype=np.int32)], axis=-1
)
return input_ids
def random_spans_noise_mask(self, length):
"""This function is copy of `random_spans_helper <https://github.com/google-research/text-to-text-transfer-transformer/blob/84f8bcc14b5f2c03de51bd3587609ba8f6bbd1cd/t5/data/preprocessors.py#L2682>`__ .
Noise mask consisting of random spans of noise tokens.
The number of noise tokens and the number of noise spans and non-noise spans
are determined deterministically as follows:
num_noise_tokens = round(length * noise_density)
num_nonnoise_spans = num_noise_spans = round(num_noise_tokens / mean_noise_span_length)
Spans alternate between non-noise and noise, beginning with non-noise.
Subject to the above restrictions, all masks are equally likely.
Args:
length: an int32 scalar (length of the incoming token sequence)
noise_density: a float - approximate density of output mask
mean_noise_span_length: a number
Returns:
a boolean tensor with shape [length]
"""
orig_length = length
num_noise_tokens = int(np.round(length * self.noise_density))
# avoid degeneracy by ensuring positive numbers of noise and nonnoise tokens.
num_noise_tokens = min(max(num_noise_tokens, 1), length - 1)
#num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length))
num_nonnoise_tokens = length - num_noise_tokens
#num_noise_spans = int(np.round(num_noise_tokens / self.mean_noise_span_length)) <--- Bug
#num_noise_tokens should be less than num_noise_tokens and num_nonnoise_tokens
num_noise_spans = int(np.round(min(num_noise_tokens,num_nonnoise_tokens) / self.mean_noise_span_length))
# avoid degeneracy by ensuring positive number of noise spans
num_noise_spans = max(num_noise_spans, 1)
# pick the lengths of the noise spans and the non-noise spans
def _random_segmentation(num_items, num_segments):
"""Partition a sequence of items randomly into non-empty segments.
Args:
num_items: an integer scalar > 0
num_segments: an integer scalar in [1, num_items]
Returns:
a Tensor with shape [num_segments] containing positive integers that add
up to num_items
"""
if num_segments <= num_items:
print(f"num_segments {num_segments} <=num_segments {num_items}!! this is going to be a problem")
mask_indices = np.arange(num_items - 1) < (num_segments - 1)
np.random.shuffle(mask_indices)
first_in_segment = np.pad(mask_indices, [[1, 0]])
segment_id = np.cumsum(first_in_segment)
# count length of sub segments assuming that list is sorted
_, segment_length = np.unique(segment_id, return_counts=True)
return segment_length
noise_span_lengths = _random_segmentation(num_noise_tokens, num_noise_spans)
nonnoise_span_lengths = _random_segmentation(num_nonnoise_tokens, num_noise_spans)
interleaved_span_lengths = np.reshape(
np.stack([nonnoise_span_lengths, noise_span_lengths], axis=1), [num_noise_spans * 2]
)
span_starts = np.cumsum(interleaved_span_lengths)[:-1]
span_start_indicator = np.zeros((length,), dtype=np.int8)
span_start_indicator[span_starts] = True
span_num = np.cumsum(span_start_indicator)
is_noise = np.equal(span_num % 2, 1)
return is_noise[:orig_length]
if __name__ == '__main__':
model_name = 't5-base'
tokenizer = T5Tokenizer.from_pretrained(model_name)
len_tokenizer =len(tokenizer) # 32100 to get the sentinel ids
print(f"len_tokenizer={len_tokenizer}")
# Unsupervised denoising training
# https://huggingface.co/docs/transformers/main/model_doc/t5#training
print("-"*20)
prompt = "The <extra_id_0> walks in <extra_id_1> park"
encoded_prompt = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt").input_ids
print(f"encoded_prompt ={encoded_prompt}")
labels ="<extra_id_0> cute dog <extra_id_1> the <extra_id_2>"
encoded_labels = tokenizer(labels, truncation=False, padding=False, return_tensors="pt").input_ids
print(f"encoded_labels ={encoded_labels}")
print(f"{encoded_prompt.shape} ={encoded_labels.shape}")
print("-"*20)
# simulating the above
prompt = "The cute dog walks in the green park"
encoded = tokenizer(prompt, truncation=False, padding=False, return_tensors="pt").input_ids
batch_size =1
input_length = encoded.shape[1]
denoiser = FlaxDataCollatorForT5MLM(tokenizer,.35,1) # okay
denoiser = FlaxDataCollatorForT5MLM(tokenizer,.55,1) # not ok
mask_indices = np.asarray([denoiser.random_spans_noise_mask(input_length) for i in range(batch_size)])
labels_mask = ~mask_indices
input_ids_sentinel = denoiser.create_sentinel_ids(mask_indices.astype(np.int8))
labels_sentinel = denoiser.create_sentinel_ids(labels_mask.astype(np.int8))
input_ids = denoiser.filter_input_ids(encoded, input_ids_sentinel)
labels = denoiser.filter_input_ids(encoded, labels_sentinel)
print(f"input_ids decoded = {tokenizer.decode(*input_ids,skip_special_tokens=False)}")
print(f"labels decoded = {tokenizer.decode(*labels,skip_special_tokens=False)}")
print(f"input_ids.shape {input_ids.shape} labels.shape {labels.shape}")
@alexcpn
Copy link
Author

alexcpn commented Mar 18, 2023

Ouput

len_tokenizer=32100
--------------------
encoded_prompt =tensor([[   37, 32099, 10681,    16, 32098,  2447,     1]])
encoded_labels =tensor([[32099,  5295,  1782, 32098,     8, 32097,     1]])
torch.Size([1, 7]) =torch.Size([1, 7])
--------------------
input_ids decoded = The<extra_id_0> dog walks in the<extra_id_1> park<extra_id_2></s>
labels decoded   = <extra_id_0> cute<extra_id_1> green<extra_id_2></s></s>
input_ids.shape (1, 10) should be equal to labels.shape (1, 7)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment