Last active
November 9, 2023 08:23
-
-
Save yongkangc/88382bbd8e2c4f7adc50f1cd84dcd92c to your computer and use it in GitHub Desktop.
ChatGPT in Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
use reqwest::header::{HeaderMap, AUTHORIZATION, CONTENT_TYPE}; | |
use reqwest::Client; | |
use serde::{Deserialize, Serialize}; | |
use std::env; | |
#[derive(Serialize, Deserialize, Debug)] | |
struct OAIRequest { | |
model: String, | |
messages: Vec<Message>, | |
} | |
#[derive(Debug, Deserialize)] | |
struct OAIResponse { | |
choices: Vec<Choice>, | |
} | |
#[derive(Debug, Deserialize)] | |
struct Choice { | |
message: Message, | |
} | |
#[derive(Serialize, Deserialize, Debug)] | |
struct Message { | |
role: String, | |
content: String, | |
} | |
const URI: &str = "https://api.openai.com/v1/chat/completions"; | |
const MODEL: &str = "gpt-3.5-turbo"; | |
pub async fn openai_query(text: &str) -> Result<String, Box<dyn std::error::Error>> { | |
let oai_token = env::var("OPENAI_API").expect("OPENAI_API must be set"); | |
let client = Client::new(); | |
let mut headers = HeaderMap::new(); | |
headers.insert(CONTENT_TYPE, "application/json".parse().unwrap()); | |
headers.insert( | |
AUTHORIZATION, | |
format!("Bearer {}", oai_token).parse().unwrap(), | |
); | |
let prompt_message: Message = Message { | |
role: String::from("system"), | |
content: String::from( | |
"You're a helpful assistant. Solve the following LeetCode problem:\n\n", | |
), | |
}; | |
let req = OAIRequest { | |
model: String::from(MODEL), | |
messages: vec![ | |
prompt_message, | |
Message { | |
role: String::from("user"), | |
content: String::from(text), | |
}, | |
], | |
}; | |
let res = client | |
.post(URI) | |
.headers(headers) | |
.json(&req) | |
.send() | |
.await? | |
.json::<OAIResponse>() | |
.await?; | |
let message = res | |
.choices | |
.last() | |
.ok_or("No choices returned")? | |
.message | |
.content | |
.clone(); | |
Ok(message) | |
} | |
#[cfg(test)] | |
mod tests { | |
use super::*; | |
#[tokio::test] | |
async fn test_openai_query() { | |
let text = String::from("Given an array of integers nums = [2,7,11,15] and an integer target = 9, return indices of the two numbers such that they add up to target."); | |
let res = openai_query(&text).await.unwrap(); | |
assert_eq!(res, String::from("[0,1]")); | |
} | |
} | |
#[tokio::main] | |
async fn main() { | |
let problem = "Given an array of integers nums = [2,7,11,15] and an integer target = 9, return indices of the two numbers such that they add up to target."; | |
match openai_query(problem).await { | |
Ok(solution) => println!("Solution: {}", solution), | |
Err(e) => eprintln!("Error: {}", e), | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment