Skip to content

Instantly share code, notes, and snippets.

@pszemraj
Last active July 6, 2023 17:03
Show Gist options
  • Save pszemraj/499293d34440a968f498562df2c4c74a to your computer and use it in GitHub Desktop.
Save pszemraj/499293d34440a968f498562df2c4c74a to your computer and use it in GitHub Desktop.
import logging
import warnings
from typing import List, Optional, Union
import numpy as np
import torch
from torch.nn import functional as F
from tqdm.auto import trange
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, RwkvModel
import torch
from torch.nn import functional as F
# helper class
class MultiLevelPooler:
@staticmethod
def get_pooling_modes():
return [
"mean",
"max",
"last_token",
"weighted_mean",
"geometric_mean",
"harmonic_mean",
]
def __init__(self, pooling_strategy="mean", max_states: int = 3):
self.POOLING_METHODS = {
"mean": (self._mean_pool, True),
"max": (self._max_pool, True),
"last_token": (self._last_token_pool, False),
"weighted_mean": (self._weighted_mean_pool, True),
"geometric_mean": (self._geometric_mean_pool, True),
"harmonic_mean": (self._harmonic_mean_pool, True),
}
if pooling_strategy not in self.POOLING_METHODS:
raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
self.pooling_strategy = pooling_strategy
self.max_states = max_states
def _mean_pool(self, last_hidden_states, attention_mask):
attention_mask_expanded = attention_mask.unsqueeze(1).expand(
last_hidden_states.shape
).float()
last_hidden = last_hidden_states * attention_mask_expanded
return last_hidden.sum(dim=-2) / attention_mask_expanded.sum(dim=-2)
def _last_token_pool(self, last_hidden_states):
return last_hidden_states[:, :, -1, :]
def _max_pool(self, last_hidden_states, attention_mask):
attention_mask_expanded = (
attention_mask.unsqueeze(1).expand(
last_hidden_states.shape
).float()
)
last_hidden = last_hidden_states.masked_fill(attention_mask_expanded == 0, -1e9)
return last_hidden.max(dim=-2).values
def _weighted_mean_pool(self, last_hidden_states, attention_mask):
attention_mask_expanded = attention_mask.unsqueeze(1).expand(
-1, last_hidden_states.shape[1], -1
).float()
weights = (
torch.arange(start=1, end=last_hidden_states.shape[2] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_states.size()[:-1])
.float()
.to(last_hidden_states.device)
)
sum_embeddings = torch.sum(
last_hidden_states * attention_mask_expanded * weights, dim=2
)
sum_mask = torch.sum(attention_mask_expanded * weights, dim=2)
return sum_embeddings / sum_mask
def _geometric_mean_pool(self, last_hidden_states, attention_mask):
attention_mask_expanded = (
attention_mask.unsqueeze(1).repeat(1, last_hidden_states.shape[1], 1).float()
)
last_hidden = last_hidden_states * attention_mask_expanded
log_last_hidden = torch.log(last_hidden.clamp(min=1e-9))
geometric_mean_log = log_last_hidden.sum(dim=-2) / attention_mask_expanded.sum(dim=-2)
return geometric_mean_log.exp()
def _harmonic_mean_pool(self, last_hidden_states, attention_mask):
attention_mask_expanded = attention_mask.unsqueeze(1).expand(
last_hidden_states.shape
).float()
last_hidden = last_hidden_states * attention_mask_expanded
last_hidden_reciprocal = 1.0 / last_hidden.clamp(min=1e-9)
harmonic_mean_reciprocal = last_hidden_reciprocal.sum(dim=-2) / attention_mask_expanded.sum(dim=-2)
return 1.0 / harmonic_mean_reciprocal.clamp(min=1e-9)
def forward(self, states, last_hidden_state, attention_mask, max_states: int = None, pooling_strategy: str = None):
if max_states is None:
max_states = self.max_states
if pooling_strategy is None:
pooling_strategy = self.pooling_strategy
pooling_mode, needs_attention_mask = self.POOLING_METHODS[pooling_strategy]
pooled_states = []
for i, state in enumerate(states[:max_states], start=1):
if needs_attention_mask:
state_pool = pooling_mode(last_hidden_states=state, attention_mask=attention_mask)
else:
state_pool = pooling_mode(last_hidden_state)
pooled_states.append(state_pool)
pooled_states_tensor = torch.stack(pooled_states, dim=-1)
final_state = pooled_states_tensor.mean(dim=-1)
last_hidden_state_pooled = (
pooling_mode(last_hidden_state, attention_mask)
if needs_attention_mask
else pooling_mode(last_hidden_state)
)
output = (final_state + last_hidden_state_pooled) / 2
return output
# @title define RwkvEmbedder
class RwkvEmbedder:
"""Generate sentence embedding for given text using RWKV models.
Models: ttps://huggingface.co/RWKV
cODE: https://github.com/BlinkDL/ChatRWKV/
Parameters:
model_name : str, optional
The name of the pre-trained RWKV model (default is "RWKV/rwkv-4-169m-pile").
batch_size: int, optional
The size of the batches to use when generating embeddings. (default is 16).
Attributes:
model : transformers.PreTrainedModel
The pre-trained model loaded from Hugging Face's model hub.
tokenizer : transformers.PreTrainedTokenizer
The tokenizer corresponding to the pre-trained model.
dimension : int
The dimension of the embeddings produced by the model.
batch_size : int
The size of the batches to use when generating embeddings.
"""
def __init__(
self,
model_name: str = "RWKV/rwkv-4-169m-pile",
batch_size: int = 16,
torch_dtype=None,
device_map="auto",
pooling_strategy="mean",
):
self.model = self.load_model(model_name)
self.tokenizer = self.load_tokenizer(model_name)
# If the tokenizer does not have a padding token, set it to be the eos_token
if self.tokenizer.pad_token is None:
if self.tokenizer.eos_token is not None:
self.tokenizer.pad_token = self.tokenizer.eos_token
else:
# Add a new pad token if eos_token is not available
self.tokenizer.add_tokens(["[PAD]"])
self.tokenizer.pad_token = "[PAD]"
self.dimension = self.model.config.hidden_size
self.batch_size = batch_size
self.pooling_strategy = pooling_strategy
self.max_length = self.model.config.context_length or 1024
if (
self.tokenizer.model_max_length is None
or self.tokenizer.model_max_length != self.max_length
):
self.tokenizer.model_max_length = self.max_length
@staticmethod
def load_model(
model_name: str, torch_dtype=None, device_map="auto"
) -> PreTrainedModel:
"""Load a pre-trained model from Hugging Face's model hub.
Parameters:
model_name : str
The name of the pre-trained model.
Returns:
transformers.PreTrainedModel
The loaded pre-trained model.
"""
try:
model = RwkvModel.from_pretrained(
model_name, torch_dtype=torch_dtype, device_map=device_map
)
model.eval()
return model
except Exception as e:
raise ValueError(f"Unable to load model {model_name}. Error: {str(e)}")
@staticmethod
def load_tokenizer(model_name: str, use_fast: bool = True) -> PreTrainedTokenizer:
"""Load a tokenizer from Hugging Face's model hub.
Parameters:
model_name : str
The name of the pre-trained model.
Returns:
transformers.PreTrainedTokenizer
The loaded tokenizer.
"""
try:
return AutoTokenizer.from_pretrained(model_name, use_fast=use_fast)
except Exception as e:
raise ValueError(
f"Unable to load tokenizer for model {model_name}. Error: {str(e)}"
)
@staticmethod
def normalize_embeddings(embeddings: torch.Tensor):
"""
Normalizing the embedding matrix so that each sentence embedding has unit length
"""
return torch.nn.functional.normalize(embeddings, p=2, dim=1)
@staticmethod
def _mean_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
# We mask the padding tokens, so they don't affect the average.
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 0.0
)
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
@staticmethod
def _last_token_pool(last_hidden_states: torch.Tensor) -> torch.Tensor:
return last_hidden_states[:, -1, :]
@staticmethod
def _max_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
# We add a very large negative number for padding tokens,
# so they effectively never get chosen by the max operation.
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), -1e9
)
return last_hidden.max(dim=1).values
@staticmethod
def _weighted_mean_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
"""
weighted mean token embeddings for autoregressive LMs
Credit: SGPT @ https://github.com/Muennighoff/sgpt
"""
weights = (
torch.arange(start=1, end=last_hidden_states.shape[1] + 1)
.unsqueeze(0)
.unsqueeze(-1)
.expand(last_hidden_states.size())
.float()
.to(last_hidden_states.device)
)
input_mask_expanded = (
attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float()
)
sum_embeddings = torch.sum(
last_hidden_states * input_mask_expanded * weights, dim=1
)
sum_mask = torch.sum(input_mask_expanded * weights, dim=1)
return sum_embeddings / sum_mask
@staticmethod
def _geometric_mean_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
# We add a very small value for padding tokens,
# so they don't affect the geometric mean calculation.
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 1e-9
)
# Compute the geometric mean.
log_last_hidden = torch.log(last_hidden)
geometric_mean_log = log_last_hidden.mean(dim=1)
return geometric_mean_log.exp()
@staticmethod
def _harmonic_mean_pool(
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
# We add a very small value for padding tokens,
# so they don't affect the harmonic mean calculation.
last_hidden = last_hidden_states.masked_fill(
~attention_mask[..., None].bool(), 1e-9
)
# Compute the harmonic mean.
last_hidden_reciprocal = 1.0 / last_hidden
harmonic_mean_reciprocal = last_hidden_reciprocal.mean(dim=1)
return 1.0 / harmonic_mean_reciprocal
POOLING_METHODS = {
"mean": (_mean_pool, True),
"max": (_max_pool, True),
"last_token": (_last_token_pool, False),
"weighted_mean": (_weighted_mean_pool, True),
"geometric_mean": (_weighted_mean_pool, True),
"harmonic_mean": (_weighted_mean_pool, True),
}
def _text_length(self, text: Union[List[int], List[List[int]]]):
"""
Help function to get the length for the input text. Text can be either
a list of ints (which means a single text as input), or a tuple of list of ints
(representing several text inputs to the model).
Credit: SBERT
https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py#L324
"""
if isinstance(text, dict): # {key: value} case
return len(next(iter(text.values())))
elif not hasattr(text, "__len__"): # Object has no len() method
return 1
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints
return len(text)
else:
return sum([len(t) for t in text]) # Sum of length of individual strings
def simple_embeddings(
self, texts: Union[str, list], pooling_strategy=None, NORMALIZE: bool = False
) -> np.ndarray:
"""Generate embeddings for given text or list of texts.
Parameters:
texts : str or list
The input text(s) to generate embeddings for.
Returns:
np.ndarray
The embeddings for each text in the shape of (len(texts), dimension).
"""
if pooling_strategy is None:
pooling_strategy = self.pooling_strategy
if pooling_strategy not in self.POOLING_METHODS:
raise ValueError(f"Unknown pooling strategy: {pooling_strategy}")
pooling_method, needs_attention_mask = self.POOLING_METHODS[pooling_strategy]
if isinstance(texts, str):
texts = [texts]
# Compute lengths and sort by length
lengths = [self._text_length(text) for text in texts]
sorted_indices = np.argsort(lengths)
sorted_texts = [texts[i] for i in sorted_indices]
embeddings = []
for i in trange(0, len(sorted_texts), self.batch_size):
batch = sorted_texts[i : i + self.batch_size]
inputs = self.tokenizer(
batch,
return_tensors="pt",
truncation=True,
padding=max_length, # or True
max_length=self.max_length,
).to(self.model.device)
with torch.no_grad():
outputs = self.model(
**inputs, output_attentions=True, output_hidden_states=True
)
if needs_attention_mask:
emb = pooling_method(
outputs.last_hidden_state, inputs["attention_mask"]
)
else:
emb = pooling_method(outputs.last_hidden_state)
if NORMALIZE:
emb = self.normalize_embeddings(emb)
embeddings.extend(emb.cpu().numpy().astype("float32"))
# Resort the embeddings to their original order
embeddings = [embeddings[i] for i in np.argsort(sorted_indices)]
return np.array(embeddings)
def summary(self):
"""
Provides a summary of the key parameters of the model.
"""
print(f"Model Name: {self.model.name_or_path}")
print(f"Batch Size: {self.batch_size}")
print(f"Token Length: {self.max_length}")
print(f"Pooling Strategy: {self.pooling_strategy}")
print(f"Device: {next(self.model.parameters()).device}")
def state_aware_embedding(
self,
texts: Union[str, list],
pooling_strategy: str = None,
max_states: int = 3,
NORMALIZE: bool = False,
) -> np.ndarray:
"""Generate embeddings for given text or list of texts.
Parameters:
texts : str or list
The input text(s) to generate embeddings for.
Returns:
np.ndarray
The embeddings for each text in the shape of (len(texts), dimension).
"""
pooling_strategy = self.pooling_strategy if pooling_strategy is None else pooling_strategy
if not self.pooling_strategy in MultiLevelPooler.get_pooling_modes():
logging.warning(
f"Unknown pooling strategy: {pooling_strategy}. Using 'max' pooling instead."
)
pooling_strategy = "max"
if isinstance(texts, str):
texts = [texts]
pooler = MultiLevelPooler(pooling_strategy=pooling_strategy, max_states=max_states)
# Compute lengths and sort by length
lengths = [self._text_length(text) for text in texts]
sorted_indices = np.argsort(lengths)
sorted_texts = [texts[i] for i in sorted_indices]
embeddings = []
for i in trange(0, len(sorted_texts), self.batch_size):
batch = sorted_texts[i : i + self.batch_size]
inputs = self.tokenizer(
batch,
return_tensors="pt",
truncation=True,
padding="longest", # or True
max_length=self.max_length,
).to(self.model.device)
with torch.no_grad():
outputs = self.model(
**inputs, output_attentions=True, output_hidden_states=True
)
attention_mask = expand_attention_mask(inputs["attention_mask"], outputs.state[0].shape)
emb = pooler.forward(outputs.state, outputs.last_hidden_state,attention_mask)
# inputs["attention_mask"])
if NORMALIZE:
emb = self.normalize_embeddings(emb)
embeddings.extend(emb.cpu().numpy().astype("float32"))
# Resort the embeddings to their original order
embeddings = [embeddings[i] for i in np.argsort(sorted_indices)]
return np.array(embeddings)
"""
RWKV embeddings with transformers (original starting point)
⚠️ embeddings from this are pretty poor because RWKV has special outputs ⚠️
credit: GPTcache
https://gptcache.readthedocs.io/en/latest/_modules/gptcache/embedding/rwkv.html#Rwkv
"""
import numpy as np
from gptcache.embedding.base import BaseEmbedding
from gptcache.utils import import_huggingface
import_huggingface()
from transformers import AutoTokenizer, RwkvModel # pylint: disable=C0413
class Rwkv(BaseEmbedding):
"""Generate sentence embedding for given text using RWKV models.
:param model: model name, defaults to 'RWKV/rwkv-4-169m-pile'. Check
https://huggingface.co/docs/transformers/model_doc/rwkv for more avaliable models.
:type model: str
Example:
.. code-block:: python
from gptcache.embedding import Rwkv
test_sentence = 'Hello, world.'
encoder = Rwkv(model='RWKV/rwkv-4-169m-pile')
embed = encoder.to_embeddings(test_sentence)
"""
def __init__(self, model: str = "RWKV/rwkv-4-169m-pile"):
self.model = RwkvModel.from_pretrained(model)
self.model.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model)
try:
self.__dimension = self.model.config.hidden_size
except Exception: # pylint: disable=W0703
from transformers import AutoConfig # pylint: disable=C0415
config = AutoConfig.from_pretrained(model)
self.__dimension = config.hidden_size
def to_embeddings(self, data, **_):
"""Generate embedding given text input
:param data: text in string.
:type data: str
:return: a text embedding in shape of (dim,).
"""
inputs = self.tokenizer(data, return_tensors="pt")
outputs = self.model(inputs["input_ids"])
emb = outputs.last_hidden_state[0, 0, :].detach().numpy()
return np.array(emb).astype("float32")
@property
def dimension(self):
"""Embedding dimension.
:return: embedding dimension
"""
return self.__dimension
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment