Skip to content

Instantly share code, notes, and snippets.

@napsternxg
Last active September 7, 2023 16:51
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save napsternxg/e5ac5dd87ae4313e52ab489f612855e4 to your computer and use it in GitHub Desktop.
Save napsternxg/e5ac5dd87ae4313e52ab489f612855e4 to your computer and use it in GitHub Desktop.
Sentence Transformer + Setfit classification head for inference without installing setfit
from datasets import load_dataset, Dataset, DatasetDict
from sentence_transformers.losses import CosineSimilarityLoss
from sentence_transformers import SentenceTransformer
from setfit import SetFitModel, SetFitTrainer, sample_dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import numpy as np
import json
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
from pathlib import Path
import time
from tqdm import trange
import torch
from sklearn.metrics import classification_report
class SetFitModelFixed(SetFitModel):
@classmethod
def from_pretrained(cls, *args, include_layers=None, **kwargs):
# Allow dropping layers similar to miniLM-v2-L3 and L6
obj = super(SetFitModelFixed, cls).from_pretrained(*args, **kwargs)
if include_layers is not None:
auto_model = obj.model_body._modules['0'].auto_model
auto_model.config.num_hidden_layers = len(include_layers)
auto_model.encoder.layer = torch.nn.ModuleList([
l
for i, l in enumerate(auto_model.encoder.layer)
if i in include_layers
])
obj.model_body._modules['0'].auto_model = auto_model
return obj
def fit(
self,
x_train: List[str],
y_train: Union[List[int], List[List[int]]],
num_epochs: int,
batch_size: Optional[int] = None,
learning_rate: Optional[float] = None,
body_learning_rate: Optional[float] = None,
l2_weight: Optional[float] = None,
max_length: Optional[int] = None,
show_progress_bar: Optional[bool] = None,
class_weights: Optional[List[float]] = None
) -> None:
if self.has_differentiable_head: # train with pyTorch
device = self.model_body.device
self.model_body.train()
self.model_head.train()
dataloader = self._prepare_dataloader(x_train, y_train, batch_size, max_length)
criterion = self.model_head.get_loss_fn()
if hasattr(self, "class_weights"):
print(f"Using {self.class_weights=}")
# This hack allows us to bypass passing class weight via trainer which is TODO
class_weights = self.class_weights
if class_weights is not None:
print(f"Using {class_weights=}")
criterion.weight = torch.Tensor(class_weights).to(self.model_head.device)
optimizer = self._prepare_optimizer(learning_rate, body_learning_rate, l2_weight)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
for epoch_idx in trange(num_epochs, desc="Epoch", disable=not show_progress_bar):
for batch in dataloader:
features, labels = batch
optimizer.zero_grad()
# to model's device
features = {k: v.to(device) for k, v in features.items()}
labels = labels.to(device)
outputs = self.model_body(features)
if self.normalize_embeddings:
embeddings = outputs["sentence_embedding"]
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
outputs["sentence_embedding"] = embeddings
outputs = self.model_head(outputs)
logits = outputs["logits"]
loss = criterion(logits, labels)
loss.backward()
optimizer.step()
scheduler.step()
else: # train with sklearn
embeddings = self.model_body.encode(x_train, normalize_embeddings=self.normalize_embeddings)
self.model_head.fit(embeddings, y_train)
def save_pretrained(self, model_save_path, class_names, non_askic_eligible_idx):
super().save_pretrained(model_save_path)
print(f"Saving extra items like model_head, model_head.config, sentence_transformer_classifier")
torch.save(self.model_head, Path(model_save_path) / "model_head.pt")
torch.save(self.model_head.state_dict(), Path(model_save_path) / "model_head.state_dict.pt")
with open(Path(model_save_path) / "model_head.config.json", "w+") as fp:
json.dump(self.model_head.get_config_dict(), fp)
with open(Path(model_save_path) / "sentence_transformer_classifier.config.json", "w+") as fp:
json.dump(dict(
class_names=list([c.item() for c in class_names]),
marginalize_negative=list(non_askic_eligible_idx)
), fp)
# class_weights = compute_class_weight("balanced", classes=class_names, y=df_train[label_col])
include_layers = {0, 5, 11} # default=None, select every 5th layer
model = SetFitModelFixed.from_pretrained(
model_type,
use_differentiable_head=True,
head_params={"out_features": num_classes},
normalize_embeddings=normalize_embeddings,
multi_target_strategy=multi_target_strategy, # "one-vs-rest"
model_kwargs={"max_seq_length": 128},
include_layers=include_layers
)
model.class_weights = class_weights
def merge_models(base, other_models):
for k, v in base.state_dict().items():
print(k, base.state_dict()[k].shape)
base.state_dict()[k] += sum([other_model.state_dict()[k] for other_model in other_models])
base.state_dict()[k] /= len(other_models) + 1
print(base.state_dict()[k].shape)
"""Usage:
merge_models(model.model_body, [model_A.model_body, model_B.model_body])
merge_models(model.model_head, [model_A.model_head, model_B.model_head])
"""
import os
import subprocess
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
import torch
from huggingface_hub import PyTorchModelHubMixin
from instalog import instalog
from sentence_transformers import SentenceTransformer, models
from torch import nn
from functools import lru_cache
class ClassificationHead(models.Dense):
"""
A ClassificationHead head that supports multi-class classification for end-to-end training.
Binary classification is treated as 2-class classification.
To be compatible with Sentence Transformers, we inherit `Dense` from:
https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/models/Dense.py
Args:
in_features (`int`, *optional*):
The embedding dimension from the output of the sentence_transformer body. If `None`, defaults to `LazyLinear`.
out_features (`int`, defaults to `2`):
The number of targets. If set `out_features` to 1 for binary classification, it will be changed to 2 as 2-class classification.
temperature (`float`, defaults to `1.0`):
A logits' scaling factor. Higher values make the model less confident and lower values make
it more confident.
eps (`float`, defaults to `1e-5`):
A value for numerical stability when scaling logits.
bias (`bool`, *optional*, defaults to `True`):
Whether to add bias to the head.
device (`torch.device`, str, *optional*):
The device the model will be sent to. If `None`, will check whether GPU is available.
multitarget (`bool`, defaults to `False`):
Enable multi-target classification by making `out_features` binary predictions instead
of a single multinomial prediction.
"""
def __init__(
self,
in_features: Optional[int] = None,
out_features: int = 2,
temperature: float = 1.0,
eps: float = 1e-5,
bias: bool = True,
device: Optional[Union[torch.device, str]] = None,
multitarget: bool = False,
) -> None:
super(models.Dense, self).__init__() # init on models.Dense's parent: nn.Module
if out_features == 1:
logger.warning(
"Change `out_features` from 1 to 2 since we use `CrossEntropyLoss` for binary classification."
)
out_features = 2
if in_features is not None:
self.linear = nn.Linear(in_features, out_features, bias=bias)
else:
self.linear = nn.LazyLinear(out_features, bias=bias)
self.in_features = in_features
self.out_features = out_features
self.temperature = temperature
self.eps = eps
self.bias = bias
self._device = device or "cuda" if torch.cuda.is_available() else "cpu"
self.multitarget = multitarget
self.to(self._device)
self.apply(self._init_weight)
def forward(
self,
features: Union[Dict[str, torch.Tensor], torch.Tensor],
temperature: Optional[float] = None,
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor]]:
"""
SetFitHead can accept embeddings in:
1. Output format (`dict`) from Sentence-Transformers.
2. Pure `torch.Tensor`.
Args:
features (`Dict[str, torch.Tensor]` or `torch.Tensor):
The embeddings from the encoder. If using `dict` format,
make sure to store embeddings under the key: 'sentence_embedding'
and the outputs will be under the key: 'prediction'.
temperature (`float`, *optional*):
A logits' scaling factor. Higher values make the model less
confident and lower values make it more confident.
Will override the temperature given during initialization.
Returns:
[`Dict[str, torch.Tensor]` or `Tuple[torch.Tensor]`]
"""
temperature = temperature or self.temperature
is_features_dict = False # whether `features` is dict or not
if isinstance(features, dict):
assert "sentence_embedding" in features
is_features_dict = True
x = features["sentence_embedding"] if is_features_dict else features
logits = self.linear(x)
logits = logits / (temperature + self.eps)
if self.multitarget: # multiple targets per item
probs = torch.sigmoid(logits)
else: # one target per item
probs = nn.functional.softmax(logits, dim=-1)
if is_features_dict:
features.update(
{
"logits": logits,
"probs": probs,
}
)
return features
return logits, probs
def predict_proba(self, x_test: torch.Tensor) -> torch.Tensor:
self.eval()
return self(x_test)[1]
def predict(self, x_test: torch.Tensor) -> torch.Tensor:
probs = self.predict_proba(x_test)
if self.multitarget:
return torch.where(probs >= 0.5, 1, 0)
return torch.argmax(probs, dim=-1)
def get_loss_fn(self):
if self.multitarget: # if sigmoid output
return torch.nn.BCEWithLogitsLoss()
return torch.nn.CrossEntropyLoss()
def get_config_dict(self) -> Dict[str, Optional[Union[int, float, bool]]]:
return {
"in_features": self.in_features,
"out_features": self.out_features,
"temperature": self.temperature,
"bias": self.bias,
"device": self.device.type, # store the string of the device, instead of `torch.device`
}
@property
def device(self) -> torch.device:
"""
`torch.device`: The device on which the model is placed.
Reference from: https://github.com/UKPLab/sentence-transformers/blob/master/sentence_transformers/SentenceTransformer.py#L869
"""
return next(self.parameters()).device
@staticmethod
def _init_weight(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
torch.nn.init.constant_(module.bias, 1e-2)
def __repr__(self):
return f"{type(self).__name__}({self.get_config_dict()})"
@dataclass
class SentenceTransformerClassifier(nn.Module, PyTorchModelHubMixin):
def __init__(
self,
model_body: Optional[SentenceTransformer] = None,
model_head: Optional[ClassificationHead] = None,
multi_target_strategy: Optional[str] = None,
l2_weight: float = 1e-2,
normalize_embeddings: bool = False,
class_names: Optional[List[str]] = None,
marginalize_negative: Optional[List[int]] = None,
):
super().__init__()
self.model_body = model_body
self.model_head = model_head
self.multi_target_strategy = multi_target_strategy
self.l2_weight = l2_weight
self.normalize_embeddings = normalize_embeddings
self.class_names = class_names
self.marginalize_negative = marginalize_negative
def forward(self, inputs):
embeddings = self.model_body.encode(
inputs,
normalize_embeddings=self.normalize_embeddings,
convert_to_tensor=True,
)
outputs = self.model_head(embeddings)
return outputs
def _output_type_conversion(
self, outputs: torch.Tensor, as_numpy: bool = False
) -> Union[torch.Tensor, np.ndarray]:
"""Return `outputs` in the desired type:
* Numpy array if no differentiable head is used.
* Torch tensor if a differentiable head is used.
Returns:
Union[torch.Tensor, np.ndarray]: The input, correctly converted to the desired type.
"""
if as_numpy:
outputs = outputs.detach().cpu().numpy()
return outputs
def _marginalize_prediction(self, preds: torch.Tensor) -> torch.Tensor:
if not self.marginalize_negative:
raise RuntimeError(
f"Please initialize marginalize_negative. Found: {self.marginalize_negative}."
)
return 1 - preds[:, self.marginalize_negative].sum(axis=-1)
def predict(
self,
x_test: List[str],
as_numpy: bool = False,
marginalize: bool = False,
threshold: float = 0.5,
) -> Union[torch.Tensor, "ndarray"]:
with torch.no_grad():
embeddings = self.model_body.encode(
x_test,
normalize_embeddings=self.normalize_embeddings,
convert_to_tensor=True,
)
if marginalize:
outputs = self.model_head.predict_proba(embeddings)
outputs = self._marginalize_prediction(outputs) > threshold
else:
outputs = self.model_head.predict(embeddings)
return self._output_type_conversion(outputs, as_numpy=as_numpy)
def predict_proba(
self, x_test: List[str], as_numpy: bool = False, marginalize: bool = False
) -> Union[torch.Tensor, "ndarray"]:
with torch.no_grad():
embeddings = self.model_body.encode(
x_test,
normalize_embeddings=self.normalize_embeddings,
convert_to_tensor=True,
)
outputs = self.model_head.predict_proba(embeddings)
if marginalize:
outputs = self._marginalize_prediction(outputs)
return self._output_type_conversion(outputs, as_numpy=as_numpy)
class_names = [
"Brand Comparison",
"Comparative or Substitutes",
"Complements and Pairings",
"Conceptual Attributes",
"Contextual Ideas",
"Health",
"Instacart",
"Occassions",
"Product",
"Products with attributes",
"Shopping Lists",
]
non_askic_eligible_idx = [5, 6, 8, 9]
def load_classifier(model_path):
head_config = {
"in_features": 384,
"out_features": 11,
"temperature": 1.0,
"bias": True,
"device": "cpu",
}
model_body = SentenceTransformer(model_path).to("cpu")
model_body
model_head = ClassificationHead(**head_config)
model_head
model_head_state_dict = torch.load(
Path(model_path) / "model_head.state_dict.pt", map_location=torch.device("cpu")
)
model_head_state_dict
model_head.load_state_dict(model_head_state_dict)
st_model = SentenceTransformerClassifier(
model_body=model_body, model_head=model_head, class_names=class_names
)
st_model = st_model.eval()
return st_model
class Classifier(object):
def __init__(self, model_path, threshhold) -> None:
self.model_path = model_path
self.threshhold = threshhold
self.model = load_classifier(model_path)
@lru_cache(maxsize=100_000)
def __call__(self, query):
prediction = self.model.predict_proba([query], marginalize=True)
output = dict(
is_true=prediction[0] > self.threshhold,
true_prob=prediction[0],
)
return output
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment