mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add more tests
This commit is contained in:
parent
cdb1ca9697
commit
b07e837a20
5 changed files with 199 additions and 10 deletions
|
|
@ -1,7 +1,7 @@
|
|||
use common::{
|
||||
api::open_ai::{ChatCompletionsRequest, ContentType, Message},
|
||||
configuration::LlmRoute,
|
||||
consts::{SYSTEM_ROLE, USER_ROLE},
|
||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use tracing::{debug, warn};
|
||||
|
|
@ -56,9 +56,12 @@ const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for U
|
|||
|
||||
impl RouterModel for RouterModelV1 {
|
||||
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
let messages_vec = messages
|
||||
.iter()
|
||||
.filter(|m| m.role != SYSTEM_ROLE)
|
||||
.filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some())
|
||||
.collect::<Vec<&Message>>();
|
||||
|
||||
// Following code is to ensure that the conversation does not exceed max token length
|
||||
|
|
@ -114,11 +117,16 @@ impl RouterModel for RouterModelV1 {
|
|||
}
|
||||
|
||||
// Reverse the selected messages to maintain the conversation order
|
||||
|
||||
let selected_conversation_list = selected_messages_list_reversed
|
||||
.iter()
|
||||
.rev()
|
||||
.map(|message| (*message).clone())
|
||||
.map(|message| {
|
||||
Message::new(
|
||||
message.role.clone(),
|
||||
// we can unwrap here because we have already filtered out messages without content
|
||||
message.content.as_ref().unwrap().to_string(),
|
||||
)
|
||||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
|
|
@ -457,6 +465,172 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_non_text_input() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
[
|
||||
{"name": "Image generation", "description": "generating image"},
|
||||
{"name": "image conversion", "description": "convert images to provided format"},
|
||||
{"name": "image search", "description": "search image"},
|
||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "hi"
|
||||
},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://example.com/image.png"
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
}
|
||||
|
||||
|
||||
#[test]
|
||||
fn test_skip_tool_call() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"What's the weather like in Tokyo?"},{"role":"assistant","content":"The current weather in Tokyo is 22°C and sunny."},{"role":"user","content":"What about in New York?"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
[
|
||||
{"name": "Image generation", "description": "generating image"},
|
||||
{"name": "image conversion", "description": "convert images to provided format"},
|
||||
{"name": "image search", "description": "search image"},
|
||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What's the weather like in Tokyo?"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": null,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "toolcall-abc123",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": { "location": "Tokyo" }
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": "toolcall-abc123",
|
||||
"content": "{ \"temperature\": \"22°C\", \"condition\": \"Sunny\" }"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "The current weather in Tokyo is 22°C and sunny."
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "What about in New York?"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
|
||||
// expects conversation to look like this
|
||||
|
||||
// [
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "What's the weather like in Tokyo?"
|
||||
// },
|
||||
// {
|
||||
// "role": "assistant",
|
||||
// "content": "The current weather in Tokyo is 22°C and sunny."
|
||||
// },
|
||||
// {
|
||||
// "role": "user",
|
||||
// "content": "What about in New York?"
|
||||
// }
|
||||
// ]
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_parse_response() {
|
||||
let routes_str = r#"
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@ use crate::{
|
|||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HallucinationClassificationRequest {
|
||||
pub prompt: String,
|
||||
|
|
|
|||
|
|
@ -162,6 +162,8 @@ pub struct StreamOptions {
|
|||
pub enum MultiPartContentType {
|
||||
#[serde(rename = "text")]
|
||||
Text,
|
||||
#[serde(rename = "image_url")]
|
||||
ImageUrl,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
|
||||
|
|
@ -188,6 +190,9 @@ impl Display for ContentType {
|
|||
.filter_map(|part| {
|
||||
if part.content_type == MultiPartContentType::Text {
|
||||
part.text.clone()
|
||||
} else if part.content_type == MultiPartContentType::ImageUrl {
|
||||
// skip image URLs or their data in text representation
|
||||
None
|
||||
} else {
|
||||
panic!("Unsupported content type: {:?}", part.content_type);
|
||||
}
|
||||
|
|
@ -217,6 +222,19 @@ pub struct Message {
|
|||
pub tool_call_id: Option<String>,
|
||||
}
|
||||
|
||||
impl Message {
|
||||
pub fn new(role: String, content: String) -> Self {
|
||||
let content = Some(ContentType::Text(content));
|
||||
Message {
|
||||
role,
|
||||
content,
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
tool_call_id: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for Message {
|
||||
fn default() -> Self {
|
||||
Message {
|
||||
|
|
|
|||
|
|
@ -237,9 +237,7 @@ impl HttpContext for StreamContext {
|
|||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
if let Some(content) =
|
||||
self.user_prompt.as_ref().unwrap().content.as_ref()
|
||||
{
|
||||
if let Some(content) = self.user_prompt.as_ref().unwrap().content.as_ref() {
|
||||
let call_context = StreamCallContext {
|
||||
response_handler_type: ResponseHandlerType::ArchFC,
|
||||
user_message: Some(content.to_string()),
|
||||
|
|
@ -262,7 +260,6 @@ impl HttpContext for StreamContext {
|
|||
);
|
||||
}
|
||||
Action::Pause
|
||||
|
||||
}
|
||||
|
||||
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use common::api::open_ai::{
|
||||
ChatCompletionsResponse, Choice, ContentType, FunctionCallDetail, Message, ToolCall, ToolType, Usage
|
||||
ChatCompletionsResponse, Choice, ContentType, FunctionCallDetail, Message, ToolCall, ToolType,
|
||||
Usage,
|
||||
};
|
||||
use common::configuration::Configuration;
|
||||
use http::StatusCode;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue