Created
January 17, 2024 00:50
-
-
Save louis030195/688b2b9716ff3d6757f9ac7b1e827ab0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#[cfg(test)] | |
mod tests { | |
use std::sync::Arc; | |
use crate::{ | |
assistants::{ | |
create_assistant_handler, delete_assistant_handler, get_assistant_handler, | |
list_assistants_handler, update_assistant_handler, | |
}, | |
messages::{ | |
add_message_handler, delete_message_handler, get_message_handler, | |
list_messages_handler, update_message_handler, | |
}, | |
models::AppState, | |
runs::{ | |
create_run_handler, delete_run_handler, get_run_handler, list_runs_handler, | |
update_run_handler, | |
}, | |
threads::{ | |
create_thread_handler, delete_thread_handler, get_thread_handler, list_threads_handler, | |
update_thread_handler, | |
}, | |
}; | |
use super::*; | |
use assistants_core::{ | |
executor::try_run_executor, file_storage::FileStorage, test_data::OPENAPI_SPEC, | |
}; | |
use async_openai::types::{ | |
AssistantObject, AssistantTools, AssistantToolsExtra, CreateAssistantRequest, | |
CreateMessageRequest, CreateRunRequest, CreateThreadRequest, ListMessagesResponse, | |
MessageContent, MessageObject, MessageRole, RunObject, RunStatus, ThreadObject, | |
}; | |
use axum::{ | |
body::Body, | |
extract::DefaultBodyLimit, | |
http::{self, HeaderName, Request, StatusCode}, | |
routing::{delete, get, post}, | |
Router, | |
}; | |
use dotenv::dotenv; | |
use hyper::{self, Method}; | |
use mime; | |
use serde_json::json; | |
use sqlx::{postgres::PgPoolOptions, PgPool}; | |
use tower::ServiceExt; | |
use tower_http::{ | |
cors::{Any, CorsLayer}, | |
limit::RequestBodyLimitLayer, | |
trace::TraceLayer, | |
}; | |
use wiremock::matchers::{method, path}; // for `oneshot` and `ready` | |
/// Having a function that produces our app makes it easy to call it from tests | |
/// without having to create an HTTP server. | |
#[allow(dead_code)] | |
fn app(app_state: AppState) -> Router { | |
let cors = CorsLayer::new() | |
.allow_methods([Method::GET, Method::POST]) | |
.allow_origin(Any) | |
.allow_headers(vec![HeaderName::from_static("content-type")]); | |
Router::new() | |
.route("/assistants", post(create_assistant_handler)) | |
.route("/assistants/:assistant_id", get(get_assistant_handler)) | |
.route("/assistants/:assistant_id", post(update_assistant_handler)) | |
.route( | |
"/assistants/:assistant_id", | |
delete(delete_assistant_handler), | |
) | |
.route("/assistants", get(list_assistants_handler)) | |
.route("/threads", post(create_thread_handler)) | |
.route("/threads/:thread_id", get(get_thread_handler)) | |
.route("/threads", get(list_threads_handler)) | |
.route("/threads/:thread_id", post(update_thread_handler)) | |
.route("/threads/:thread_id", delete(delete_thread_handler)) | |
.route("/threads/:thread_id/messages", post(add_message_handler)) | |
.route( | |
"/threads/:thread_id/messages/:message_id", | |
get(get_message_handler), | |
) | |
.route( | |
"/threads/:thread_id/messages/:message_id", | |
post(update_message_handler), | |
) | |
.route( | |
"/threads/:thread_id/messages/:message_id", | |
delete(delete_message_handler), | |
) | |
.route("/threads/:thread_id/messages", get(list_messages_handler)) | |
.route("/threads/:thread_id/runs", post(create_run_handler)) | |
.route("/threads/:thread_id/runs/:run_id", get(get_run_handler)) | |
.route("/threads/:thread_id/runs/:run_id", post(update_run_handler)) | |
.route( | |
"/threads/:thread_id/runs/:run_id", | |
delete(delete_run_handler), | |
) | |
.route("/threads/:thread_id/runs", get(list_runs_handler)) | |
.layer(DefaultBodyLimit::disable()) | |
.layer(RequestBodyLimitLayer::new(250 * 1024 * 1024)) // 250mb | |
.layer(TraceLayer::new_for_http()) // Add this line | |
.layer(cors) | |
.with_state(app_state) | |
} | |
async fn setup() -> AppState { | |
dotenv().ok(); | |
let database_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); | |
let pool = PgPoolOptions::new() | |
.max_connections(5) | |
.connect(&database_url) | |
.await | |
.expect("Failed to create pool."); | |
// Initialize the logger with an info level filter | |
match env_logger::builder() | |
.filter_level(log::LevelFilter::Info) | |
.try_init() | |
{ | |
Ok(_) => (), | |
Err(_) => (), | |
}; | |
let app_state = AppState { | |
pool: Arc::new(pool), | |
file_storage: Arc::new(FileStorage::new().await), | |
}; | |
app_state | |
} | |
async fn reset_db(pool: &PgPool) { | |
sqlx::query!( | |
"TRUNCATE assistants, threads, messages, runs, functions, tool_calls RESTART IDENTITY" | |
) | |
.execute(pool) | |
.await | |
.unwrap(); | |
} | |
#[tokio::test] | |
async fn test_end_to_end_wikipedia_action_tool() { | |
let app_state = setup().await; | |
let app = app(app_state.clone()); | |
let pool_clone = app_state.pool.clone(); | |
reset_db(&app_state.pool).await; | |
let assistant = CreateAssistantRequest { | |
instructions: Some( | |
"You are a personal assistant. Use the MediaWiki API to fetch random facts." | |
.to_string(), | |
), | |
name: Some("Action Tool Assistant".to_string()), | |
tools: Some(vec![AssistantTools::Extra(AssistantToolsExtra { | |
r#type: "action".to_string(), | |
data: OPENAPI_SPEC.to_string(), | |
})]), | |
model: "mistralai/mixtral-8x7b-instruct".to_string(), | |
file_ids: None, | |
description: None, | |
metadata: None, | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri("/assistants") | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&assistant).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let assistant: AssistantObject = serde_json::from_slice(&body).unwrap(); | |
// create thread and run | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri("/threads") | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let thread: ThreadObject = serde_json::from_slice(&body).unwrap(); | |
// Send a message to the assistant | |
let message = CreateMessageRequest { | |
file_ids: None, | |
metadata: None, | |
role: "user".to_string(), | |
content: "Give me a random fact. Also provide the exact output from the API" | |
.to_string(), | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri(format!("/threads/{}/messages", thread.id)) // Use the thread ID here | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&message).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let body: MessageObject = serde_json::from_slice(&body).unwrap(); | |
let run_input = CreateRunRequest { | |
assistant_id: assistant.id, | |
instructions: Some("Please help me find a random fact.".to_string()), | |
additional_instructions: None, | |
model: None, | |
tools: None, | |
metadata: None, | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri(format!("/threads/{}/runs", thread.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&run_input).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let run: RunObject = serde_json::from_slice(&body).unwrap(); | |
let redis_url = std::env::var("REDIS_URL").expect("REDIS_URL must be set"); | |
let client = redis::Client::open(redis_url).unwrap(); | |
let mut con = client.get_async_connection().await.unwrap(); | |
let result = try_run_executor(&pool_clone, &mut con).await; | |
assert!(result.is_ok(), "{:?}", result); | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::GET) | |
.uri(format!("/threads/{}/runs/{}", thread.id, run.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let run: RunObject = serde_json::from_slice(&body).unwrap(); | |
assert_eq!(run.status, RunStatus::Completed); | |
// Fetch the messages from the database | |
let response = app | |
.clone() | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::GET) | |
.uri(format!("/threads/{}/messages", thread.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let messages: ListMessagesResponse = serde_json::from_slice(&body).unwrap(); | |
// Check the assistant's response | |
assert_eq!(messages.data.len(), 2); | |
assert_eq!(messages.data[1].role, MessageRole::Assistant); | |
if let MessageContent::Text(text_object) = &messages.data[1].content[0] { | |
assert!( | |
text_object.text.value.contains("ID") | |
|| text_object.text.value.contains("id") | |
|| text_object.text.value.contains("batchcomplete") | |
|| text_object.text.value.contains("talk"), | |
"Expected the assistant to return a text containing either 'ID', 'id', 'batchcomplete', or 'talk', but got something else: {}", | |
text_object.text.value | |
); | |
} else { | |
panic!("Expected a Text message, but got something else."); | |
} | |
} | |
#[tokio::test] | |
async fn test_end_to_end_with_json_openapi_and_post() { | |
let app_state = setup().await; | |
let app = app(app_state.clone()); | |
let pool_clone = app_state.pool.clone(); | |
reset_db(&app_state.pool).await; | |
// Set up a mock server | |
let mock_server = wiremock::MockServer::start().await; | |
// Define a mock for the /users POST endpoint | |
wiremock::Mock::given(method("POST")) | |
.and(path("/users")) | |
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(json!({ | |
"name": "John Doe", | |
"email": "john.doe@example.com" | |
}))) | |
.mount(&mock_server) | |
.await; | |
// Define your OpenAPI spec here | |
let openapi_spec = r#" | |
{ | |
"openapi": "3.0.0", | |
"info": { | |
"title": "Sample API", | |
"description": "API description in Markdown.", | |
"version": "1.0.0" | |
}, | |
"paths": { | |
"/users": { | |
"post": { | |
"summary": "Creates a new user.", | |
"requestBody": { | |
"content": { | |
"application/json": { | |
"schema": { | |
"type": "object", | |
"properties": { | |
"name": { | |
"type": "string" | |
}, | |
"email": { | |
"type": "string" | |
} | |
}, | |
"required": ["name", "email"] | |
} | |
} | |
} | |
}, | |
"responses": { | |
"200": { | |
"description": "A user object.", | |
"content": { | |
"application/json": { | |
"schema": { | |
"type": "object", | |
"properties": { | |
"name": { | |
"type": "string" | |
}, | |
"email": { | |
"type": "string" | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
} | |
"#; | |
let assistant = CreateAssistantRequest { | |
instructions: Some( | |
"You are a personal assistant. Use the OpenAPI spec to create a new user." | |
.to_string(), | |
), | |
name: Some("OpenAPI Assistant".to_string()), | |
tools: Some(vec![AssistantTools::Extra(AssistantToolsExtra { | |
r#type: "action".to_string(), | |
data: openapi_spec.to_string(), | |
})]), | |
model: "mistralai/mixtral-8x7b-instruct".to_string(), | |
file_ids: None, | |
description: None, | |
metadata: None, | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri("/assistants") | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&assistant).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let assistant: AssistantObject = serde_json::from_slice(&body).unwrap(); | |
// create thread and run | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri("/threads") | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let thread: ThreadObject = serde_json::from_slice(&body).unwrap(); | |
// Send a message to the assistant | |
let message = CreateMessageRequest { | |
file_ids: None, | |
metadata: None, | |
role: "user".to_string(), | |
content: "Create a new user with name 'John Doe' and email 'john.doe@example.com'" | |
.to_string(), | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri(format!("/threads/{}/messages", thread.id)) // Use the thread ID here | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&message).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let body: MessageObject = serde_json::from_slice(&body).unwrap(); | |
let run_input = CreateRunRequest { | |
assistant_id: assistant.id, | |
instructions: Some("Please help me create a new user.".to_string()), | |
additional_instructions: None, | |
model: None, | |
tools: None, | |
metadata: None, | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri(format!("/threads/{}/runs", thread.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&run_input).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let run: RunObject = serde_json::from_slice(&body).unwrap(); | |
let redis_url = std::env::var("REDIS_URL").expect("REDIS_URL must be set"); | |
let client = redis::Client::open(redis_url).unwrap(); | |
let mut con = client.get_async_connection().await.unwrap(); | |
let result = try_run_executor(&pool_clone, &mut con).await; | |
assert!(result.is_ok(), "{:?}", result); | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::GET) | |
.uri(format!("/threads/{}/runs/{}", thread.id, run.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let run: RunObject = serde_json::from_slice(&body).unwrap(); | |
assert_eq!(run.status, RunStatus::Completed); | |
// Fetch the messages from the database | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::GET) | |
.uri(format!("/threads/{}/messages", thread.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let messages: ListMessagesResponse = serde_json::from_slice(&body).unwrap(); | |
// Check the assistant's response | |
assert_eq!(messages.data.len(), 2); | |
assert_eq!(messages.data[1].role, MessageRole::Assistant); | |
if let MessageContent::Text(text_object) = &messages.data[1].content[0] { | |
assert!( | |
text_object.text.value.contains("name") | |
|| text_object.text.value.contains("email"), | |
"Expected the assistant to return a text containing either 'name' or 'email', but got something else: {}", | |
text_object.text.value | |
); | |
} else { | |
panic!("Expected a Text message, but got something else."); | |
} | |
} | |
#[tokio::test] | |
async fn test_end_to_end_with_post() { | |
let app_state = setup().await; | |
let app = app(app_state.clone()); | |
let pool_clone = app_state.pool.clone(); | |
reset_db(&app_state.pool).await; | |
// Set up a mock server | |
let mock_server = wiremock::MockServer::start().await; | |
// Define a mock for the /users POST endpoint | |
wiremock::Mock::given(method("POST")) | |
.and(path("/users")) | |
.respond_with(wiremock::ResponseTemplate::new(200).set_body_json(json!({ | |
"name": "John Doe", | |
"email": "john.doe@example.com" | |
}))) | |
.mount(&mock_server) | |
.await; | |
// Define your OpenAPI spec here | |
let openapi_spec = r#" | |
openapi: "3.0.0" | |
info: | |
title: "Sample API" | |
description: "API description in Markdown." | |
version: "1.0.0" | |
paths: | |
/users: | |
post: | |
summary: "Creates a new user." | |
requestBody: | |
content: | |
application/json: | |
schema: | |
type: "object" | |
properties: | |
name: | |
type: "string" | |
email: | |
type: "string" | |
required: | |
- "name" | |
- "email" | |
responses: | |
'200': | |
description: "A user object." | |
content: | |
application/json: | |
schema: | |
type: "object" | |
properties: | |
name: | |
type: "string" | |
email: | |
type: "string" | |
"#; | |
let assistant = CreateAssistantRequest { | |
instructions: Some( | |
"You are a personal assistant. Use the OpenAPI spec to create a new user." | |
.to_string(), | |
), | |
name: Some("OpenAPI Assistant".to_string()), | |
tools: Some(vec![AssistantTools::Extra(AssistantToolsExtra { | |
r#type: "action".to_string(), | |
data: openapi_spec.to_string(), | |
})]), | |
model: "mistralai/mixtral-8x7b-instruct".to_string(), | |
file_ids: None, | |
description: None, | |
metadata: None, | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri("/assistants") | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&assistant).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let assistant: AssistantObject = serde_json::from_slice(&body).unwrap(); | |
// create thread and run | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri("/threads") | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let thread: ThreadObject = serde_json::from_slice(&body).unwrap(); | |
// Send a message to the assistant | |
let message = CreateMessageRequest { | |
file_ids: None, | |
metadata: None, | |
role: "user".to_string(), | |
content: "Create a new user with name 'John Doe' and email 'john.doe@example.com'" | |
.to_string(), | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri(format!("/threads/{}/messages", thread.id)) // Use the thread ID here | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&message).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let body: MessageObject = serde_json::from_slice(&body).unwrap(); | |
let run_input = CreateRunRequest { | |
assistant_id: assistant.id, | |
instructions: Some("Please help me create a new user.".to_string()), | |
additional_instructions: None, | |
model: None, | |
tools: None, | |
metadata: None, | |
}; | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::POST) | |
.uri(format!("/threads/{}/runs", thread.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::from(serde_json::to_vec(&run_input).unwrap())) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let run: RunObject = serde_json::from_slice(&body).unwrap(); | |
let redis_url = std::env::var("REDIS_URL").expect("REDIS_URL must be set"); | |
let client = redis::Client::open(redis_url).unwrap(); | |
let mut con = client.get_async_connection().await.unwrap(); | |
let result = try_run_executor(&pool_clone, &mut con).await; | |
assert!(result.is_ok(), "{:?}", result); | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::GET) | |
.uri(format!("/threads/{}/runs/{}", thread.id, run.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let run: RunObject = serde_json::from_slice(&body).unwrap(); | |
assert_eq!(run.status, RunStatus::Completed); | |
// Fetch the messages from the database | |
let response = app | |
.clone() | |
.oneshot( | |
Request::builder() | |
.method(http::Method::GET) | |
.uri(format!("/threads/{}/messages", thread.id)) | |
.header(http::header::CONTENT_TYPE, mime::APPLICATION_JSON.as_ref()) | |
.body(Body::empty()) | |
.unwrap(), | |
) | |
.await | |
.unwrap(); | |
assert_eq!(response.status(), StatusCode::OK); | |
let body = hyper::body::to_bytes(response.into_body()).await.unwrap(); | |
let messages: ListMessagesResponse = serde_json::from_slice(&body).unwrap(); | |
// Check the assistant's response | |
assert_eq!(messages.data.len(), 2); | |
assert_eq!(messages.data[1].role, MessageRole::Assistant); | |
if let MessageContent::Text(text_object) = &messages.data[1].content[0] { | |
assert!( | |
text_object.text.value.contains("name") | |
|| text_object.text.value.contains("email"), | |
"Expected the assistant to return a text containing either 'name' or 'email', but got something else: {}", | |
text_object.text.value | |
); | |
} else { | |
panic!("Expected a Text message, but got something else."); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment