mirror of
https://github.com/katanemo/plano.git
synced 2026-05-08 15:22:43 +02:00
better model names (#517)
This commit is contained in:
parent
4e2355965b
commit
a7fddf30f9
55 changed files with 979 additions and 483 deletions
|
|
@ -1,7 +1,7 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
use std::sync::Arc;
|
||||
|
||||
use common::{
|
||||
configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
|
||||
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
|
||||
consts::ARCH_PROVIDER_HINT_HEADER,
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
|
||||
|
|
@ -19,7 +19,6 @@ pub struct RouterService {
|
|||
router_model: Arc<dyn RouterModel>,
|
||||
routing_provider_name: String,
|
||||
llm_usage_defined: bool,
|
||||
llm_provider_map: HashMap<String, LlmProvider>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
|
|
@ -45,11 +44,14 @@ impl RouterService {
|
|||
) -> Self {
|
||||
let providers_with_usage = providers
|
||||
.iter()
|
||||
.filter(|provider| provider.usage.is_some())
|
||||
.filter(|provider| provider.routing_preferences.is_some())
|
||||
.cloned()
|
||||
.collect::<Vec<LlmProvider>>();
|
||||
|
||||
let llm_routes: Vec<LlmRoute> = providers_with_usage.iter().map(LlmRoute::from).collect();
|
||||
let llm_routes: Vec<RoutingPreference> = providers_with_usage
|
||||
.iter()
|
||||
.flat_map(|provider| provider.routing_preferences.clone().unwrap_or_default())
|
||||
.collect();
|
||||
|
||||
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
|
||||
llm_routes,
|
||||
|
|
@ -57,18 +59,12 @@ impl RouterService {
|
|||
router_model_v1::MAX_TOKEN_LEN,
|
||||
));
|
||||
|
||||
let llm_provider_map: HashMap<String, LlmProvider> = providers
|
||||
.into_iter()
|
||||
.map(|provider| (provider.name.clone(), provider))
|
||||
.collect();
|
||||
|
||||
RouterService {
|
||||
router_url,
|
||||
client: reqwest::Client::new(),
|
||||
router_model,
|
||||
routing_provider_name,
|
||||
llm_usage_defined: !providers_with_usage.is_empty(),
|
||||
llm_provider_map,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -155,40 +151,21 @@ impl RouterService {
|
|||
if let Some(ContentType::Text(content)) =
|
||||
&chat_completion_response.choices[0].message.content
|
||||
{
|
||||
let mut selected_model: Option<String> = None;
|
||||
if let Some(selected_llm_name) = self.router_model.parse_response(content)? {
|
||||
if selected_llm_name != "other" {
|
||||
if let Some(usage_preferences) = usage_preferences {
|
||||
for usage in usage_preferences {
|
||||
if usage.name == selected_llm_name {
|
||||
selected_model = Some(usage.model);
|
||||
break;
|
||||
}
|
||||
}
|
||||
if selected_model.is_none() {
|
||||
warn!(
|
||||
"Selected LLM model not found in usage preferences: {}",
|
||||
selected_llm_name
|
||||
);
|
||||
}
|
||||
} else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) {
|
||||
selected_model = provider.model.clone();
|
||||
} else {
|
||||
warn!(
|
||||
"Selected LLM model not found in provider map: {}",
|
||||
selected_llm_name
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
let route_name = self.router_model.parse_response(content)?;
|
||||
info!(
|
||||
"router response: {}, selected_model: {:?}, response time: {}ms",
|
||||
content.replace("\n", "\\n"),
|
||||
selected_model,
|
||||
route_name,
|
||||
router_response_time.as_millis()
|
||||
);
|
||||
|
||||
Ok(selected_model)
|
||||
if let Some(ref route) = route_name {
|
||||
if route == "other" {
|
||||
return Ok(None);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(route_name)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
use common::{
|
||||
configuration::{LlmRoute, ModelUsagePreference},
|
||||
configuration::{ModelUsagePreference, RoutingPreference},
|
||||
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
|
||||
};
|
||||
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
|
||||
|
|
@ -36,7 +36,11 @@ pub struct RouterModelV1 {
|
|||
max_token_length: usize,
|
||||
}
|
||||
impl RouterModelV1 {
|
||||
pub fn new(llm_routes: Vec<LlmRoute>, routing_model: String, max_token_length: usize) -> Self {
|
||||
pub fn new(
|
||||
llm_routes: Vec<RoutingPreference>,
|
||||
routing_model: String,
|
||||
max_token_length: usize,
|
||||
) -> Self {
|
||||
let llm_route_json_str =
|
||||
serde_json::to_string(&llm_routes).unwrap_or_else(|_| "[]".to_string());
|
||||
RouterModelV1 {
|
||||
|
|
@ -138,9 +142,9 @@ impl RouterModel for RouterModelV1 {
|
|||
let llm_route_json = usage_preferences
|
||||
.as_ref()
|
||||
.map(|prefs| {
|
||||
let llm_route: Vec<LlmRoute> = prefs
|
||||
let llm_route: Vec<RoutingPreference> = prefs
|
||||
.iter()
|
||||
.map(|pref| LlmRoute {
|
||||
.map(|pref| RoutingPreference {
|
||||
name: pref.name.clone(),
|
||||
description: pref.usage.clone().unwrap_or_default(),
|
||||
})
|
||||
|
|
@ -255,7 +259,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -314,7 +318,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -379,7 +383,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
|
||||
|
||||
|
|
@ -440,7 +444,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
|
||||
|
||||
|
|
@ -501,7 +505,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
|
||||
|
||||
|
|
@ -569,7 +573,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -639,7 +643,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
let routing_model = "test-model".to_string();
|
||||
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
|
||||
|
||||
|
|
@ -716,7 +720,7 @@ Based on your analysis, provide your response in the following JSON formats if y
|
|||
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
|
||||
]
|
||||
"#;
|
||||
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
|
||||
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
|
||||
|
||||
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue