Skip to content

Instantly share code, notes, and snippets.

@ywegel
Last active March 31, 2024 22:06
Show Gist options
  • Save ywegel/b475f3aefe5a3c4f1e75aa1666cf9003 to your computer and use it in GitHub Desktop.
Save ywegel/b475f3aefe5a3c4f1e75aa1666cf9003 to your computer and use it in GitHub Desktop.
oauth_token_manager.rs
// DISCLAIMER: This code is provided as is without any warranties. It's primarily for learning and informational purposes.
// The author does not guarantee that the code is fit for any particular purpose and will not be responsible for any damage or loss resulting from the use of this code.
// The author does not assume liability for the code.
use anyhow::Result;
use std::time::{Duration, SystemTime};
use jsonwebtoken::{encode, EncodingKey, Header};
use reqwest::Client;
use serde_json::json;
pub struct TokenManager {
token: Option<String>,
expires_at: Option<SystemTime>,
service_account_key: ServiceAccountKey,
}
#[derive(serde::Deserialize)]
struct ServiceAccountKey {
private_key: String,
client_email: String,
private_key_id: String,
}
impl TokenManager {
pub fn new(google_credentials_location: &str) -> Result<Self> {
let service_account_key_json = std::fs::read_to_string(google_credentials_location)?;
let service_account_key: ServiceAccountKey = serde_json::from_str(&service_account_key_json)?;
Ok(TokenManager {
token: None,
expires_at: None,
service_account_key,
})
}
pub async fn get_token(&mut self) -> Result<String> {
if let Some(token) = &self.token {
if !self.is_token_expired() {
return Ok(token.clone());
}
}
self.refresh_token().await
}
fn is_token_expired(&self) -> bool {
if let Some(expires_at) = self.expires_at {
expires_at <= SystemTime::now()
} else {
true
}
}
async fn refresh_token(&mut self) -> Result<String> {
let signed_jwt = create_signed_jwt(&self.service_account_key);
let access_token_response = get_access_token(&signed_jwt).await?;
let new_token = access_token_response.access_token;
self.token = Some(new_token.clone());
self.expires_at = Some(SystemTime::now() + Duration::from_secs(access_token_response.expirese));
Ok(new_token)
}
}
fn create_signed_jwt(service_account_key: &ServiceAccountKey) -> String {
let mut header = Header::new(jsonwebtoken::Algorithm::RS256);
header.kid = Some(service_account_key.private_key_id.clone());
let now = SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs();
let claims = json!({
"iss": service_account_key.client_email,
"scope": "https://www.googleapis.com/auth/firebase.messaging",
"aud": "https://oauth2.googleapis.com/token",
"exp": now + 3600,
"iat": now
});
let encoding_key = EncodingKey::from_rsa_pem(&service_account_key.private_key.as_bytes()).unwrap();
encode(&header, &claims, &encoding_key).unwrap()
}
#[derive(serde::Deserialize)]
struct AccessTokenResponse {
pub access_token: String,
pub expirese: u64,
}
async fn get_access_token(signed_jwt: &str) -> std::result::Result<AccessTokenResponse, reqwest::Error> {
let client = Client::new();
let params = [
("grant_type", "urn:ietf:params:oauth:grant-type:jwt-bearer"),
("assertion", signed_jwt),
];
let res = client
.post("https://oauth2.googleapis.com/token")
.form(&params)
.send()
.await?
.json::<AccessTokenResponse>()
.await?;
Ok(res)
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment