Add support for updating model preferences (#510)

This commit is contained in:
Adil Hafeez 2025-07-02 14:08:19 -07:00 committed by GitHub
parent 1963020c21
commit 00dc95e034
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
16 changed files with 437 additions and 53 deletions

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use std::{collections::HashMap, sync::Arc};
use common::{
configuration::{LlmProvider, LlmRoute},
configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
consts::ARCH_PROVIDER_HINT_HEADER,
};
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
@ -19,6 +19,7 @@ pub struct RouterService {
router_model: Arc<dyn RouterModel>,
routing_model_name: String,
llm_usage_defined: bool,
llm_provider_map: HashMap<String, LlmProvider>,
}
#[derive(Debug, Error)]
@ -55,12 +56,18 @@ impl RouterService {
router_model_v1::MAX_TOKEN_LEN,
));
let llm_provider_map: HashMap<String, LlmProvider> = providers
.into_iter()
.map(|provider| (provider.name.clone(), provider))
.collect();
RouterService {
router_url,
client: reqwest::Client::new(),
router_model,
routing_model_name,
llm_usage_defined: !providers_with_usage.is_empty(),
llm_provider_map,
}
}
@ -68,12 +75,15 @@ impl RouterService {
&self,
messages: &[Message],
trace_parent: Option<String>,
usage_preferences: Option<Vec<ModelUsagePreference>>,
) -> Result<Option<String>> {
if !self.llm_usage_defined {
return Ok(None);
}
let router_request = self.router_model.generate_request(messages);
let router_request = self
.router_model
.generate_request(messages, &usage_preferences);
info!(
"sending request to arch-router model: {}, endpoint: {}",
@ -144,13 +154,40 @@ impl RouterService {
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{
let mut selected_model: Option<String> = None;
if let Some(selected_llm_name) = self.router_model.parse_response(content)? {
if selected_llm_name != "other" {
if let Some(usage_preferences) = usage_preferences {
for usage in usage_preferences {
if usage.name == selected_llm_name {
selected_model = Some(usage.model);
break;
}
}
if selected_model.is_none() {
warn!(
"Selected LLM model not found in usage preferences: {}",
selected_llm_name
);
}
} else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) {
selected_model = provider.model.clone();
} else {
warn!(
"Selected LLM model not found in provider map: {}",
selected_llm_name
);
}
}
}
info!(
"router response: {}, response time: {}ms",
"router response: {}, selected_model: {:?}, response time: {}ms",
content.replace("\n", "\\n"),
selected_model,
router_response_time.as_millis()
);
let selected_llm = self.router_model.parse_response(content)?;
Ok(selected_llm)
Ok(selected_model)
} else {
Ok(None)
}

View file

@ -1,3 +1,4 @@
use common::configuration::ModelUsagePreference;
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
use thiserror::Error;
@ -10,7 +11,11 @@ pub enum RoutingModelError {
pub type Result<T> = std::result::Result<T, RoutingModelError>;
pub trait RouterModel: Send + Sync {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
fn generate_request(
&self,
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest;
fn parse_response(&self, content: &str) -> Result<Option<String>>;
fn get_model_name(&self) -> String;
}

View file

@ -1,5 +1,5 @@
use common::{
configuration::LlmRoute,
configuration::{LlmRoute, ModelUsagePreference},
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
@ -55,7 +55,11 @@ struct LlmRouterResponse {
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
impl RouterModel for RouterModelV1 {
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
fn generate_request(
&self,
messages: &[Message],
usage_preferences: &Option<Vec<ModelUsagePreference>>,
) -> ChatCompletionsRequest {
// remove system prompt, tool calls, tool call response and messages without content
// if content is empty its likely a tool call
// when role == tool its tool call response
@ -131,8 +135,22 @@ impl RouterModel for RouterModelV1 {
})
.collect::<Vec<Message>>();
let llm_route_json = usage_preferences
.as_ref()
.map(|prefs| {
let llm_route: Vec<LlmRoute> = prefs
.iter()
.map(|pref| LlmRoute {
name: pref.name.clone(),
description: pref.usage.clone().unwrap_or_default(),
})
.collect();
serde_json::to_string(&llm_route).unwrap_or_default()
})
.unwrap_or_else(|| self.llm_route_json_str.clone());
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
.replace("{routes}", &self.llm_route_json_str)
.replace("{routes}", &llm_route_json)
.replace(
"{conversation}",
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
@ -204,8 +222,6 @@ impl std::fmt::Debug for dyn RouterModel {
#[cfg(test)]
mod tests {
use crate::utils::tracing::init_tracer;
use super::*;
use pretty_assertions::assert_eq;
@ -261,7 +277,71 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
assert_eq!(expected_prompt, prompt.to_string());
}
#[test]
fn test_system_prompt_format_usage_preferences() {
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
<routes>
[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}]
</routes>
<conversation>
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
</conversation>
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
2. You must analyze the route descriptions and find the best match route for user latest intent.
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
{"route": "route_name"}
"#;
let routes_str = r#"
[
{"name": "Image generation", "description": "generating image"},
{"name": "image conversion", "description": "convert images to provided format"},
{"name": "image search", "description": "search image"},
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
let conversation_str = r#"
[
{
"role": "user",
"content": "hi"
},
{
"role": "assistant",
"content": "Hello! How can I assist you today?"
},
{
"role": "user",
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
}
]
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let usage_preferences = Some(vec![ModelUsagePreference {
name: "code-generation".to_string(),
model: "claude/claude-3-7-sonnet".to_string(),
usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()),
}]);
let req = router.generate_request(&conversation, &usage_preferences);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -270,7 +350,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_exceed_token_count() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -323,7 +402,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -332,7 +411,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_exceed_token_count_large_single_message() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -385,7 +463,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -394,7 +472,6 @@ Based on your analysis, provide your response in the following JSON formats if y
#[test]
fn test_conversation_trim_upto_user_message() {
let _tracer = init_tracer();
let expected_prompt = r#"
You are a helpful assistant designed to find the best suited route.
You are provided with route description within <routes></routes> XML tags:
@ -455,7 +532,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -525,7 +602,7 @@ Based on your analysis, provide your response in the following JSON formats if y
"#;
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();
@ -621,7 +698,7 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let req = router.generate_request(&conversation);
let req = router.generate_request(&conversation, &None);
let prompt = req.messages[0].content.as_ref().unwrap();