add more tests

This commit is contained in:
Adil Hafeez 2025-05-30 11:03:58 -07:00
parent cdb1ca9697
commit b07e837a20
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 199 additions and 10 deletions

View file

@ -1,7 +1,7 @@
use common::{
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
configuration::LlmRoute,
consts::{SYSTEM_ROLE, USER_ROLE},
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
};
use serde::{Deserialize, Serialize};
use tracing::{debug, warn};
@ -56,9 +56,12 @@ const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for U
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
// when role == tool its tool call response
let messages_vec = messages
.iter()
.filter(|m| m.role != SYSTEM_ROLE)
.filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some())
.collect::<Vec<&Message>>();
// Following code is to ensure that the conversation does not exceed max token length
@ -114,11 +117,16 @@ impl RouterModel for RouterModelV1 {
}
// Reverse the selected messages to maintain the conversation order
let selected_conversation_list = selected_messages_list_reversed
.iter()
.rev()
.map(|message| (*message).clone())
.map(|message| {
Message::new(
message.role.clone(),
// we can unwrap here because we have already filtered out messages without content
message.content.as_ref().unwrap().to_string(),
)
})
.collect::<Vec<Message>>();
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
@ -457,6 +465,172 @@ Based on your analysis, provide your response in the following JSON formats if y
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_non_text_input() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
</routes>
<conversation>
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "hi"
},
{
"type": "image_url",
"image_url": {
"url": "https://example.com/image.png"
}
}
]
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_skip_tool_call() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
</routes>
<conversation>
[{"role":"user","content":"What's the weather like in Tokyo?"},{"role":"assistant","content":"The current weather in Tokyo is 22°C and sunny."},{"role":"user","content":"What about in New York?"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": "What's the weather like in Tokyo?"
},
{
"role": "assistant",
"content": null,
"tool_calls": [
{
"id": "toolcall-abc123",
"type": "function",
"function": {
"name": "get_weather",
"arguments": { "location": "Tokyo" }
}
}
]
},
{
"role": "tool",
"tool_call_id": "toolcall-abc123",
"content": "{ \"temperature\": \"22°C\", \"condition\": \"Sunny\" }"
},
{
"role": "assistant",
"content": "The current weather in Tokyo is 22°C and sunny."
},
{
"role": "user",
"content": "What about in New York?"
}
]
"#;
// expects conversation to look like this
// [
// {
// "role": "user",
// "content": "What's the weather like in Tokyo?"
// },
// {
// "role": "assistant",
// "content": "The current weather in Tokyo is 22°C and sunny."
// },
// {
// "role": "user",
// "content": "What about in New York?"
// }
// ]
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_parse_response() {
let routes_str = r#"

View file

@ -6,7 +6,6 @@ use crate::{
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HallucinationClassificationRequest {
pub prompt: String,

View file

@ -162,6 +162,8 @@ pub struct StreamOptions {
pub enum MultiPartContentType {
#[serde(rename = "text")]
Text,
#[serde(rename = "image_url")]
ImageUrl,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
@ -188,6 +190,9 @@ impl Display for ContentType {
.filter_map(|part| {
if part.content_type == MultiPartContentType::Text {
part.text.clone()
} else if part.content_type == MultiPartContentType::ImageUrl {
// skip image URLs or their data in text representation
None
} else {
panic!("Unsupported content type: {:?}", part.content_type);
}
@ -217,6 +222,19 @@ pub struct Message {
pub tool_call_id: Option<String>,
}
impl Message {
pub fn new(role: String, content: String) -> Self {
let content = Some(ContentType::Text(content));
Message {
role,
content,
model: None,
tool_calls: None,
tool_call_id: None,
}
}
}
impl Default for Message {
fn default() -> Self {
Message {

View file

@ -237,9 +237,7 @@ impl HttpContext for StreamContext {
Duration::from_secs(5),
);
if let Some(content) =
self.user_prompt.as_ref().unwrap().content.as_ref()
{
if let Some(content) = self.user_prompt.as_ref().unwrap().content.as_ref() {
let call_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchFC,
user_message: Some(content.to_string()),
@ -262,7 +260,6 @@ impl HttpContext for StreamContext {
);
}
Action::Pause
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {

View file

@ -1,5 +1,6 @@
use common::api::open_ai::{
ChatCompletionsResponse, Choice, ContentType, FunctionCallDetail, Message, ToolCall, ToolType, Usage
ChatCompletionsResponse, Choice, ContentType, FunctionCallDetail, Message, ToolCall, ToolType,
Usage,
};
use common::configuration::Configuration;
use http::StatusCode;