better model names (#517)

This commit is contained in:
Adil Hafeez 2025-07-11 16:42:16 -07:00 committed by GitHub
parent 4e2355965b
commit a7fddf30f9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
55 changed files with 979 additions and 483 deletions

View file

@ -12,7 +12,7 @@ use hyper::{Request, Response, StatusCode};
use tokio::sync::mpsc;
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt;
use tracing::{debug, info, trace, warn};
use tracing::{debug, info, warn};
use crate::router::llm_router::RouterService;
@ -81,8 +81,8 @@ pub async fn chat_completions(
}
}
trace!(
"arch-router request body: {}",
debug!(
"arch-router request received: {}",
&serde_json::to_string(&chat_completion_request).unwrap()
);
@ -102,9 +102,9 @@ pub async fn chat_completions(
.as_ref()
.and_then(|s| serde_yaml::from_str(s).ok());
debug!("usage preferences: {:?}", usage_preferences);
debug!("usage preferences from request: {:?}", usage_preferences);
let mut selected_llm = match router_service
let mut determined_route = match router_service
.determine_route(
&chat_completion_request.messages,
trace_parent.clone(),
@ -121,14 +121,14 @@ pub async fn chat_completions(
}
};
if selected_llm.is_none() {
if determined_route.is_none() {
debug!("No LLM model selected, using default from request");
selected_llm = Some(chat_completion_request.model.clone());
determined_route = Some(chat_completion_request.model.clone());
}
info!(
"sending request to llm provider: {} with llm model: {:?}",
llm_provider_endpoint, selected_llm
llm_provider_endpoint, determined_route
);
if let Some(trace_parent) = trace_parent {
@ -138,10 +138,10 @@ pub async fn chat_completions(
);
}
if let Some(selected_llm) = selected_llm {
if let Some(selected_route) = determined_route {
request_headers.insert(
ARCH_PROVIDER_HINT_HEADER,
header::HeaderValue::from_str(&selected_llm).unwrap(),
header::HeaderValue::from_str(&selected_route).unwrap(),
);
}

View file

@ -44,9 +44,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let _tracer_provider = init_tracer();
let bind_address = env::var("BIND_ADDRESS").unwrap_or_else(|_| BIND_ADDRESS.to_string());
info!(
"current working directory: {}",
env::current_dir().unwrap().display()
);
// loading arch_config.yaml file
let arch_config_path =
env::var("ARCH_CONFIG_PATH").unwrap_or_else(|_| "./arch_config.yaml".to_string());
let arch_config_path = env::var("ARCH_CONFIG_PATH_RENDERED")
.unwrap_or_else(|_| "./arch_config_rendered.yaml".to_string());
info!("Loading arch_config.yaml from {}", arch_config_path);
let config_contents =

View file

@ -1,7 +1,7 @@
use std::{collections::HashMap, sync::Arc};
use std::sync::Arc;
use common::{
configuration::{LlmProvider, LlmRoute, ModelUsagePreference},
configuration::{LlmProvider, ModelUsagePreference, RoutingPreference},
consts::ARCH_PROVIDER_HINT_HEADER,
};
use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message};
@ -19,7 +19,6 @@ pub struct RouterService {
router_model: Arc<dyn RouterModel>,
routing_provider_name: String,
llm_usage_defined: bool,
llm_provider_map: HashMap<String, LlmProvider>,
}
#[derive(Debug, Error)]
@ -45,11 +44,14 @@ impl RouterService {
) -> Self {
let providers_with_usage = providers
.iter()
.filter(|provider| provider.usage.is_some())
.filter(|provider| provider.routing_preferences.is_some())
.cloned()
.collect::<Vec<LlmProvider>>();
let llm_routes: Vec<LlmRoute> = providers_with_usage.iter().map(LlmRoute::from).collect();
let llm_routes: Vec<RoutingPreference> = providers_with_usage
.iter()
.flat_map(|provider| provider.routing_preferences.clone().unwrap_or_default())
.collect();
let router_model = Arc::new(router_model_v1::RouterModelV1::new(
llm_routes,
@ -57,18 +59,12 @@ impl RouterService {
router_model_v1::MAX_TOKEN_LEN,
));
let llm_provider_map: HashMap<String, LlmProvider> = providers
.into_iter()
.map(|provider| (provider.name.clone(), provider))
.collect();
RouterService {
router_url,
client: reqwest::Client::new(),
router_model,
routing_provider_name,
llm_usage_defined: !providers_with_usage.is_empty(),
llm_provider_map,
}
}
@ -155,40 +151,21 @@ impl RouterService {
if let Some(ContentType::Text(content)) =
&chat_completion_response.choices[0].message.content
{
let mut selected_model: Option<String> = None;
if let Some(selected_llm_name) = self.router_model.parse_response(content)? {
if selected_llm_name != "other" {
if let Some(usage_preferences) = usage_preferences {
for usage in usage_preferences {
if usage.name == selected_llm_name {
selected_model = Some(usage.model);
break;
}
}
if selected_model.is_none() {
warn!(
"Selected LLM model not found in usage preferences: {}",
selected_llm_name
);
}
} else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) {
selected_model = provider.model.clone();
} else {
warn!(
"Selected LLM model not found in provider map: {}",
selected_llm_name
);
}
}
}
let route_name = self.router_model.parse_response(content)?;
info!(
"router response: {}, selected_model: {:?}, response time: {}ms",
content.replace("\n", "\\n"),
selected_model,
route_name,
router_response_time.as_millis()
);
Ok(selected_model)
if let Some(ref route) = route_name {
if route == "other" {
return Ok(None);
}
}
Ok(route_name)
} else {
Ok(None)
}

View file

@ -1,5 +1,5 @@
use common::{
configuration::{LlmRoute, ModelUsagePreference},
configuration::{ModelUsagePreference, RoutingPreference},
consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE},
};
use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message};
@ -36,7 +36,11 @@ pub struct RouterModelV1 {
max_token_length: usize,
}
impl RouterModelV1 {
pub fn new(llm_routes: Vec<LlmRoute>, routing_model: String, max_token_length: usize) -> Self {
pub fn new(
llm_routes: Vec<RoutingPreference>,
routing_model: String,
max_token_length: usize,
) -> Self {
let llm_route_json_str =
serde_json::to_string(&llm_routes).unwrap_or_else(|_| "[]".to_string());
RouterModelV1 {
@ -138,9 +142,9 @@ impl RouterModel for RouterModelV1 {
let llm_route_json = usage_preferences
.as_ref()
.map(|prefs| {
let llm_route: Vec<LlmRoute> = prefs
let llm_route: Vec<RoutingPreference> = prefs
.iter()
.map(|pref| LlmRoute {
.map(|pref| RoutingPreference {
name: pref.name.clone(),
description: pref.usage.clone().unwrap_or_default(),
})
@ -255,7 +259,7 @@ Based on your analysis, provide your response in the following JSON formats if y
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -314,7 +318,7 @@ Based on your analysis, provide your response in the following JSON formats if y
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -379,7 +383,7 @@ Based on your analysis, provide your response in the following JSON formats if y
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
@ -440,7 +444,7 @@ Based on your analysis, provide your response in the following JSON formats if y
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
@ -501,7 +505,7 @@ Based on your analysis, provide your response in the following JSON formats if y
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
@ -569,7 +573,7 @@ Based on your analysis, provide your response in the following JSON formats if y
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -639,7 +643,7 @@ Based on your analysis, provide your response in the following JSON formats if y
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -716,7 +720,7 @@ Based on your analysis, provide your response in the following JSON formats if y
{"name": "Speech Recognition", "description": "Converting spoken language into written text"}
]
"#;
let llm_routes = serde_json::from_str::<Vec<LlmRoute>>(routes_str).unwrap();
let llm_routes = serde_json::from_str::<Vec<RoutingPreference>>(routes_str).unwrap();
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);