Last active
July 6, 2023 17:03
-
-
Save pszemraj/499293d34440a968f498562df2c4c74a to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import logging | |
import warnings | |
from typing import List, Optional, Union | |
import numpy as np | |
import torch | |
from torch.nn import functional as F | |
from tqdm.auto import trange | |
from transformers import AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, RwkvModel | |
import torch | |
from torch.nn import functional as F | |
# helper class | |
class MultiLevelPooler: | |
@staticmethod | |
def get_pooling_modes(): | |
return [ | |
"mean", | |
"max", | |
"last_token", | |
"weighted_mean", | |
"geometric_mean", | |
"harmonic_mean", | |
] | |
def __init__(self, pooling_strategy="mean", max_states: int = 3): | |
self.POOLING_METHODS = { | |
"mean": (self._mean_pool, True), | |
"max": (self._max_pool, True), | |
"last_token": (self._last_token_pool, False), | |
"weighted_mean": (self._weighted_mean_pool, True), | |
"geometric_mean": (self._geometric_mean_pool, True), | |
"harmonic_mean": (self._harmonic_mean_pool, True), | |
} | |
if pooling_strategy not in self.POOLING_METHODS: | |
raise ValueError(f"Unknown pooling strategy: {pooling_strategy}") | |
self.pooling_strategy = pooling_strategy | |
self.max_states = max_states | |
def _mean_pool(self, last_hidden_states, attention_mask): | |
attention_mask_expanded = attention_mask.unsqueeze(1).expand( | |
last_hidden_states.shape | |
).float() | |
last_hidden = last_hidden_states * attention_mask_expanded | |
return last_hidden.sum(dim=-2) / attention_mask_expanded.sum(dim=-2) | |
def _last_token_pool(self, last_hidden_states): | |
return last_hidden_states[:, :, -1, :] | |
def _max_pool(self, last_hidden_states, attention_mask): | |
attention_mask_expanded = ( | |
attention_mask.unsqueeze(1).expand( | |
last_hidden_states.shape | |
).float() | |
) | |
last_hidden = last_hidden_states.masked_fill(attention_mask_expanded == 0, -1e9) | |
return last_hidden.max(dim=-2).values | |
def _weighted_mean_pool(self, last_hidden_states, attention_mask): | |
attention_mask_expanded = attention_mask.unsqueeze(1).expand( | |
-1, last_hidden_states.shape[1], -1 | |
).float() | |
weights = ( | |
torch.arange(start=1, end=last_hidden_states.shape[2] + 1) | |
.unsqueeze(0) | |
.unsqueeze(-1) | |
.expand(last_hidden_states.size()[:-1]) | |
.float() | |
.to(last_hidden_states.device) | |
) | |
sum_embeddings = torch.sum( | |
last_hidden_states * attention_mask_expanded * weights, dim=2 | |
) | |
sum_mask = torch.sum(attention_mask_expanded * weights, dim=2) | |
return sum_embeddings / sum_mask | |
def _geometric_mean_pool(self, last_hidden_states, attention_mask): | |
attention_mask_expanded = ( | |
attention_mask.unsqueeze(1).repeat(1, last_hidden_states.shape[1], 1).float() | |
) | |
last_hidden = last_hidden_states * attention_mask_expanded | |
log_last_hidden = torch.log(last_hidden.clamp(min=1e-9)) | |
geometric_mean_log = log_last_hidden.sum(dim=-2) / attention_mask_expanded.sum(dim=-2) | |
return geometric_mean_log.exp() | |
def _harmonic_mean_pool(self, last_hidden_states, attention_mask): | |
attention_mask_expanded = attention_mask.unsqueeze(1).expand( | |
last_hidden_states.shape | |
).float() | |
last_hidden = last_hidden_states * attention_mask_expanded | |
last_hidden_reciprocal = 1.0 / last_hidden.clamp(min=1e-9) | |
harmonic_mean_reciprocal = last_hidden_reciprocal.sum(dim=-2) / attention_mask_expanded.sum(dim=-2) | |
return 1.0 / harmonic_mean_reciprocal.clamp(min=1e-9) | |
def forward(self, states, last_hidden_state, attention_mask, max_states: int = None, pooling_strategy: str = None): | |
if max_states is None: | |
max_states = self.max_states | |
if pooling_strategy is None: | |
pooling_strategy = self.pooling_strategy | |
pooling_mode, needs_attention_mask = self.POOLING_METHODS[pooling_strategy] | |
pooled_states = [] | |
for i, state in enumerate(states[:max_states], start=1): | |
if needs_attention_mask: | |
state_pool = pooling_mode(last_hidden_states=state, attention_mask=attention_mask) | |
else: | |
state_pool = pooling_mode(last_hidden_state) | |
pooled_states.append(state_pool) | |
pooled_states_tensor = torch.stack(pooled_states, dim=-1) | |
final_state = pooled_states_tensor.mean(dim=-1) | |
last_hidden_state_pooled = ( | |
pooling_mode(last_hidden_state, attention_mask) | |
if needs_attention_mask | |
else pooling_mode(last_hidden_state) | |
) | |
output = (final_state + last_hidden_state_pooled) / 2 | |
return output | |
# @title define RwkvEmbedder | |
class RwkvEmbedder: | |
"""Generate sentence embedding for given text using RWKV models. | |
Models: ttps://huggingface.co/RWKV | |
cODE: https://github.com/BlinkDL/ChatRWKV/ | |
Parameters: | |
model_name : str, optional | |
The name of the pre-trained RWKV model (default is "RWKV/rwkv-4-169m-pile"). | |
batch_size: int, optional | |
The size of the batches to use when generating embeddings. (default is 16). | |
Attributes: | |
model : transformers.PreTrainedModel | |
The pre-trained model loaded from Hugging Face's model hub. | |
tokenizer : transformers.PreTrainedTokenizer | |
The tokenizer corresponding to the pre-trained model. | |
dimension : int | |
The dimension of the embeddings produced by the model. | |
batch_size : int | |
The size of the batches to use when generating embeddings. | |
""" | |
def __init__( | |
self, | |
model_name: str = "RWKV/rwkv-4-169m-pile", | |
batch_size: int = 16, | |
torch_dtype=None, | |
device_map="auto", | |
pooling_strategy="mean", | |
): | |
self.model = self.load_model(model_name) | |
self.tokenizer = self.load_tokenizer(model_name) | |
# If the tokenizer does not have a padding token, set it to be the eos_token | |
if self.tokenizer.pad_token is None: | |
if self.tokenizer.eos_token is not None: | |
self.tokenizer.pad_token = self.tokenizer.eos_token | |
else: | |
# Add a new pad token if eos_token is not available | |
self.tokenizer.add_tokens(["[PAD]"]) | |
self.tokenizer.pad_token = "[PAD]" | |
self.dimension = self.model.config.hidden_size | |
self.batch_size = batch_size | |
self.pooling_strategy = pooling_strategy | |
self.max_length = self.model.config.context_length or 1024 | |
if ( | |
self.tokenizer.model_max_length is None | |
or self.tokenizer.model_max_length != self.max_length | |
): | |
self.tokenizer.model_max_length = self.max_length | |
@staticmethod | |
def load_model( | |
model_name: str, torch_dtype=None, device_map="auto" | |
) -> PreTrainedModel: | |
"""Load a pre-trained model from Hugging Face's model hub. | |
Parameters: | |
model_name : str | |
The name of the pre-trained model. | |
Returns: | |
transformers.PreTrainedModel | |
The loaded pre-trained model. | |
""" | |
try: | |
model = RwkvModel.from_pretrained( | |
model_name, torch_dtype=torch_dtype, device_map=device_map | |
) | |
model.eval() | |
return model | |
except Exception as e: | |
raise ValueError(f"Unable to load model {model_name}. Error: {str(e)}") | |
@staticmethod | |
def load_tokenizer(model_name: str, use_fast: bool = True) -> PreTrainedTokenizer: | |
"""Load a tokenizer from Hugging Face's model hub. | |
Parameters: | |
model_name : str | |
The name of the pre-trained model. | |
Returns: | |
transformers.PreTrainedTokenizer | |
The loaded tokenizer. | |
""" | |
try: | |
return AutoTokenizer.from_pretrained(model_name, use_fast=use_fast) | |
except Exception as e: | |
raise ValueError( | |
f"Unable to load tokenizer for model {model_name}. Error: {str(e)}" | |
) | |
@staticmethod | |
def normalize_embeddings(embeddings: torch.Tensor): | |
""" | |
Normalizing the embedding matrix so that each sentence embedding has unit length | |
""" | |
return torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
@staticmethod | |
def _mean_pool( | |
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
) -> torch.Tensor: | |
# We mask the padding tokens, so they don't affect the average. | |
last_hidden = last_hidden_states.masked_fill( | |
~attention_mask[..., None].bool(), 0.0 | |
) | |
return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] | |
@staticmethod | |
def _last_token_pool(last_hidden_states: torch.Tensor) -> torch.Tensor: | |
return last_hidden_states[:, -1, :] | |
@staticmethod | |
def _max_pool( | |
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
) -> torch.Tensor: | |
# We add a very large negative number for padding tokens, | |
# so they effectively never get chosen by the max operation. | |
last_hidden = last_hidden_states.masked_fill( | |
~attention_mask[..., None].bool(), -1e9 | |
) | |
return last_hidden.max(dim=1).values | |
@staticmethod | |
def _weighted_mean_pool( | |
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
) -> torch.Tensor: | |
""" | |
weighted mean token embeddings for autoregressive LMs | |
Credit: SGPT @ https://github.com/Muennighoff/sgpt | |
""" | |
weights = ( | |
torch.arange(start=1, end=last_hidden_states.shape[1] + 1) | |
.unsqueeze(0) | |
.unsqueeze(-1) | |
.expand(last_hidden_states.size()) | |
.float() | |
.to(last_hidden_states.device) | |
) | |
input_mask_expanded = ( | |
attention_mask.unsqueeze(-1).expand(last_hidden_states.size()).float() | |
) | |
sum_embeddings = torch.sum( | |
last_hidden_states * input_mask_expanded * weights, dim=1 | |
) | |
sum_mask = torch.sum(input_mask_expanded * weights, dim=1) | |
return sum_embeddings / sum_mask | |
@staticmethod | |
def _geometric_mean_pool( | |
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
) -> torch.Tensor: | |
# We add a very small value for padding tokens, | |
# so they don't affect the geometric mean calculation. | |
last_hidden = last_hidden_states.masked_fill( | |
~attention_mask[..., None].bool(), 1e-9 | |
) | |
# Compute the geometric mean. | |
log_last_hidden = torch.log(last_hidden) | |
geometric_mean_log = log_last_hidden.mean(dim=1) | |
return geometric_mean_log.exp() | |
@staticmethod | |
def _harmonic_mean_pool( | |
last_hidden_states: torch.Tensor, attention_mask: torch.Tensor | |
) -> torch.Tensor: | |
# We add a very small value for padding tokens, | |
# so they don't affect the harmonic mean calculation. | |
last_hidden = last_hidden_states.masked_fill( | |
~attention_mask[..., None].bool(), 1e-9 | |
) | |
# Compute the harmonic mean. | |
last_hidden_reciprocal = 1.0 / last_hidden | |
harmonic_mean_reciprocal = last_hidden_reciprocal.mean(dim=1) | |
return 1.0 / harmonic_mean_reciprocal | |
POOLING_METHODS = { | |
"mean": (_mean_pool, True), | |
"max": (_max_pool, True), | |
"last_token": (_last_token_pool, False), | |
"weighted_mean": (_weighted_mean_pool, True), | |
"geometric_mean": (_weighted_mean_pool, True), | |
"harmonic_mean": (_weighted_mean_pool, True), | |
} | |
def _text_length(self, text: Union[List[int], List[List[int]]]): | |
""" | |
Help function to get the length for the input text. Text can be either | |
a list of ints (which means a single text as input), or a tuple of list of ints | |
(representing several text inputs to the model). | |
Credit: SBERT | |
https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py#L324 | |
""" | |
if isinstance(text, dict): # {key: value} case | |
return len(next(iter(text.values()))) | |
elif not hasattr(text, "__len__"): # Object has no len() method | |
return 1 | |
elif len(text) == 0 or isinstance(text[0], int): # Empty string or list of ints | |
return len(text) | |
else: | |
return sum([len(t) for t in text]) # Sum of length of individual strings | |
def simple_embeddings( | |
self, texts: Union[str, list], pooling_strategy=None, NORMALIZE: bool = False | |
) -> np.ndarray: | |
"""Generate embeddings for given text or list of texts. | |
Parameters: | |
texts : str or list | |
The input text(s) to generate embeddings for. | |
Returns: | |
np.ndarray | |
The embeddings for each text in the shape of (len(texts), dimension). | |
""" | |
if pooling_strategy is None: | |
pooling_strategy = self.pooling_strategy | |
if pooling_strategy not in self.POOLING_METHODS: | |
raise ValueError(f"Unknown pooling strategy: {pooling_strategy}") | |
pooling_method, needs_attention_mask = self.POOLING_METHODS[pooling_strategy] | |
if isinstance(texts, str): | |
texts = [texts] | |
# Compute lengths and sort by length | |
lengths = [self._text_length(text) for text in texts] | |
sorted_indices = np.argsort(lengths) | |
sorted_texts = [texts[i] for i in sorted_indices] | |
embeddings = [] | |
for i in trange(0, len(sorted_texts), self.batch_size): | |
batch = sorted_texts[i : i + self.batch_size] | |
inputs = self.tokenizer( | |
batch, | |
return_tensors="pt", | |
truncation=True, | |
padding=max_length, # or True | |
max_length=self.max_length, | |
).to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model( | |
**inputs, output_attentions=True, output_hidden_states=True | |
) | |
if needs_attention_mask: | |
emb = pooling_method( | |
outputs.last_hidden_state, inputs["attention_mask"] | |
) | |
else: | |
emb = pooling_method(outputs.last_hidden_state) | |
if NORMALIZE: | |
emb = self.normalize_embeddings(emb) | |
embeddings.extend(emb.cpu().numpy().astype("float32")) | |
# Resort the embeddings to their original order | |
embeddings = [embeddings[i] for i in np.argsort(sorted_indices)] | |
return np.array(embeddings) | |
def summary(self): | |
""" | |
Provides a summary of the key parameters of the model. | |
""" | |
print(f"Model Name: {self.model.name_or_path}") | |
print(f"Batch Size: {self.batch_size}") | |
print(f"Token Length: {self.max_length}") | |
print(f"Pooling Strategy: {self.pooling_strategy}") | |
print(f"Device: {next(self.model.parameters()).device}") | |
def state_aware_embedding( | |
self, | |
texts: Union[str, list], | |
pooling_strategy: str = None, | |
max_states: int = 3, | |
NORMALIZE: bool = False, | |
) -> np.ndarray: | |
"""Generate embeddings for given text or list of texts. | |
Parameters: | |
texts : str or list | |
The input text(s) to generate embeddings for. | |
Returns: | |
np.ndarray | |
The embeddings for each text in the shape of (len(texts), dimension). | |
""" | |
pooling_strategy = self.pooling_strategy if pooling_strategy is None else pooling_strategy | |
if not self.pooling_strategy in MultiLevelPooler.get_pooling_modes(): | |
logging.warning( | |
f"Unknown pooling strategy: {pooling_strategy}. Using 'max' pooling instead." | |
) | |
pooling_strategy = "max" | |
if isinstance(texts, str): | |
texts = [texts] | |
pooler = MultiLevelPooler(pooling_strategy=pooling_strategy, max_states=max_states) | |
# Compute lengths and sort by length | |
lengths = [self._text_length(text) for text in texts] | |
sorted_indices = np.argsort(lengths) | |
sorted_texts = [texts[i] for i in sorted_indices] | |
embeddings = [] | |
for i in trange(0, len(sorted_texts), self.batch_size): | |
batch = sorted_texts[i : i + self.batch_size] | |
inputs = self.tokenizer( | |
batch, | |
return_tensors="pt", | |
truncation=True, | |
padding="longest", # or True | |
max_length=self.max_length, | |
).to(self.model.device) | |
with torch.no_grad(): | |
outputs = self.model( | |
**inputs, output_attentions=True, output_hidden_states=True | |
) | |
attention_mask = expand_attention_mask(inputs["attention_mask"], outputs.state[0].shape) | |
emb = pooler.forward(outputs.state, outputs.last_hidden_state,attention_mask) | |
# inputs["attention_mask"]) | |
if NORMALIZE: | |
emb = self.normalize_embeddings(emb) | |
embeddings.extend(emb.cpu().numpy().astype("float32")) | |
# Resort the embeddings to their original order | |
embeddings = [embeddings[i] for i in np.argsort(sorted_indices)] | |
return np.array(embeddings) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
RWKV embeddings with transformers (original starting point) | |
⚠️ embeddings from this are pretty poor because RWKV has special outputs ⚠️ | |
credit: GPTcache | |
https://gptcache.readthedocs.io/en/latest/_modules/gptcache/embedding/rwkv.html#Rwkv | |
""" | |
import numpy as np | |
from gptcache.embedding.base import BaseEmbedding | |
from gptcache.utils import import_huggingface | |
import_huggingface() | |
from transformers import AutoTokenizer, RwkvModel # pylint: disable=C0413 | |
class Rwkv(BaseEmbedding): | |
"""Generate sentence embedding for given text using RWKV models. | |
:param model: model name, defaults to 'RWKV/rwkv-4-169m-pile'. Check | |
https://huggingface.co/docs/transformers/model_doc/rwkv for more avaliable models. | |
:type model: str | |
Example: | |
.. code-block:: python | |
from gptcache.embedding import Rwkv | |
test_sentence = 'Hello, world.' | |
encoder = Rwkv(model='RWKV/rwkv-4-169m-pile') | |
embed = encoder.to_embeddings(test_sentence) | |
""" | |
def __init__(self, model: str = "RWKV/rwkv-4-169m-pile"): | |
self.model = RwkvModel.from_pretrained(model) | |
self.model.eval() | |
self.tokenizer = AutoTokenizer.from_pretrained(model) | |
try: | |
self.__dimension = self.model.config.hidden_size | |
except Exception: # pylint: disable=W0703 | |
from transformers import AutoConfig # pylint: disable=C0415 | |
config = AutoConfig.from_pretrained(model) | |
self.__dimension = config.hidden_size | |
def to_embeddings(self, data, **_): | |
"""Generate embedding given text input | |
:param data: text in string. | |
:type data: str | |
:return: a text embedding in shape of (dim,). | |
""" | |
inputs = self.tokenizer(data, return_tensors="pt") | |
outputs = self.model(inputs["input_ids"]) | |
emb = outputs.last_hidden_state[0, 0, :].detach().numpy() | |
return np.array(emb).astype("float32") | |
@property | |
def dimension(self): | |
"""Embedding dimension. | |
:return: embedding dimension | |
""" | |
return self.__dimension |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment