Skip to content

Instantly share code, notes, and snippets.

@densumesh
Created December 14, 2023 19:26
Show Gist options
  • Save densumesh/91d603841849389020b8b1cee89cd74c to your computer and use it in GitHub Desktop.
Save densumesh/91d603841849389020b8b1cee89cd74c to your computer and use it in GitHub Desktop.
Rust Splade Embeddings generation
pub fn get_splade_vector(
input: String,
model: Model,
tokenizer: &Tokenizer,
) -> Result<Vec<f32>, ServiceError> {
let tokenized_inputs = tokenizer.encode(input, false).unwrap();
let tokens = tokenized_inputs.get_ids().to_vec();
let token_ids = Tensor::new(tokens.as_slice(), &candle_core::Device::Cpu)
.map_err(|e| ServiceError::BadRequest(format!("Could not create tensor: {}", e)))?;
let token_ids = token_ids.unsqueeze(0).unwrap();
let token_type_ids = token_ids
.zeros_like()
.map_err(|e| ServiceError::BadRequest(format!("Could not create tensor: {}", e)))?;
let attention_mask = token_type_ids
.ne(0_i64)
.map_err(|e| ServiceError::BadRequest(format!("Could not run ne: {}", e)))?;
log::info!("token_ids: {:?}", token_ids);
log::info!("token_type_ids: {:?}", token_type_ids);
let embeddings = match model {
Model::Doc(model) => model
.forward(&token_ids, &attention_mask)
.map_err(|e| ServiceError::BadRequest(format!("Could not run model: {}", e)))?,
Model::Query(model) => model
.forward(&token_ids, &token_type_ids)
.map_err(|e| ServiceError::BadRequest(format!("Could not run model: {}", e)))?,
};
let logits = embeddings.to_dtype(candle_core::DType::F32).unwrap();
let relu = logits.relu().map_err(|e| {
ServiceError::BadRequest(format!("Could not run relu on logits: {}", e))
})?;
let relu_log = relu
.add(
&Tensor::ones(
relu.shape(),
candle_core::DType::F32,
&candle_core::Device::Cpu,
)
.unwrap(),
)
.unwrap()
.log()
.map_err(|e| ServiceError::BadRequest(format!("Could not run log on logits: {}", e)))?;
let weighted_log = relu_log
.broadcast_mul(
&attention_mask
.unsqueeze(D::Minus1)
.unwrap()
.to_dtype(candle_core::DType::F32)
.unwrap(),
)
.map_err(|e| ServiceError::BadRequest(format!("Could not run mul: {}", e)))?;
let max_val = weighted_log
.max(1)
.map_err(|e| ServiceError::BadRequest(format!("Could not run max: {}", e)))?;
log::info!("max_val: {:?}", max_val);
max_val
.squeeze(0)
.unwrap()
.to_vec1::<f32>()
.map_err(|e| ServiceError::BadRequest(format!("Could not run to_vec1: {}", e)))
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment