Skip to content

Instantly share code, notes, and snippets.

@ToluClassics
Last active July 26, 2023 04:13
Show Gist options
  • Save ToluClassics/b6fcb1c7c375cce9bafc9f0dddfb86ab to your computer and use it in GitHub Desktop.
Save ToluClassics/b6fcb1c7c375cce9bafc9f0dddfb86ab to your computer and use it in GitHub Desktop.
Bert in Rust
use std::borrow::Borrow;
use tch::nn::ModuleT;
use tch::nn::{self};
use tch::{Kind, Tensor};
#[derive(Debug)]
pub struct Dropout {
dropout_prob: f64,
}
impl Dropout {
pub fn new(p: f64) -> Dropout {
Dropout { dropout_prob: p }
}
}
impl ModuleT for Dropout {
fn forward_t(&self, input: &Tensor, train: bool) -> Tensor {
input.dropout(self.dropout_prob, train)
}
}
pub struct Config {
pub vocab_size: i64,
pub hidden_size: i64,
pub num_hidden_layers: i64,
pub num_attention_heads: i64,
pub intermediate_size: i64,
pub hidden_act: String,
pub hidden_dropout_prob: f64,
pub attention_probs_dropout_prob: f64,
pub max_position_embeddings: i64,
pub type_vocab_size: i64,
pub initializer_range: f64,
pub layer_norm_eps: f64,
}
pub struct BertEmbeddings {
word_embeddings: nn::Embedding,
position_embeddings: nn::Embedding,
token_type_embeddings: nn::Embedding,
layer_norm: nn::LayerNorm,
dropout: Dropout,
}
impl Default for Config {
fn default() -> Self {
Config {
vocab_size: 30522,
hidden_size: 768,
num_hidden_layers: 12,
num_attention_heads: 12,
intermediate_size: 3072,
hidden_act: "gelu".to_string(),
hidden_dropout_prob: 0.1,
attention_probs_dropout_prob: 0.1,
max_position_embeddings: 512,
type_vocab_size: 2,
initializer_range: 0.02,
layer_norm_eps: 1e-12,
}
}
}
pub fn new(
vocab_size: i64,
hidden_size: i64,
num_hidden_layers: i64,
num_attention_heads: i64,
intermediate_size: i64,
hidden_act: String,
hidden_dropout_prob: f64,
attention_probs_dropout_prob: f64,
max_position_embeddings: i64,
type_vocab_size: i64,
initializer_range: f64,
layer_norm_eps: f64,
) -> Self {
Config {
vocab_size,
hidden_size,
num_hidden_layers,
num_attention_heads,
intermediate_size,
hidden_act,
hidden_dropout_prob,
attention_probs_dropout_prob,
max_position_embeddings,
type_vocab_size,
initializer_range,
layer_norm_eps,
}
}
}
impl BertEmbeddings {
pub fn new<'p, P>(p: P, config: &Config) -> BertEmbeddings
where
P: Borrow<nn::Path<'p>>,
{
let p = p.borrow();
let word_embeddings = nn::embedding(
p / "word_embeddings",
config.vocab_size,
config.hidden_size,
Default::default(),
);
let position_embeddings = nn::embedding(
p / "position_embeddings",
config.max_position_embeddings,
config.hidden_size,
Default::default(),
);
let token_type_embeddings = nn::embedding(
p / "token_type_embeddings",
config.type_vocab_size,
config.hidden_size,
Default::default(),
);
let layer_norm_config = nn::LayerNormConfig {
eps: config.layer_norm_eps,
..Default::default()
};
let layer_norm =
nn::layer_norm(p / "LayerNorm", vec![config.hidden_size], layer_norm_config);
let dropout = Dropout::new(config.hidden_dropout_prob);
Self {
word_embeddings,
position_embeddings,
token_type_embeddings,
layer_norm,
dropout,
}
}
pub fn forward_t(
&self,
input_ids: &Tensor,
token_type_ids: Option<&Tensor>,
position_ids: Option<&Tensor>,
train: bool,
) -> Result<Tensor, &'static str> {
let input_shape = input_ids.size();
let seq_length = input_shape[1];
let device = input_ids.device();
let input_ids = input_ids.view((-1, seq_length));
let position_ids = match position_ids {
Some(position_ids) => position_ids.view((-1, seq_length)),
None => Tensor::arange(seq_length, (Kind::Int64, device))
.unsqueeze(0)
.expand(&input_shape, true),
};
let token_type_ids = match token_type_ids {
Some(token_type_ids) => token_type_ids.view((-1, seq_length)),
None => Tensor::zeros(&input_shape, (Kind::Int64, device)),
};
let word_embeddings = input_ids.apply(&self.word_embeddings);
let position_embeddings = position_ids.apply(&self.position_embeddings);
let token_type_embeddings = token_type_ids.apply(&self.token_type_embeddings);
let mut embeddings = word_embeddings + position_embeddings + token_type_embeddings;
Ok(embeddings
.apply(&self.layer_norm)
.apply_t(&self.dropout, train))
}
}
fn test_bert_embeddings() {
let config = Config::new(
30522,
768,
12,
3072,
12,
"gelu".to_string(),
0.1,
0.1,
512,
2,
0.02,
1e-12,
);
let vs = VarStore::new(Device::Cpu);
let root = vs.root();
let embeddings = BertEmbeddings::new(&root, &config);
let input_ids = Tensor::of_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
let input_ids = input_ids.unsqueeze(0);
let token_type_ids = Tensor::of_slice(&vec![0, 0, 0, 0, 0, 0, 0, 0, 0, 0]);
let token_type_ids = token_type_ids.unsqueeze(0);
let position_ids = Tensor::of_slice(&vec![0, 1, 2, 3, 4, 0, 1, 2, 3, 4]);
let position_ids = position_ids.unsqueeze(0);
let output = embeddings.forward_t(
&input_ids,
Some(&token_type_ids),
Some(&position_ids),
false,
);
let expected_shape = vec![1, 10, 768];
assert_eq!(output.unwrap().size(), expected_shape.as_slice());
}
fn main(){
test_bert_embeddings()
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment