mirror of
https://github.com/katanemo/plano.git
synced 2026-05-07 23:02:43 +02:00
Add support for updating model preferences (#510)
This commit is contained in:
parent
1963020c21
commit
00dc95e034
16 changed files with 437 additions and 53 deletions
|
|
@ -1,7 +1,7 @@
|
|||
use std::sync::Arc;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use common::{
|
||||
configuration::{LlmProvider, LlmRoute},
|
||||
configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
|
||||
|
|
@ -19,6 +19,7 @@ pub struct RouterService {
|
|||
router_model: Arc<dyn RouterModel>,
|
||||
routing_model_name: String,
|
||||
llm_usage_defined: bool,
|
||||
llm_provider_map: HashMap<String, LlmProvider>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -55,12 +56,18 @@ impl RouterService {
|
|||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
let llm_provider_map: HashMap<String, LlmProvider> = providers
|
||||
.into_iter()
|
||||
.map(|provider| (provider.name.clone(), provider))
|
||||
.collect();
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_model_name,
|
||||
llm_usage_defined: !providers_with_usage.is_empty(),
|
||||
llm_provider_map,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -68,12 +75,15 @@ impl RouterService {
|
|||
&self,
|
||||
messages: &[Message],
|
||||
trace_parent: Option<String>,
|
||||
usage_preferences: Option<Vec<ModelUsagePreference>>,
|
||||
) -> Result<Option<String>> {
|
||||
if !self.llm_usage_defined {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let router_request = self.router_model.generate_request(messages);
|
||||
let router_request = self
|
||||
.router_model
|
||||
.generate_request(messages, &usage_preferences);
|
||||
|
||||
info!(
|
||||
"sending request to arch-router model: {}, endpoint: {}",
|
||||
|
|
@ -144,13 +154,40 @@ impl RouterService {
|
|||
if let Some(ContentType::Text(content)) =
|
||||
&chat_completion_response.choices[0].message.content
|
||||
{
|
||||
let mut selected_model: Option<String> = None;
|
||||
if let Some(selected_llm_name) = self.router_model.parse_response(content)? {
|
||||
if selected_llm_name != "other" {
|
||||
if let Some(usage_preferences) = usage_preferences {
|
||||
for usage in usage_preferences {
|
||||
if usage.name == selected_llm_name {
|
||||
selected_model = Some(usage.model);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if selected_model.is_none() {
|
||||
warn!(
|
||||
"Selected LLM model not found in usage preferences: {}",
|
||||
selected_llm_name
|
||||
);
|
||||
}
|
||||
} else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) {
|
||||
selected_model = provider.model.clone();
|
||||
} else {
|
||||
warn!(
|
||||
"Selected LLM model not found in provider map: {}",
|
||||
selected_llm_name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
info!(
|
||||
"router response: {}, response time: {}ms",
|
||||
"router response: {}, selected_model: {:?}, response time: {}ms",
|
||||
content.replace("\n", "\\n"),
|
||||
selected_model,
|
||||
router_response_time.as_millis()
|
||||
);
|
||||
let selected_llm = self.router_model.parse_response(content)?;
|
||||
Ok(selected_llm)
|
||||
|
||||
Ok(selected_model)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
use common::configuration::ModelUsagePreference;
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message};
|
||||
use thiserror::Error;
|
||||
|
||||
|
|
@ -10,7 +11,11 @@ pub enum RoutingModelError {
|
|||
pub type Result<T> = std::result::Result<T, RoutingModelError>;
|
||||
|
||||
pub trait RouterModel: Send + Sync {
|
||||
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest;
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest;
|
||||
fn parse_response(&self, content: &str) -> Result<Option<String>>;
|
||||
fn get_model_name(&self) -> String;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use common::{
|
||||
configuration::LlmRoute,
|
||||
configuration::{LlmRoute, ModelUsagePreference},
|
||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
|
||||
|
|
@ -55,7 +55,11 @@ struct LlmRouterResponse {
|
|||
const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters
|
||||
|
||||
impl RouterModel for RouterModelV1 {
|
||||
fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest {
|
||||
fn generate_request(
|
||||
&self,
|
||||
messages: &[Message],
|
||||
usage_preferences: &Option<Vec<ModelUsagePreference>>,
|
||||
) -> ChatCompletionsRequest {
|
||||
// remove system prompt, tool calls, tool call response and messages without content
|
||||
// if content is empty its likely a tool call
|
||||
// when role == tool its tool call response
|
||||
|
|
@ -131,8 +135,22 @@ impl RouterModel for RouterModelV1 {
|
|||
})
|
||||
.collect::<Vec<Message>>();
|
||||
|
||||
let llm_route_json = usage_preferences
|
||||
.as_ref()
|
||||
.map(|prefs| {
|
||||
let llm_route: Vec<LlmRoute> = prefs
|
||||
.iter()
|
||||
.map(|pref| LlmRoute {
|
||||
name: pref.name.clone(),
|
||||
description: pref.usage.clone().unwrap_or_default(),
|
||||
})
|
||||
.collect();
|
||||
serde_json::to_string(&llm_route).unwrap_or_default()
|
||||
})
|
||||
.unwrap_or_else(|| self.llm_route_json_str.clone());
|
||||
|
||||
let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT
|
||||
.replace("{routes}", &self.llm_route_json_str)
|
||||
.replace("{routes}", &llm_route_json)
|
||||
.replace(
|
||||
"{conversation}",
|
||||
&serde_json::to_string(&selected_conversation_list).unwrap_or_default(),
|
||||
|
|
@ -204,8 +222,6 @@ impl std::fmt::Debug for dyn RouterModel {
|
|||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::utils::tracing::init_tracer;
|
||||
|
||||
use super::*;
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
|
|
@ -261,7 +277,71 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation);
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
assert_eq!(expected_prompt, prompt.to_string());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_system_prompt_format_usage_preferences() {
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
<routes>
|
||||
[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}]
|
||||
</routes>
|
||||
|
||||
<conversation>
|
||||
[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}]
|
||||
</conversation>
|
||||
|
||||
Your task is to decide which route is best suit with user intent on the conversation in <conversation></conversation> XML tags. Follow the instruction:
|
||||
1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}.
|
||||
2. You must analyze the route descriptions and find the best match route for user latest intent.
|
||||
3. You only response the name of the route that best matches the user's request, use the exact name in the <routes></routes>.
|
||||
|
||||
Based on your analysis, provide your response in the following JSON formats if you decide to match any route:
|
||||
{"route": "route_name"}
|
||||
"#;
|
||||
let routes_str = r#"
|
||||
[
|
||||
{"name": "Image generation", "description": "generating image"},
|
||||
{"name": "image conversion", "description": "convert images to provided format"},
|
||||
{"name": "image search", "description": "search image"},
|
||||
{"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"},
|
||||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
let conversation_str = r#"
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?"
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"
|
||||
}
|
||||
]
|
||||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let usage_preferences = Some(vec![ModelUsagePreference {
|
||||
name: "code-generation".to_string(),
|
||||
model: "claude/claude-3-7-sonnet".to_string(),
|
||||
usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()),
|
||||
}]);
|
||||
let req = router.generate_request(&conversation, &usage_preferences);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
|
|
@ -270,7 +350,6 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
#[test]
|
||||
fn test_conversation_exceed_token_count() {
|
||||
let _tracer = init_tracer();
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
|
|
@ -323,7 +402,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation);
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
|
|
@ -332,7 +411,6 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
#[test]
|
||||
fn test_conversation_exceed_token_count_large_single_message() {
|
||||
let _tracer = init_tracer();
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
|
|
@ -385,7 +463,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation);
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
|
|
@ -394,7 +472,6 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
#[test]
|
||||
fn test_conversation_trim_upto_user_message() {
|
||||
let _tracer = init_tracer();
|
||||
let expected_prompt = r#"
|
||||
You are a helpful assistant designed to find the best suited route.
|
||||
You are provided with route description within <routes></routes> XML tags:
|
||||
|
|
@ -455,7 +532,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation);
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
|
|
@ -525,7 +602,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
"#;
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation);
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
|
|
@ -621,7 +698,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
|
||||
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
|
||||
|
||||
let req = router.generate_request(&conversation);
|
||||
let req = router.generate_request(&conversation, &None);
|
||||
|
||||
let prompt = req.messages[0].content.as_ref().unwrap();
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue