diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index febab6c2..6de38b5b 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,3 +1,2 @@ pub mod chat_completions; pub mod models; -pub mod preferences; diff --git a/crates/brightstaff/src/handlers/preferences.rs b/crates/brightstaff/src/handlers/preferences.rs deleted file mode 100644 index a9c5a65d..00000000 --- a/crates/brightstaff/src/handlers/preferences.rs +++ /dev/null @@ -1,135 +0,0 @@ -use bytes::Bytes; -use common::configuration::{LlmProvider, ModelUsagePreference}; -use http_body_util::{combinators::BoxBody, BodyExt, Full}; -use hyper::{Request, Response, StatusCode}; -use serde_json; -use std::{collections::HashMap, sync::Arc}; -use tracing::{info, warn}; - -pub async fn list_preferences( - llm_providers: Arc>>, -) -> Response> { - let prov = llm_providers.read().await; - // convert the LlmProvider to UsageBasedProvider - let providers_with_usage = prov - .iter() - .map(|provider| ModelUsagePreference { - name: provider.name.clone(), - model: provider.model.clone().unwrap_or_default(), - usage: provider.usage.clone(), - }) - .collect::>(); - - match serde_json::to_string(&providers_with_usage) { - Ok(json) => { - let body = Full::new(Bytes::from(json)) - .map_err(|never| match never {}) - .boxed(); - Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(body) - .unwrap() - } - Err(_) => { - let body = Full::new(Bytes::from_static( - b"{\"error\":\"Failed to serialize models\"}", - )) - .map_err(|never| match never {}) - .boxed(); - Response::builder() - .status(StatusCode::INTERNAL_SERVER_ERROR) - .header("Content-Type", "application/json") - .body(body) - .unwrap() - } - } -} - -pub async fn update_preferences( - request: Request, - llm_providers: Arc>>, -) -> Result>, hyper::Error> { - let request_body = request.collect().await?.to_bytes(); - - let usage: Vec = match serde_json::from_slice(&request_body) { - Ok(usage) => usage, - Err(_) => { - let response_body = Full::new(Bytes::from_static(b"Invalid request body: ")) - .map_err(|never| match never {}) - .boxed(); - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()); - } - }; - - let usage_model_map: HashMap = - usage.into_iter().map(|u| (u.model.clone(), u)).collect(); - - info!( - "Updating usage preferences for models: {:?}", - usage_model_map.keys() - ); - - let mut llm_providers = llm_providers.write().await; - - // ensure that models coming in the request are valid - let llm_provider_names: Vec = llm_providers - .iter() - .map(|provider| provider.name.clone()) - .collect(); - - for model in usage_model_map.keys() { - if !llm_provider_names.contains(model) { - let model_not_found = format!("model not found: {}", model); - warn!("updating preferences: {}", model_not_found); - let response_body = Full::new(model_not_found.into()) - .map_err(|never| match never {}) - .boxed(); - return Ok(Response::builder() - .status(StatusCode::BAD_REQUEST) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()); - } - } - - let mut updated_models_list = Vec::new(); - for provider in llm_providers.iter_mut() { - if let Some(usage_provider) = usage_model_map.get(&provider.name) { - provider.usage = usage_provider.usage.clone(); - updated_models_list.push(ModelUsagePreference { - name: provider.name.clone(), - model: provider.model.clone().unwrap_or_default(), - usage: provider.usage.clone(), - }); - } - } - - if !updated_models_list.is_empty() { - // return list of updated models - let response_body = Full::new(Bytes::from(format!( - "{{\"updated_models\": {}}}", - serde_json::to_string(&updated_models_list).unwrap() - ))) - .map_err(|never| match never {}) - .boxed(); - Ok(Response::builder() - .status(StatusCode::OK) - .header("Content-Type", "application/json") - .body(response_body) - .unwrap()) - } else { - let response_body = Full::new(Bytes::from_static(b"Provider not found")) - .map_err(|never| match never {}) - .boxed(); - Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .header("Content-Type", "text/plain") - .body(response_body) - .unwrap()) - } -} diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index dbcc9124..b5bf0204 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,6 +1,5 @@ use brightstaff::handlers::chat_completions::chat_completions; use brightstaff::handlers::models::list_models; -use brightstaff::handlers::preferences::{list_preferences, update_preferences}; use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; @@ -116,12 +115,6 @@ async fn main() -> Result<(), Box> { .with_context(parent_cx) .await } - (&Method::GET, "/v1/router/preferences") => { - Ok(list_preferences(llm_providers).await) - } - (&Method::PUT, "/v1/router/preferences") => { - update_preferences(req, llm_providers).await - } (&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await), (&Method::OPTIONS, "/v1/models") => { let mut response = Response::new(empty()); diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index dc0e1563..bd06b525 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -73,7 +73,7 @@ impl RouterModel for RouterModelV1 { fn generate_request( &self, messages: &[Message], - usage_preferences: &Option>, + usage_preferences_from_request: &Option>, ) -> ChatCompletionsRequest { // remove system prompt, tool calls, tool call response and messages without content // if content is empty its likely a tool call @@ -150,31 +150,17 @@ impl RouterModel for RouterModelV1 { }) .collect::>(); - let llm_route_json = usage_preferences - .as_ref() - .map(|prefs| { - let llm_route: Vec = prefs - .iter() - .map(|pref| RoutingPreference { - name: pref.name.clone(), - description: pref.usage.clone().unwrap_or_default(), - }) - .collect(); - serde_json::to_string(&llm_route).unwrap_or_default() - }) - .unwrap_or_else(|| self.llm_route_json_str.clone()); - - let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT - .replace("{routes}", &llm_route_json) - .replace( - "{conversation}", - &serde_json::to_string(&selected_conversation_list).unwrap_or_default(), - ); + // Generate the router request message based on the usage preferences. + // If preferences are passed in request then we use them otherwise we use the default routing model preferences. + let router_message = match convert_to_router_preferences(usage_preferences_from_request) { + Some(prefs) => generate_router_message(&prefs, &selected_conversation_list), + None => generate_router_message(&self.llm_route_json_str, &selected_conversation_list), + }; ChatCompletionsRequest { model: self.routing_model.clone(), messages: vec![Message { - content: Some(ContentType::Text(messages_content)), + content: Some(ContentType::Text(router_message)), role: USER_ROLE.to_string(), }], temperature: Some(0.01), @@ -201,12 +187,18 @@ impl RouterModel for RouterModelV1 { if let Some(usage_preferences) = usage_preferences { // If usage preferences are defined, we need to find the model that matches the selected route - let matching_preference = usage_preferences + let model_name: Option = usage_preferences .iter() - .find(|pref| pref.name == selected_route); + .map(|pref| { + pref.routing_preferences + .iter() + .find(|routing_pref| routing_pref.name == selected_route) + .map(|_| pref.model.clone()) + }) + .find_map(|model| model); - if let Some(preference) = matching_preference { - return Ok(Some((selected_route, preference.model.clone()))); + if let Some(model_name) = model_name { + return Ok(Some((selected_route, model_name))); } else { warn!( "No matching model found for route: {}, usage preferences: {:?}", @@ -216,7 +208,7 @@ impl RouterModel for RouterModelV1 { } } - // If no usage preferences are defined, we return the route with the routing model + // If no usage preferences are passed in request then use the default routing model preferences if let Some(model) = self.llm_route_to_model_map.get(&selected_route).cloned() { return Ok(Some((selected_route, model))); } @@ -234,6 +226,37 @@ impl RouterModel for RouterModelV1 { } } +fn generate_router_message(prefs: &str, selected_conversation_list: &Vec) -> String { + ARCH_ROUTER_V1_SYSTEM_PROMPT + .replace("{routes}", prefs) + .replace( + "{conversation}", + &serde_json::to_string(&selected_conversation_list).unwrap_or_default(), + ) +} + +fn convert_to_router_preferences( + prefs_from_request: &Option>, +) -> Option { + if let Some(usage_preferences) = prefs_from_request { + let routing_preferences = usage_preferences + .iter() + .flat_map(|pref| { + pref.routing_preferences + .iter() + .map(|routing_pref| RoutingPreference { + name: routing_pref.name.clone(), + description: routing_pref.description.clone(), + }) + }) + .collect::>(); + + return Some(serde_json::to_string(&routing_preferences).unwrap_or_default()); + } + + None +} + fn fix_json_response(body: &str) -> String { let mut updated_body = body.to_string(); @@ -299,7 +322,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -356,7 +380,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -379,9 +404,11 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); let usage_preferences = Some(vec![ModelUsagePreference { - name: "code-generation".to_string(), model: "claude/claude-3-7-sonnet".to_string(), - usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()), + routing_preferences: vec![RoutingPreference { + name: "code-generation".to_string(), + description: "generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string(), + }], }]); let req = router.generate_request(&conversation, &usage_preferences); @@ -419,7 +446,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 235); @@ -478,7 +506,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 200); @@ -538,7 +567,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), 230); @@ -604,7 +634,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -672,7 +703,8 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let routing_model = "test-model".to_string(); let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); @@ -747,14 +779,18 @@ Based on your analysis, provide your response in the following JSON formats if y ] } "#; - let llm_routes = serde_json::from_str::>>(routes_str).unwrap(); + let llm_routes = + serde_json::from_str::>>(routes_str).unwrap(); let router = RouterModelV1::new(llm_routes, "test-model".to_string(), 2000); // Case 1: Valid JSON with non-empty route let input = r#"{"route": "Image generation"}"#; let result = router.parse_response(input, &None).unwrap(); - assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); // Case 2: Valid JSON with empty route let input = r#"{"route": ""}"#; @@ -784,11 +820,17 @@ Based on your analysis, provide your response in the following JSON formats if y // Case 6: Single quotes and \n in JSON let input = "{'route': 'Image generation'}\\n"; let result = router.parse_response(input, &None).unwrap(); - assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); // Case 7: Code block marker let input = "```json\n{\"route\": \"Image generation\"}\n```"; let result = router.parse_response(input, &None).unwrap(); - assert_eq!(result, Some(("Image generation".to_string(), "gpt-4o".to_string()))); + assert_eq!( + result, + Some(("Image generation".to_string(), "gpt-4o".to_string())) + ); } } diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 0693c09b..186691dc 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -1,6 +1,5 @@ use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models}; use serde::{Deserialize, Serialize}; -use serde_with::skip_serializing_none; use std::collections::HashMap; use std::fmt::Display; @@ -178,12 +177,10 @@ impl Display for LlmProviderType { } } -#[skip_serializing_none] #[derive(Serialize, Deserialize, Debug)] pub struct ModelUsagePreference { - pub name: String, pub model: String, - pub usage: Option, + pub routing_preferences: Vec, } #[derive(Debug, Clone, Serialize, Deserialize)]