Skip to content

Instantly share code, notes, and snippets.

@Helw150
Last active April 27, 2023 22:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Helw150/d4303c7c91406eca9a6c1d28b3713dbc to your computer and use it in GitHub Desktop.
Save Helw150/d4303c7c91406eca9a6c1d28b3713dbc to your computer and use it in GitHub Desktop.
OT TADA Loss
from typing import List, Optional, Tuple, Union
from torchtyping import TensorType
from transformers.adapters.modeling import Adapter
from transformers.adapters import (
BartAdapterModel,
RobertaAdapterModel,
BertAdapterModel,
AdapterConfig,
)
import torch
from torch import nn
from geomloss import SamplesLoss
class AlignmentMixin(nn.Module):
def __init__(self, config):
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
super().__init__(config)
self.earth_mover_loss = SamplesLoss(loss="sinkhorn", p=2)
@torch.no_grad()
def produce_original_embeddings(
self,
input_ids: TensorType["batch", "seq_len"],
attention_mask: TensorType["batch", "seq_len"],
token_type_ids: Optional[TensorType["batch", "seq_len"]] = None,
position_ids: Optional[TensorType["batch", "seq_len"]] = None,
head_mask: Optional[TensorType["layers", "heads"]] = None,
) -> TensorType["batch", "seq_len", "hidden_size"]:
self.train(False)
outputs = super().forward(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
if "last_hidden_state" in outputs:
hidden_mat = outputs.last_hidden_state
else:
hidden_mat = outputs.encoder_last_hidden_state
self.train(True)
return outputs.last_hidden_state, attention_mask
def get_weight(self, mask):
probs = mask / mask.sum(1).reshape(-1, 1)
return probs
def forward(
self,
input_ids: TensorType["batch", "seq_len"],
attention_mask: TensorType["batch", "seq_len"],
original_embedding: Optional[
TensorType["batch", "layers", "hidden_size"]
] = None,
original_mask: TensorType["batch", "seq_len"],
token_type_ids: Optional[TensorType["batch", "seq_len"]] = None,
position_ids: Optional[TensorType["batch", "seq_len"]] = None,
head_mask: Optional[TensorType["layers", "heads"]] = None,
**kwargs
):
if type(original_embedding) != type(None):
outputs = super().forward(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
output_attentions=False,
output_hidden_states=True,
return_dict=True,
)
if "last_hidden_state" in outputs:
hidden_mat = outputs.last_hidden_state
else:
hidden_mat = outputs.encoder_last_hidden_state
alignment_loss = self.earth_mover_loss(
self.get_weight(attention_mask), hidden_mat, self.get_weight(original_mask), original_embeddings
)
return (alignment_loss,)
class BartAdapterModelForAlignment(AlignmentMixin, BartAdapterModel):
def __init__(self, config):
config.dropout = 0.0
config.activation_dropout = 0.0
config.attention_dropout = 0.0
config.classifier_dropout = 0.0
super().__init__(config)
class RobertaAdapterModelForAlignment(AlignmentMixin, RobertaAdapterModel):
def __init__(self, config):
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
super().__init__(config)
class BertAdapterModelForAlignment(AlignmentMixin, BertAdapterModel):
def __init__(self, config):
config.hidden_dropout_prob = 0.0
config.attention_probs_dropout_prob = 0.0
super().__init__(config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment