add more changes

This commit is contained in:
Adil Hafeez 2025-06-25 13:44:15 -07:00
parent 4373aeb00b
commit 6ba7140f62
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
6 changed files with 128 additions and 18 deletions

View file

@ -12,7 +12,7 @@ use hyper::{Request, Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, info, warn};
use tracing::{debug, info, trace, warn};
use crate::router::llm_router::RouterService;
@ -47,7 +47,7 @@ pub async fn chat_completions(
}
};
debug!(
trace!(
"arch-router request body: {}",
&serde_json::to_string(&chat_completion_request).unwrap()
);

View file

@ -14,7 +14,8 @@ pub async fn list_preferences(
let providers_with_usage = prov
.iter()
.map(|provider| ModelUsagePreference {
model: provider.name.clone(),
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
})
.collect::<Vec<ModelUsagePreference>>();
@ -101,7 +102,8 @@ pub async fn update_preferences(
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
provider.usage = usage_provider.usage.clone();
updated_models_list.push(ModelUsagePreference {
model: provider.name.clone(),
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
});
}

View file

@ -1,4 +1,4 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
@ -19,6 +19,7 @@ pub struct RouterService {
router_model: Arc<dyn RouterModel>,
routing_model_name: String,
llm_usage_defined: bool,
llm_provider_map: HashMap<String, LlmProvider>,
}
#[derive(Debug, Error)]
@ -55,12 +56,18 @@ impl RouterService {
router_model_v1::MAX_TOKEN_LEN,
));
let llm_provider_map: HashMap<String, LlmProvider> = providers
.into_iter()
.map(|provider| (provider.name.clone(), provider))
.collect();
RouterService {
router_url,
client: reqwest::Client::new(),
router_model,
routing_model_name,
llm_usage_defined: !providers_with_usage.is_empty(),
llm_provider_map,
}
}
@ -76,7 +83,7 @@ impl RouterService {
let router_request = self
.router_model
.generate_request(messages, usage_preferences);
.generate_request(messages, &usage_preferences);
info!(
"sending request to arch-router model: {}, endpoint: {}",
@ -147,13 +154,40 @@ impl RouterService {
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{
let mut selected_model: Option<String> = None;
if let Some(selected_llm_name) = self.router_model.parse_response(content)? {
if selected_llm_name != "other" {
if let Some(usage_preferences) = usage_preferences {
for usage in usage_preferences {
if usage.name == selected_llm_name {
selected_model = Some(usage.model);
break;
}
}
if selected_model.is_none() {
warn!(
"Selected LLM model not found in usage preferences: {}",
selected_llm_name
);
}
} else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) {
selected_model = provider.model.clone();
} else {
warn!(
"Selected LLM model not found in provider map: {}",
selected_llm_name
);
}
}
}
info!(
"router response: {}, response time: {}ms",
"router response: {}, selected_model: {:?}, response time: {}ms",
content.replace("\n", "\\n"),
selected_model,
router_response_time.as_millis()
);
let selected_llm = self.router_model.parse_response(content)?;
Ok(selected_llm)
Ok(selected_model)
} else {
Ok(None)
}

View file

@ -14,7 +14,7 @@ pub trait RouterModel: Send + Sync {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: Option<Vec<ModelUsagePreference>>,
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
fn get_model_name(&self) -> String;

View file

@ -58,7 +58,7 @@ impl RouterModel for RouterModelV1 {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: Option<Vec<ModelUsagePreference>>,
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
@ -137,7 +137,16 @@ impl RouterModel for RouterModelV1 {
let llm_route_json = usage_preferences
.as_ref()
.map(|prefs| serde_json::to_string(prefs).unwrap_or_default())
.map(|prefs| {
let llm_route: Vec<LlmRoute> = prefs
.iter()
.map(|pref| LlmRoute {
name: pref.name.clone(),
description: pref.usage.clone().unwrap_or_default(),
})
.collect();
serde_json::to_string(&llm_route).unwrap_or_default()
})
.unwrap_or_else(|| self.llm_route_json_str.clone());
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
@ -268,7 +277,71 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, None);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_system_prompt_format_usage_preferences() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}]
</routes>
<conversation>
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let usage_preferences = Some(vec![ModelUsagePreference {
name: "code-generation".to_string(),
model: "claude/claude-3-7-sonnet".to_string(),
usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()),
}]);
let req = router.generate_request(&conversation, &usage_preferences);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -329,7 +402,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, None);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -390,7 +463,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, None);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -459,7 +532,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, None);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -529,7 +602,7 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, None);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -625,7 +698,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation, None);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();

View file

@ -180,6 +180,7 @@ impl Display for LlmProviderType {
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug)]
pub struct ModelUsagePreference {
pub name: String,
pub model: String,
pub usage: Option<String>,
}