Skip to content

Instantly share code, notes, and snippets.

@guillaume-be
Last active December 15, 2020 19:38
Show Gist options
  • Save guillaume-be/76e0d287dc125592e8a2088cc48f7066 to your computer and use it in GitHub Desktop.
Save guillaume-be/76e0d287dc125592e8a2088cc48f7066 to your computer and use it in GitHub Desktop.
albert_loading_inference_loop.rs
// Copyright 2018 Google AI and Google Brain team.
// Copyright 2020-present, the HuggingFace Inc. team.
// Copyright 2020 Guillaume Becquin
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
extern crate anyhow;
use rust_bert::albert::{
AlbertConfig, AlbertConfigResources, AlbertForMaskedLM, AlbertModelResources,
AlbertVocabResources,
};
use rust_bert::resources::{RemoteResource, Resource};
use rust_bert::Config;
use rust_tokenizers::tokenizer::{AlbertTokenizer, Tokenizer, TruncationStrategy};
use rust_tokenizers::vocab::Vocab;
use tch::{nn, no_grad, Device, Tensor, Kind};
fn main() -> anyhow::Result<()> {
// Resources paths
let config_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertConfigResources::ALBERT_BASE_V2,
));
let vocab_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertVocabResources::ALBERT_BASE_V2,
));
let weights_resource = Resource::Remote(RemoteResource::from_pretrained(
AlbertModelResources::ALBERT_BASE_V2,
));
let config_path = config_resource.get_local_path()?;
let vocab_path = vocab_resource.get_local_path()?;
let weights_path = weights_resource.get_local_path()?;
let mut i = 0;
loop {
// Set-up masked LM model
let device = Device::cuda_if_available();
let mut vs = nn::VarStore::new(device);
let config = AlbertConfig::from_file(&config_path);
let albert_model = AlbertForMaskedLM::new(&vs.root(), &config);
vs.load(&weights_path)?;
let input_tensor = Tensor::randint(10, &[2, 10], (Kind::Int64, device));
// // Forward pass
let model_output =
no_grad(|| albert_model.forward_t(Some(input_tensor), None, None, None, None, false));
println!(
"{:?}",
model_output.prediction_scores.double_value(&[0, 0, 0])
);
i += 1;
println!("iteration {}", i);
}
Ok(())
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment