From cd3b5111028f61f37df3caa2ae6f07914f7287ee Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Fri, 18 Jul 2025 18:01:14 -0700 Subject: [PATCH] In request path use same format for usage preferences as arch_config --- crates/brightstaff/src/handlers/mod.rs | 1 - .../brightstaff/src/handlers/preferences.rs | 135 ------------------ crates/brightstaff/src/main.rs | 7 - .../brightstaff/src/router/router_model_v1.rs | 70 ++++++--- crates/common/src/configuration.rs | 5 +- 5 files changed, 51 insertions(+), 167 deletions(-) delete mode 100644 crates/brightstaff/src/handlers/preferences.rs diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index febab6c2..6de38b5b 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,3 +1,2 @@ pub mod chat_completions; pub mod models; -pub mod preferences; diff --git a/crates/brightstaff/src/handlers/preferences.rs b/crates/brightstaff/src/handlers/preferences.rs deleted file mode 100644 index a9c5a65d..00000000 --- a/crates/brightstaff/src/handlers/preferences.rs +++ /dev/null @@ -1,135 +0,0 @@ -use bytes::Bytes; -use common::configuration::{LlmProvider, ModelUsagePreference}; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; -use hyper::{Request, Response, StatusCode}; -use serde_json; -use std::{collections::HashMap, sync::Arc}; -use tracing::{info, warn}; - -pub async fn list_preferences( - llm_providers: Arc>>, -) -> Response> { - let prov = llm_providers.read().await; - // convert the LlmProvider to UsageBasedProvider - let providers_with_usage = prov - .iter() - .map(|provider| ModelUsagePreference { - name: provider.name.clone(), - model: provider.model.clone().unwrap_or_default(), - usage: provider.usage.clone(), - }) - .collect::>(); - - match serde_json::to_string(&providers_with_usage) { - Ok(json) => { - let body = Full::new(Bytes::from(json)) - .map_err(|never| match never {}) - .boxed(); - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(body) - .unwrap() - } - Err(_) => { - let body = Full::new(Bytes::from_static( - b"{\"error\":\"Failed to serialize models\"}", - )) - .map_err(|never| match never {}) - .boxed(); - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header("Content-Type", "application/json") - .body(body) - .unwrap() - } - } -} - -pub async fn update_preferences( - request: Request, - llm_providers: Arc>>, -) -> Result>, hyper::Error> { - let request_body = request.collect().await?.to_bytes(); - - let usage: Vec = match serde_json::from_slice(&request_body) { - Ok(usage) => usage, - Err(_) => { - let response_body = Full::new(Bytes::from_static(b"Invalid request body: ")) - .map_err(|never| match never {}) - .boxed(); - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()); - } - }; - - let usage_model_map: HashMap = - usage.into_iter().map(|u| (u.model.clone(), u)).collect(); - - info!( - "Updating usage preferences for models: {:?}", - usage_model_map.keys() - ); - - let mut llm_providers = llm_providers.write().await; - - // ensure that models coming in the request are valid - let llm_provider_names: Vec = llm_providers - .iter() - .map(|provider| provider.name.clone()) - .collect(); - - for model in usage_model_map.keys() { - if !llm_provider_names.contains(model) { - let model_not_found = format!("model not found: {}", model); - warn!("updating preferences: {}", model_not_found); - let response_body = Full::new(model_not_found.into()) - .map_err(|never| match never {}) - .boxed(); - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()); - } - } - - let mut updated_models_list = Vec::new(); - for provider in llm_providers.iter_mut() { - if let Some(usage_provider) = usage_model_map.get(&provider.name) { - provider.usage = usage_provider.usage.clone(); - updated_models_list.push(ModelUsagePreference { - name: provider.name.clone(), - model: provider.model.clone().unwrap_or_default(), - usage: provider.usage.clone(), - }); - } - } - - if !updated_models_list.is_empty() { - // return list of updated models - let response_body = Full::new(Bytes::from(format!( - "{{\"updated_models\": {}}}", - serde_json::to_string(&updated_models_list).unwrap() - ))) - .map_err(|never| match never {}) - .boxed(); - Ok(Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(response_body) - .unwrap()) - } else { - let response_body = Full::new(Bytes::from_static(b"Provider not found")) - .map_err(|never| match never {}) - .boxed(); - Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()) - } -} diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index dbcc9124..b5bf0204 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,6 +1,5 @@ use brightstaff::handlers::chat_completions::chat_completions; use brightstaff::handlers::models::list_models; -use brightstaff::handlers::preferences::{list_preferences, update_preferences}; use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; @@ -116,12 +115,6 @@ async fn main() -> Result<(), Box> { .with_context(parent_cx) .await } - (&Method::GET, "/v1/router/preferences") => { - Ok(list_preferences(llm_providers).await) - } - (&Method::PUT, "/v1/router/preferences") => { - update_preferences(req, llm_providers).await - } (&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await), (&Method::OPTIONS, "/v1/models") => { let mut response = Response::new(empty()); diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index dc0e1563..2dccd61f 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -155,9 +155,14 @@ impl RouterModel for RouterModelV1 { .map(|prefs| { let llm_route: Vec = prefs .iter() - .map(|pref| RoutingPreference { - name: pref.name.clone(), - description: pref.usage.clone().unwrap_or_default(), + .flat_map(|pref| { + let routing_preferences = pref.routing_preferences.clone(); + routing_preferences + .into_iter() + .map(|routing_pref| RoutingPreference { + name: routing_pref.name, + description: routing_pref.description, + }) }) .collect(); serde_json::to_string(&llm_route).unwrap_or_default() @@ -201,12 +206,18 @@ impl RouterModel for RouterModelV1 { if let Some(usage_preferences) = usage_preferences { // If usage preferences are defined, we need to find the model that matches the selected route - let matching_preference = usage_preferences + let model_name: Option = usage_preferences .iter() - .find(|pref| pref.name == selected_route); + .map(|pref| { + pref.routing_preferences + .iter() + .find(|routing_pref| routing_pref.name == selected_route) + .map(|_| pref.model.clone()) + }) + .find_map(|model| model); - if let Some(preference) = matching_preference { - return Ok(Some((selected_route, preference.model.clone()))); + if let Some(model_name) = model_name { + return Ok(Some((selected_route, model_name))); } else { warn!( "No matching model found for route: {}, usage preferences: {:?}", @@ -299,7 +310,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -356,7 +368,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -379,9 +392,11 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); let usage_preferences = Some(vec![ModelUsagePreference { - name: "code-generation".to_string(), model: "claude/claude-3-7-sonnet".to_string(), - usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()), + routing_preferences: vec![RoutingPreference { + name: "code-generation".to_string(), + description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(), + }], }]); let req = router.generate_request(&conversation, &usage_preferences); @@ -419,7 +434,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235); @@ -478,7 +494,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200); @@ -538,7 +555,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230); @@ -604,7 +622,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -672,7 +691,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -747,14 +767,18 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000); // Case 1: Valid JSON with non-empty route let input = r#"{"route": "Image generation"}"#; let result = router.parse_response(input, &None).unwrap(); - assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); // Case 2: Valid JSON with empty route let input = r#"{"route": ""}"#; @@ -784,11 +808,17 @@ Based on your analysis, provide your response in the following JSON formats if y // Case 6: Single quotes and \n in JSON let input = "{'route': 'Image generation'}\\n"; let result = router.parse_response(input, &None).unwrap(); - assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); // Case 7: Code block marker let input = "```json\n{\"route\": \"Image generation\"}\n```"; let result = router.parse_response(input, &None).unwrap(); - assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); } } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 0693c09b..186691dc 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -1,6 +1,5 @@ use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models}; use serde::{Deserialize, Serialize}; -use serde_with::skip_serializing_none; use std::collections::HashMap; use std::fmt::Display; @@ -178,12 +177,10 @@ impl Display for LlmProviderType { } } -#[skip_serializing_none] #[derive(Serialize, Deserialize, Debug)] pub struct ModelUsagePreference { - pub name: String, pub model: String, - pub usage: Option, + pub routing_preferences: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)]