Skip to content

Instantly share code, notes, and snippets.

@jbrry
Created October 21, 2021 15:23
Show Gist options
  • Save jbrry/ae58e02eb5bc881fddad9a27b01e6946 to your computer and use it in GitHub Desktop.
Save jbrry/ae58e02eb5bc881fddad9a27b01e6946 to your computer and use it in GitHub Desktop.
debug labels in batch
# replace function: https://github.com/huggingface/transformers/blob/f9c16b02e3f5d2ee0a1cadb6f50dc9e3281e2536/src/transformers/data/data_collator.py#L78
def torch_default_data_collator(features: List[InputDataClass]) -> Dict[str, Any]:
"""place this function in transformers/data/data_collator.py"""
import torch
if not isinstance(features[0], (dict, BatchEncoding)):
features = [vars(f) for f in features]
first = features[0]
batch = {}
# Special handling for labels.
# Ensure that tensor is created with the correct type
# (it should be automatically the case, but let's make sure of it.)
if "label" in first and first["label"] is not None:
label = first["label"].item() if isinstance(first["label"], torch.Tensor) else first["label"]
dtype = torch.long if isinstance(label, int) else torch.float
batch["labels"] = torch.tensor([f["label"] for f in features], dtype=dtype)
elif "label_ids" in first and first["label_ids"] is not None:
if isinstance(first["label_ids"], torch.Tensor):
batch["labels"] = torch.stack([f["label_ids"] for f in features])
else:
dtype = torch.long if type(first["label_ids"][0]) is int else torch.float
batch["labels"] = torch.tensor([f["label_ids"] for f in features], dtype=dtype)
# Handling of all other possible keys.
# Again, we will use the first element to figure out which key/values are not None for this model.
for k, v in first.items():
if k not in ("label", "label_ids") and v is not None and not isinstance(v, str):
if isinstance(v, torch.Tensor):
batch[k] = torch.stack([f[k] for f in features])
else:
batch[k] = torch.tensor([f[k] for f in features])
# ** Investigate certain labels in batch **
# these are the labels you are interested in, e.g. word is a 'verb', a multiword expression etc.
# the numbers are the label2id indices
ACTIVE_LABELS = torch.Tensor([27, 35, 29, 3, 8])
# counter for the labels
num_batch_active_labels = 0
labels = batch["labels_rels"] # usually this will be called labels
for i, sent in enumerate(labels):
for label in sent:
if label in ACTIVE_LABELS:
num_batch_active_labels += 1
print(f"Number of active labels: \t {num_batch_active_labels}") # alternatively log this to a file
return batch
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment