In request path use same format for usage preferences as arch_config

This commit is contained in:
Adil Hafeez 2025-07-18 18:01:14 -07:00
parent 83f4d33434
commit cd3b511102
No known key found for this signature in database
GPG key ID: 9B18EF7691369645
5 changed files with 51 additions and 167 deletions

View file

@ -1,3 +1,2 @@
pub mod chat_completions;
pub mod models;
pub mod preferences;

View file

@ -1,135 +0,0 @@
use bytes::Bytes;
use common::configuration::{LlmProvider, ModelUsagePreference};
use http_body_util::{combinators::BoxBody, BodyExt, Full};
use hyper::{Request, Response, StatusCode};
use serde_json;
use std::{collections::HashMap, sync::Arc};
use tracing::{info, warn};
pub async fn list_preferences(
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Response<BoxBody<Bytes, hyper::Error>> {
let prov = llm_providers.read().await;
// convert the LlmProvider to UsageBasedProvider
let providers_with_usage = prov
.iter()
.map(|provider| ModelUsagePreference {
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
})
.collect::<Vec<ModelUsagePreference>>();
match serde_json::to_string(&providers_with_usage) {
Ok(json) => {
let body = Full::new(Bytes::from(json))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
Err(_) => {
let body = Full::new(Bytes::from_static(
b"{\"error\":\"Failed to serialize models\"}",
))
.map_err(|never| match never {})
.boxed();
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.header("Content-Type", "application/json")
.body(body)
.unwrap()
}
}
}
pub async fn update_preferences(
request: Request<hyper::body::Incoming>,
llm_providers: Arc<tokio::sync::RwLock<Vec<LlmProvider>>>,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_body = request.collect().await?.to_bytes();
let usage: Vec<ModelUsagePreference> = match serde_json::from_slice(&request_body) {
Ok(usage) => usage,
Err(_) => {
let response_body = Full::new(Bytes::from_static(b"Invalid request body: "))
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
};
let usage_model_map: HashMap<String, ModelUsagePreference> =
usage.into_iter().map(|u| (u.model.clone(), u)).collect();
info!(
"Updating usage preferences for models: {:?}",
usage_model_map.keys()
);
let mut llm_providers = llm_providers.write().await;
// ensure that models coming in the request are valid
let llm_provider_names: Vec<String> = llm_providers
.iter()
.map(|provider| provider.name.clone())
.collect();
for model in usage_model_map.keys() {
if !llm_provider_names.contains(model) {
let model_not_found = format!("model not found: {}", model);
warn!("updating preferences: {}", model_not_found);
let response_body = Full::new(model_not_found.into())
.map_err(|never| match never {})
.boxed();
return Ok(Response::builder()
.status(StatusCode::BAD_REQUEST)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap());
}
}
let mut updated_models_list = Vec::new();
for provider in llm_providers.iter_mut() {
if let Some(usage_provider) = usage_model_map.get(&provider.name) {
provider.usage = usage_provider.usage.clone();
updated_models_list.push(ModelUsagePreference {
name: provider.name.clone(),
model: provider.model.clone().unwrap_or_default(),
usage: provider.usage.clone(),
});
}
}
if !updated_models_list.is_empty() {
// return list of updated models
let response_body = Full::new(Bytes::from(format!(
"{{\"updated_models\": {}}}",
serde_json::to_string(&updated_models_list).unwrap()
)))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::OK)
.header("Content-Type", "application/json")
.body(response_body)
.unwrap())
} else {
let response_body = Full::new(Bytes::from_static(b"Provider not found"))
.map_err(|never| match never {})
.boxed();
Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.header("Content-Type", "text/plain")
.body(response_body)
.unwrap())
}
}

View file

