mirror of
https://github.com/katanemo/plano.git
synced 2026-05-07 06:42:42 +02:00
pass model name in header when a route is selected when using usage preferences (#531)
This commit is contained in:
parent
2340a45353
commit
f819ee3507
4 changed files with 150 additions and 117 deletions
|
|
@ -104,7 +104,7 @@ pub async fn chat_completions(
|
||||||
|
|
||||||
debug!("usage preferences from request: {:?}", usage_preferences);
|
debug!("usage preferences from request: {:?}", usage_preferences);
|
||||||
|
|
||||||
let mut determined_route = match router_service
|
let determined_model = match router_service
|
||||||
.determine_route(
|
.determine_route(
|
||||||
&chat_completion_request.messages,
|
&chat_completion_request.messages,
|
||||||
trace_parent.clone(),
|
trace_parent.clone(),
|
||||||
|
|
@ -121,14 +121,17 @@ pub async fn chat_completions(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
if determined_route.is_none() {
|
|
||||||
debug!("No LLM model selected, using default from request");
|
|
||||||
determined_route = Some(chat_completion_request.model.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"sending request to llm provider: {} with llm model: {:?}",
|
"sending request to llm provider: {} determined_model: {:?}, model from request: {}",
|
||||||
llm_provider_endpoint, determined_route
|
llm_provider_endpoint, determined_model, chat_completion_request.model
|
||||||
|
);
|
||||||
|
|
||||||
|
request_headers.insert(
|
||||||
|
ARCH_PROVIDER_HINT_HEADER,
|
||||||
|
header::HeaderValue::from_str(
|
||||||
|
&determined_model.unwrap_or(chat_completion_request.model.clone()),
|
||||||
|
)
|
||||||
|
.unwrap(),
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(trace_parent) = trace_parent {
|
if let Some(trace_parent) = trace_parent {
|
||||||
|
|
@ -138,13 +141,6 @@ pub async fn chat_completions(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(selected_route) = determined_route {
|
|
||||||
request_headers.insert(
|
|
||||||
ARCH_PROVIDER_HINT_HEADER,
|
|
||||||
header::HeaderValue::from_str(&selected_route).unwrap(),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let chat_request_parsed_bytes =
|
let chat_request_parsed_bytes =
|
||||||
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();
|
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use std::sync::Arc;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||||
|
|
@ -48,9 +48,14 @@ impl RouterService {
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect::<Vec<LlmProvider>>();
|
.collect::<Vec<LlmProvider>>();
|
||||||
|
|
||||||
let llm_routes: Vec<RoutingPreference> = providers_with_usage
|
let llm_routes: HashMap<String, Vec<RoutingPreference>> = providers_with_usage
|
||||||
.iter()
|
.iter()
|
||||||
.flat_map(|provider| provider.routing_preferences.clone().unwrap_or_default())
|
.filter_map(|provider| {
|
||||||
|
provider
|
||||||
|
.routing_preferences
|
||||||
|
.as_ref()
|
||||||
|
.map(|prefs| (provider.name.clone(), prefs.clone()))
|
||||||
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||||
|
|
@ -151,21 +156,22 @@ impl RouterService {
|
||||||
if let Some(ContentType::Text(content)) =
|
if let Some(ContentType::Text(content)) =
|
||||||
&chat_completion_response.choices[0].message.content
|
&chat_completion_response.choices[0].message.content
|
||||||
{
|
{
|
||||||
let route_name = self.router_model.parse_response(content)?;
|
let parsed_response = self
|
||||||
|
.router_model
|
||||||
|
.parse_response(content, &usage_preferences)?;
|
||||||
info!(
|
info!(
|
||||||
"router response: {}, selected_model: {:?}, response time: {}ms",
|
"router response: {}, selected_model: {:?}, response time: {}ms",
|
||||||
content.replace("\n", "\\n"),
|
content.replace("\n", "\\n"),
|
||||||
route_name,
|
parsed_response,
|
||||||
router_response_time.as_millis()
|
router_response_time.as_millis()
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Some(ref route) = route_name {
|
if let Some(ref route) = parsed_response {
|
||||||
if route == "other" {
|
// return model name if route is found
|
||||||
return Ok(None);
|
return Ok(Some(route.1.clone()));
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(route_name)
|
Ok(None)
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,10 @@ pub trait RouterModel: Send + Sync {
|
||||||
messages: &[Message],
|
messages: &[Message],
|
||||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||||
) -> ChatCompletionsRequest;
|
) -> ChatCompletionsRequest;
|
||||||
fn parse_response(&self, content: &str) -> Result<Option<String>>;
|
fn parse_response(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||||
|
) -> Result<Option<(String, String)>>;
|
||||||
fn get_model_name(&self) -> String;
|
fn get_model_name(&self) -> String;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use common::{
|
use common::{
|
||||||
configuration::{ModelUsagePreference, RoutingPreference},
|
configuration::{ModelUsagePreference, RoutingPreference},
|
||||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
||||||
|
|
@ -32,21 +34,30 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||||
pub struct RouterModelV1 {
|
pub struct RouterModelV1 {
|
||||||
llm_route_json_str: String,
|
llm_route_json_str: String,
|
||||||
|
llm_route_to_model_map: HashMap<String, String>,
|
||||||
routing_model: String,
|
routing_model: String,
|
||||||
max_token_length: usize,
|
max_token_length: usize,
|
||||||
}
|
}
|
||||||
impl RouterModelV1 {
|
impl RouterModelV1 {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
llm_routes: Vec<RoutingPreference>,
|
llm_routes: HashMap<String, Vec<RoutingPreference>>,
|
||||||
routing_model: String,
|
routing_model: String,
|
||||||
max_token_length: usize,
|
max_token_length: usize,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
|
let llm_route_values: Vec<RoutingPreference> =
|
||||||
|
llm_routes.values().flatten().cloned().collect();
|
||||||
let llm_route_json_str =
|
let llm_route_json_str =
|
||||||
serde_json::to_string(&llm_routes).unwrap_or_else(|_| "[]".to_string());
|
serde_json::to_string(&llm_route_values).unwrap_or_else(|_| "[]".to_string());
|
||||||
|
let llm_route_to_model_map: HashMap<String, String> = llm_routes
|
||||||
|
.iter()
|
||||||
|
.flat_map(|(model, prefs)| prefs.iter().map(|pref| (pref.name.clone(), model.clone())))
|
||||||
|
.collect();
|
||||||
|
|
||||||
RouterModelV1 {
|
RouterModelV1 {
|
||||||
routing_model,
|
routing_model,
|
||||||
max_token_length,
|
max_token_length,
|
||||||
llm_route_json_str,
|
llm_route_json_str,
|
||||||
|
llm_route_to_model_map,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -171,20 +182,51 @@ impl RouterModel for RouterModelV1 {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_response(&self, content: &str) -> Result<Option<String>> {
|
fn parse_response(
|
||||||
|
&self,
|
||||||
|
content: &str,
|
||||||
|
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||||
|
) -> Result<Option<(String, String)>> {
|
||||||
if content.is_empty() {
|
if content.is_empty() {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
let router_resp_fixed = fix_json_response(content);
|
let router_resp_fixed = fix_json_response(content);
|
||||||
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
|
let router_response: LlmRouterResponse = serde_json::from_str(router_resp_fixed.as_str())?;
|
||||||
|
|
||||||
let selected_llm = router_response.route.unwrap_or_default().to_string();
|
let selected_route = router_response.route.unwrap_or_default().to_string();
|
||||||
|
|
||||||
if selected_llm.is_empty() {
|
if selected_route.is_empty() || selected_route == "other" {
|
||||||
return Ok(None);
|
return Ok(None);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Some(selected_llm))
|
if let Some(usage_preferences) = usage_preferences {
|
||||||
|
// If usage preferences are defined, we need to find the model that matches the selected route
|
||||||
|
let matching_preference = usage_preferences
|
||||||
|
.iter()
|
||||||
|
.find(|pref| pref.name == selected_route);
|
||||||
|
|
||||||
|
if let Some(preference) = matching_preference {
|
||||||
|
return Ok(Some((selected_route, preference.model.clone())));
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"No matching model found for route: {}, usage preferences: {:?}",
|
||||||
|
selected_route, usage_preferences
|
||||||
|
);
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If no usage preferences are defined, we return the route with the routing model
|
||||||
|
if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() {
|
||||||
|
return Ok(Some((selected_route, model)));
|
||||||
|
}
|
||||||
|
|
||||||
|
warn!(
|
||||||
|
"No model found for route: {}, router model preferences: {:?}",
|
||||||
|
selected_route, self.llm_route_to_model_map
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_model_name(&self) -> String {
|
fn get_model_name(&self) -> String {
|
||||||
|
|
@ -235,7 +277,7 @@ mod tests {
|
||||||
You are a helpful assistant designed to find the best suited route.
|
You are a helpful assistant designed to find the best suited route.
|
||||||
You are provided with route description within <routes></routes> XML tags:
|
You are provided with route description within <routes></routes> XML tags:
|
||||||
<routes>
|
<routes>
|
||||||
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
|
[{"name":"Image generation","description":"generating image"}]
|
||||||
</routes>
|
</routes>
|
||||||
|
|
||||||
<conversation>
|
<conversation>
|
||||||
|
|
@ -251,15 +293,13 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
{"route": "route_name"}
|
{"route": "route_name"}
|
||||||
"#;
|
"#;
|
||||||
let routes_str = r#"
|
let routes_str = r#"
|
||||||
[
|
{
|
||||||
{"name": "Image generation", "description": "generating image"},
|
"gpt-4o": [
|
||||||
{"name": "image conversion", "description": "convert images to provided format"},
|
{"name": "Image generation", "description": "generating image"}
|
||||||
{"name": "image search", "description": "search image"},
|
]
|
||||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
}
|
||||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
|
||||||
]
|
|
||||||
"#;
|
"#;
|
||||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||||
let routing_model = "test-model".to_string();
|
let routing_model = "test-model".to_string();
|
||||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||||
|
|
||||||
|
|
@ -310,15 +350,13 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
{"route": "route_name"}
|
{"route": "route_name"}
|
||||||
"#;
|
"#;
|
||||||
let routes_str = r#"
|
let routes_str = r#"
|
||||||
[
|
{
|
||||||
{"name": "Image generation", "description": "generating image"},
|
"gpt-4o": [
|
||||||
{"name": "image conversion", "description": "convert images to provided format"},
|
{"name": "Image generation", "description": "generating image"}
|
||||||
{"name": "image search", "description": "search image"},
|
]
|
||||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
}
|
||||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
|
||||||
]
|
|
||||||
"#;
|
"#;
|
||||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||||
let routing_model = "test-model".to_string();
|
let routing_model = "test-model".to_string();
|
||||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||||
|
|
||||||
|
|
@ -358,7 +396,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
You are a helpful assistant designed to find the best suited route.
|
You are a helpful assistant designed to find the best suited route.
|
||||||
You are provided with route description within <routes></routes> XML tags:
|
You are provided with route description within <routes></routes> XML tags:
|
||||||
<routes>
|
<routes>
|
||||||
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
|
[{"name":"Image generation","description":"generating image"}]
|
||||||
</routes>
|
</routes>
|
||||||
|
|
||||||
<conversation>
|
<conversation>
|
||||||
|
|
@ -375,15 +413,13 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let routes_str = r#"
|
let routes_str = r#"
|
||||||
[
|
{
|
||||||
{"name": "Image generation", "description": "generating image"},
|
"gpt-4o": [
|
||||||
{"name": "image conversion", "description": "convert images to provided format"},
|
{"name": "Image generation", "description": "generating image"}
|
||||||
{"name": "image search", "description": "search image"},
|
]
|
||||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
}
|
||||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
|
||||||
]
|
|
||||||
"#;
|
"#;
|
||||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||||
let routing_model = "test-model".to_string();
|
let routing_model = "test-model".to_string();
|
||||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
|
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
|
||||||
|
|
||||||
|
|
@ -419,7 +455,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
You are a helpful assistant designed to find the best suited route.
|
You are a helpful assistant designed to find the best suited route.
|
||||||
You are provided with route description within <routes></routes> XML tags:
|
You are provided with route description within <routes></routes> XML tags:
|
||||||
<routes>
|
<routes>
|
||||||
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
|
[{"name":"Image generation","description":"generating image"}]
|
||||||
</routes>
|
</routes>
|
||||||
|
|
||||||
<conversation>
|
<conversation>
|
||||||
|
|
@ -436,15 +472,14 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let routes_str = r#"
|
let routes_str = r#"
|
||||||
[
|
{
|
||||||
{"name": "Image generation", "description": "generating image"},
|
"gpt-4o": [
|
||||||
{"name": "image conversion", "description": "convert images to provided format"},
|
{"name": "Image generation", "description": "generating image"}
|
||||||
{"name": "image search", "description": "search image"},
|
]
|
||||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
}
|
||||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
|
||||||
]
|
|
||||||
"#;
|
"#;
|
||||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||||
|
|
||||||
let routing_model = "test-model".to_string();
|
let routing_model = "test-model".to_string();
|
||||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
|
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
|
||||||
|
|
||||||
|
|
@ -480,7 +515,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
You are a helpful assistant designed to find the best suited route.
|
You are a helpful assistant designed to find the best suited route.
|
||||||
You are provided with route description within <routes></routes> XML tags:
|
You are provided with route description within <routes></routes> XML tags:
|
||||||
<routes>
|
<routes>
|
||||||
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
|
[{"name":"Image generation","description":"generating image"}]
|
||||||
</routes>
|
</routes>
|
||||||
|
|
||||||
<conversation>
|
<conversation>
|
||||||
|
|
@ -497,15 +532,13 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
"#;
|
"#;
|
||||||
|
|
||||||
let routes_str = r#"
|
let routes_str = r#"
|
||||||
[
|
{
|
||||||
{"name": "Image generation", "description": "generating image"},
|
"gpt-4o": [
|
||||||
{"name": "image conversion", "description": "convert images to provided format"},
|
{"name": "Image generation", "description": "generating image"}
|
||||||
{"name": "image search", "description": "search image"},
|
]
|
||||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
}
|
||||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
|
||||||
]
|
|
||||||
"#;
|
"#;
|
||||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||||
let routing_model = "test-model".to_string();
|
let routing_model = "test-model".to_string();
|
||||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
|
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
|
||||||
|
|
||||||
|
|
@ -549,7 +582,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
You are a helpful assistant designed to find the best suited route.
|
You are a helpful assistant designed to find the best suited route.
|
||||||
You are provided with route description within <routes></routes> XML tags:
|
You are provided with route description within <routes></routes> XML tags:
|
||||||
<routes>
|
<routes>
|
||||||
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
|
[{"name":"Image generation","description":"generating image"}]
|
||||||
</routes>
|
</routes>
|
||||||
|
|
||||||
<conversation>
|
<conversation>
|
||||||
|
|
@ -565,15 +598,13 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
{"route": "route_name"}
|
{"route": "route_name"}
|
||||||
"#;
|
"#;
|
||||||
let routes_str = r#"
|
let routes_str = r#"
|
||||||
[
|
{
|
||||||
{"name": "Image generation", "description": "generating image"},
|
"gpt-4o": [
|
||||||
{"name": "image conversion", "description": "convert images to provided format"},
|
{"name": "Image generation", "description": "generating image"}
|
||||||
{"name": "image search", "description": "search image"},
|
]
|
||||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
}
|
||||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
|
||||||
]
|
|
||||||
"#;
|
"#;
|
||||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||||
let routing_model = "test-model".to_string();
|
let routing_model = "test-model".to_string();
|
||||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||||
|
|
||||||
|
|
@ -619,7 +650,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
You are a helpful assistant designed to find the best suited route.
|
You are a helpful assistant designed to find the best suited route.
|
||||||
You are provided with route description within <routes></routes> XML tags:
|
You are provided with route description within <routes></routes> XML tags:
|
||||||
<routes>
|
<routes>
|
||||||
[{"name":"Image generation","description":"generating image"},{"name":"image conversion","description":"convert images to provided format"},{"name":"image search","description":"search image"},{"name":"Audio Processing","description":"Analyzing and interpreting audio input including speech, music, and environmental sounds"},{"name":"Speech Recognition","description":"Converting spoken language into written text"}]
|
[{"name":"Image generation","description":"generating image"}]
|
||||||
</routes>
|
</routes>
|
||||||
|
|
||||||
<conversation>
|
<conversation>
|
||||||
|
|
@ -635,15 +666,13 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
{"route": "route_name"}
|
{"route": "route_name"}
|
||||||
"#;
|
"#;
|
||||||
let routes_str = r#"
|
let routes_str = r#"
|
||||||
[
|
{
|
||||||
{"name": "Image generation", "description": "generating image"},
|
"gpt-4o": [
|
||||||
{"name": "image conversion", "description": "convert images to provided format"},
|
{"name": "Image generation", "description": "generating image"}
|
||||||
{"name": "image search", "description": "search image"},
|
]
|
||||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
}
|
||||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
|
||||||
]
|
|
||||||
"#;
|
"#;
|
||||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||||
let routing_model = "test-model".to_string();
|
let routing_model = "test-model".to_string();
|
||||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||||
|
|
||||||
|
|
@ -712,56 +741,54 @@ Based on your analysis, provide your response in the following JSON formats if y
|
||||||
#[test]
|
#[test]
|
||||||
fn test_parse_response() {
|
fn test_parse_response() {
|
||||||
let routes_str = r#"
|
let routes_str = r#"
|
||||||
[
|
{
|
||||||
{"name": "Image generation", "description": "generating image"},
|
"gpt-4o": [
|
||||||
{"name": "image conversion", "description": "convert images to provided format"},
|
{"name": "Image generation", "description": "generating image"}
|
||||||
{"name": "image search", "description": "search image"},
|
]
|
||||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
}
|
||||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
"#;
|
||||||
]
|
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
|
||||||
"#;
|
|
||||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
|
||||||
|
|
||||||
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
||||||
|
|
||||||
// Case 1: Valid JSON with non-empty route
|
// Case 1: Valid JSON with non-empty route
|
||||||
let input = r#"{"route": "route1"}"#;
|
let input = r#"{"route": "Image generation"}"#;
|
||||||
let result = router.parse_response(input).unwrap();
|
let result = router.parse_response(input, &None).unwrap();
|
||||||
assert_eq!(result, Some("route1".to_string()));
|
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
|
||||||
|
|
||||||
// Case 2: Valid JSON with empty route
|
// Case 2: Valid JSON with empty route
|
||||||
let input = r#"{"route": ""}"#;
|
let input = r#"{"route": ""}"#;
|
||||||
let result = router.parse_response(input).unwrap();
|
let result = router.parse_response(input, &None).unwrap();
|
||||||
assert_eq!(result, None);
|
assert_eq!(result, None);
|
||||||
|
|
||||||
// Case 3: Valid JSON with null route
|
// Case 3: Valid JSON with null route
|
||||||
let input = r#"{"route": null}"#;
|
let input = r#"{"route": null}"#;
|
||||||
let result = router.parse_response(input).unwrap();
|
let result = router.parse_response(input, &None).unwrap();
|
||||||
assert_eq!(result, None);
|
assert_eq!(result, None);
|
||||||
|
|
||||||
// Case 4: JSON missing route field
|
// Case 4: JSON missing route field
|
||||||
let input = r#"{}"#;
|
let input = r#"{}"#;
|
||||||
let result = router.parse_response(input).unwrap();
|
let result = router.parse_response(input, &None).unwrap();
|
||||||
assert_eq!(result, None);
|
assert_eq!(result, None);
|
||||||
|
|
||||||
// Case 4.1: empty string
|
// Case 4.1: empty string
|
||||||
let input = r#""#;
|
let input = r#""#;
|
||||||
let result = router.parse_response(input).unwrap();
|
let result = router.parse_response(input, &None).unwrap();
|
||||||
assert_eq!(result, None);
|
assert_eq!(result, None);
|
||||||
|
|
||||||
// Case 5: Malformed JSON
|
// Case 5: Malformed JSON
|
||||||
let input = r#"{"route": "route1""#; // missing closing }
|
let input = r#"{"route": "route1""#; // missing closing }
|
||||||
let result = router.parse_response(input);
|
let result = router.parse_response(input, &None);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
|
|
||||||
// Case 6: Single quotes and \n in JSON
|
// Case 6: Single quotes and \n in JSON
|
||||||
let input = "{'route': 'route2'}\\n";
|
let input = "{'route': 'Image generation'}\\n";
|
||||||
let result = router.parse_response(input).unwrap();
|
let result = router.parse_response(input, &None).unwrap();
|
||||||
assert_eq!(result, Some("route2".to_string()));
|
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
|
||||||
|
|
||||||
// Case 7: Code block marker
|
// Case 7: Code block marker
|
||||||
let input = "```json\n{\"route\": \"route1\"}\n```";
|
let input = "```json\n{\"route\": \"Image generation\"}\n```";
|
||||||
let result = router.parse_response(input).unwrap();
|
let result = router.parse_response(input, &None).unwrap();
|
||||||
assert_eq!(result, Some("route1".to_string()));
|
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue