diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index 1320952d..30945a96 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,7 +1,7 @@ use common::{ api::open_ai::{ChatCompletionsRequest, ContentType, Message}, configuration::LlmRoute, - consts::{SYSTEM_ROLE, USER_ROLE}, + consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE}, }; use serde::{Deserialize, Serialize}; use tracing::{debug, warn}; @@ -56,9 +56,12 @@ const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for U impl RouterModel for RouterModelV1 { fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest { + // remove system prompt, tool calls, tool call response and messages without content + // if content is empty its likely a tool call + // when role == tool its tool call response let messages_vec = messages .iter() - .filter(|m| m.role != SYSTEM_ROLE) + .filter(|m| m.role != SYSTEM_ROLE && m.role != TOOL_ROLE && m.content.is_some()) .collect::>(); // Following code is to ensure that the conversation does not exceed max token length @@ -114,11 +117,16 @@ impl RouterModel for RouterModelV1 { } // Reverse the selected messages to maintain the conversation order - let selected_conversation_list = selected_messages_list_reversed .iter() .rev() - .map(|message| (*message).clone()) + .map(|message| { + Message::new( + message.role.clone(), + // we can unwrap here because we have already filtered out messages without content + message.content.as_ref().unwrap().to_string(), + ) + }) .collect::>(); let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT @@ -457,6 +465,172 @@ Based on your analysis, provide your response in the following JSON formats if y assert_eq!(expected_prompt, prompt.to_string()); } + #[test] + fn test_non_text_input() { + let expected_prompt = r#" +You are a helpful assistant designed to find the best suited route. +You are provided with route description within XML tags: + +[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] + + + +[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}] + + +Your task is to decide which route is best suit with user intent on the conversation in XML tags. Follow the instruction: +1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}. +2. You must analyze the route descriptions and find the best match route for user latest intent. +3. You only response the name of the route that best matches the user's request, use the exact name in the . + +Based on your analysis, provide your response in the following JSON formats if you decide to match any route: +{"route": "route_name"} +"#; + let routes_str = r#" + [ + {"name": "Image generation", "description": "generating image"}, + {"name": "image conversion", "description": "convert images to provided format"}, + {"name": "image search", "description": "search image"}, + {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, + {"name": "Speech Recognition", "description": "Converting spoken language into written text"} + ] + "#; + let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let routing_model = "test-model".to_string(); + let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); + + let conversation_str = r#" + [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "hi" + }, + { + "type": "image_url", + "image_url": { + "url": "https://example.com/image.png" + } + } + ] + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" + } + ] + "#; + let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + + let req = router.generate_request(&conversation); + + let prompt = req.messages[0].content.as_ref().unwrap(); + + assert_eq!(expected_prompt, prompt.to_string()); + } + + + #[test] + fn test_skip_tool_call() { + let expected_prompt = r#" +You are a helpful assistant designed to find the best suited route. +You are provided with route description within XML tags: + +[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}] + + + +[{"role":"user","content":"What's the weather like in Tokyo?"},{"role":"assistant","content":"The current weather in Tokyo is 22°C and sunny."},{"role":"user","content":"What about in New York?"}] + + +Your task is to decide which route is best suit with user intent on the conversation in XML tags. Follow the instruction: +1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}. +2. You must analyze the route descriptions and find the best match route for user latest intent. +3. You only response the name of the route that best matches the user's request, use the exact name in the . + +Based on your analysis, provide your response in the following JSON formats if you decide to match any route: +{"route": "route_name"} +"#; + let routes_str = r#" + [ + {"name": "Image generation", "description": "generating image"}, + {"name": "image conversion", "description": "convert images to provided format"}, + {"name": "image search", "description": "search image"}, + {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, + {"name": "Speech Recognition", "description": "Converting spoken language into written text"} + ] + "#; + let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let routing_model = "test-model".to_string(); + let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); + + let conversation_str = r#" + [ + { + "role": "user", + "content": "What's the weather like in Tokyo?" + }, + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "toolcall-abc123", + "type": "function", + "function": { + "name": "get_weather", + "arguments": { "location": "Tokyo" } + } + } + ] + }, + { + "role": "tool", + "tool_call_id": "toolcall-abc123", + "content": "{ \"temperature\": \"22°C\", \"condition\": \"Sunny\" }" + }, + { + "role": "assistant", + "content": "The current weather in Tokyo is 22°C and sunny." + }, + { + "role": "user", + "content": "What about in New York?" + } + ] + "#; + + // expects conversation to look like this + +// [ +// { +// "role": "user", +// "content": "What's the weather like in Tokyo?" +// }, +// { +// "role": "assistant", +// "content": "The current weather in Tokyo is 22°C and sunny." +// }, +// { +// "role": "user", +// "content": "What about in New York?" +// } +// ] + let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + + let req = router.generate_request(&conversation); + + let prompt = req.messages[0].content.as_ref().unwrap(); + + assert_eq!(expected_prompt, prompt.to_string()); + } + #[test] fn test_parse_response() { let routes_str = r#" diff --git a/crates/common/src/api/hallucination.rs b/crates/common/src/api/hallucination.rs index e90ea165..41ccf3d7 100644 --- a/crates/common/src/api/hallucination.rs +++ b/crates/common/src/api/hallucination.rs @@ -6,7 +6,6 @@ use crate::{ }; use serde::{Deserialize, Serialize}; - #[derive(Debug, Clone, Serialize, Deserialize)] pub struct HallucinationClassificationRequest { pub prompt: String, diff --git a/crates/common/src/api/open_ai.rs b/crates/common/src/api/open_ai.rs index d5d4ce2a..080923c1 100644 --- a/crates/common/src/api/open_ai.rs +++ b/crates/common/src/api/open_ai.rs @@ -162,6 +162,8 @@ pub struct StreamOptions { pub enum MultiPartContentType { #[serde(rename = "text")] Text, + #[serde(rename = "image_url")] + ImageUrl, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] @@ -188,6 +190,9 @@ impl Display for ContentType { .filter_map(|part| { if part.content_type == MultiPartContentType::Text { part.text.clone() + } else if part.content_type == MultiPartContentType::ImageUrl { + // skip image URLs or their data in text representation + None } else { panic!("Unsupported content type: {:?}", part.content_type); } @@ -217,6 +222,19 @@ pub struct Message { pub tool_call_id: Option, } +impl Message { + pub fn new(role: String, content: String) -> Self { + let content = Some(ContentType::Text(content)); + Message { + role, + content, + model: None, + tool_calls: None, + tool_call_id: None, + } + } +} + impl Default for Message { fn default() -> Self { Message { diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index bb673208..cd251064 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -237,9 +237,7 @@ impl HttpContext for StreamContext { Duration::from_secs(5), ); - if let Some(content) = - self.user_prompt.as_ref().unwrap().content.as_ref() - { + if let Some(content) = self.user_prompt.as_ref().unwrap().content.as_ref() { let call_context = StreamCallContext { response_handler_type: ResponseHandlerType::ArchFC, user_message: Some(content.to_string()), @@ -262,7 +260,6 @@ impl HttpContext for StreamContext { ); } Action::Pause - } fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 563c9393..e749a007 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -1,5 +1,6 @@ use common::api::open_ai::{ - ChatCompletionsResponse, Choice, ContentType, FunctionCallDetail, Message, ToolCall, ToolType, Usage + ChatCompletionsResponse, Choice, ContentType, FunctionCallDetail, Message, ToolCall, ToolType, + Usage, }; use common::configuration::Configuration; use http::StatusCode;