@ -1,6 +1,5 @@
use brightstaff::handlers::chat_completions::chat_completions;
use brightstaff::handlers::models::list_models;
use brightstaff::handlers::preferences::{list_preferences, update_preferences};
use brightstaff::router::llm_router::RouterService;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
@ -116,12 +115,6 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
.with_context(parent_cx)
.await
}
(&Method::GET, "/v1/router/preferences") => {
Ok(list_preferences(llm_providers).await)
}
(&Method::PUT, "/v1/router/preferences") => {
update_preferences(req, llm_providers).await
}
(&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await),
(&Method::OPTIONS, "/v1/models") => {
let mut response = Response::new(empty());

View file

@ -155,9 +155,14 @@ impl RouterModel for RouterModelV1 {
.map(|prefs| {
let llm_route: Vec<RoutingPreference> = prefs
.iter()
.map(|pref| RoutingPreference {
name: pref.name.clone(),
description: pref.usage.clone().unwrap_or_default(),
.flat_map(|pref| {
let routing_preferences = pref.routing_preferences.clone();
routing_preferences
.into_iter()
.map(|routing_pref| RoutingPreference {
name: routing_pref.name,
description: routing_pref.description,
})
})
.collect();
serde_json::to_string(&llm_route).unwrap_or_default()
@ -201,12 +206,18 @@ impl RouterModel for RouterModelV1 {
if let Some(usage_preferences) = usage_preferences {
// If usage preferences are defined, we need to find the model that matches the selected route
let matching_preference = usage_preferences
let model_name: Option<String> = usage_preferences
.iter()
.find(|pref| pref.name == selected_route);
.map(|pref| {
pref.routing_preferences
.iter()
.find(|routing_pref| routing_pref.name == selected_route)
.map(|_| pref.model.clone())
})
.find_map(|model| model);
if let Some(preference) = matching_preference {
return Ok(Some((selected_route, preference.model.clone())));
if let Some(model_name) = model_name {
return Ok(Some((selected_route, model_name)));
} else {
warn!(
"No matching model found for route: {}, usage preferences: {:?}",
@ -299,7 +310,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -356,7 +368,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -379,9 +392,11 @@ Based on your analysis, provide your response in the following JSON formats if y
let conversation: Vec<Message> = serde_json::from_str(conversation_str).unwrap();
let usage_preferences = Some(vec![ModelUsagePreference {
name: "code-generation".to_string(),
model: "claude/claude-3-7-sonnet".to_string(),
usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()),
routing_preferences: vec![RoutingPreference {
name: "code-generation".to_string(),
description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(),
}],
}]);
let req = router.generate_request(&conversation, &usage_preferences);
@ -419,7 +434,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235);
@ -478,7 +494,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200);
@ -538,7 +555,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230);
@ -604,7 +622,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -672,7 +691,8 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let routing_model = "test-model".to_string();
let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX);
@ -747,14 +767,18 @@ Based on your analysis, provide your response in the following JSON formats if y
]
}
"#;
let llm_routes = serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let llm_routes =
serde_json::from_str::<HashMap<String, Vec<RoutingPreference>>>(routes_str).unwrap();
let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000);
// Case 1: Valid JSON with non-empty route
let input = r#"{"route": "Image generation"}"#;
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
// Case 2: Valid JSON with empty route
let input = r#"{"route": ""}"#;
@ -784,11 +808,17 @@ Based on your analysis, provide your response in the following JSON formats if y
// Case 6: Single quotes and \n in JSON
let input = "{'route': 'Image generation'}\\n";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
// Case 7: Code block marker
let input = "```json\n{\"route\": \"Image generation\"}\n```";
let result = router.parse_response(input, &None).unwrap();
assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string())));
assert_eq!(
result,
Some(("Image generation".to_string(), "gpt-4o".to_string()))
);
}
}

View file

@ -1,6 +1,5 @@
use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models};
use serde::{Deserialize, Serialize};
use serde_with::skip_serializing_none;
use std::collections::HashMap;
use std::fmt::Display;
@ -178,12 +177,10 @@ impl Display for LlmProviderType {
}
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug)]
pub struct ModelUsagePreference {
pub name: String,
pub model: String,
pub usage: Option<String>,
pub routing_preferences: Vec<RoutingPreference>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]