Skip to content

Instantly share code, notes, and snippets.

@transitive-bullshit
Forked from rileytomasek/splade.py
Last active November 6, 2023 02:40
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save transitive-bullshit/cc9140ff832fc7e815a48f0a45e1fc27 to your computer and use it in GitHub Desktop.
Save transitive-bullshit/cc9140ff832fc7e815a48f0a45e1fc27 to your computer and use it in GitHub Desktop.
SPLADE on Modal
from fastapi.responses import JSONResponse
from modal import Image, Mount, NetworkFileSystem, Secret, Stub, method, web_endpoint
from pydantic import BaseModel
# This is copied from: https://github.com/pinecone-io/examples/blob/2f51ddfd12a08f2963cc2849661fab51afdeedc6/learn/search/semantic-search/sparse/splade/splade-vector-generation.ipynb#L10
# Which is recommended here: https://docs.pinecone.io/docs/hybrid-search
stub = Stub("splade")
image = Image.debian_slim().pip_install("torch", "transformers")
volume = NetworkFileSystem.persisted("splade-model-cache-vol-gcp", cloud="gcp")
CACHE_DIR = "/cache"
class Body(BaseModel):
text: str
@stub.cls(
image=image,
cloud="gcp",
cpu=2,
memory=2048,
keep_warm=40,
container_idle_timeout=120,
network_file_systems={CACHE_DIR: volume},
secret=Secret.from_dict({"TORCH_HOME": CACHE_DIR, "TRANSFORMERS_CACHE": CACHE_DIR}),
)
class SPLADE:
def __enter__(self):
import torch
from transformers import AutoModelForMaskedLM, AutoTokenizer
model = "naver/splade-cocondenser-ensembledistil"
# check device
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.tokenizer = AutoTokenizer.from_pretrained(model)
self.model = AutoModelForMaskedLM.from_pretrained(model)
# move to gpu if available
self.model.to(self.device)
@web_endpoint(method="POST")
def vector(self, body: Body):
import torch
from transformers.tokenization_utils_base import TruncationStrategy
text = body.text
max_length = self.tokenizer.model_max_length
inputs = self.tokenizer(
text,
truncation=TruncationStrategy.LONGEST_FIRST,
max_length=max_length,
return_tensors="pt",
).to(self.device)
with torch.no_grad():
logits = self.model(**inputs).logits
inter = torch.log1p(torch.relu(logits[0]))
token_max = torch.max(inter, dim=0) # sum over input tokens
nz_tokens = torch.where(token_max.values > 0)[0]
nz_weights = token_max.values[nz_tokens]
order = torch.sort(nz_weights, descending=True)
nz_weights = nz_weights[order[1]]
nz_tokens = nz_tokens[order[1]]
response = {
"indices": nz_tokens.cpu().numpy().tolist(),
"values": nz_weights.cpu().numpy().tolist(),
}
return JSONResponse(content=response)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment