From 4373aeb00be832aa3b33e3789334c705ba7a84aa Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 24 Jun 2025 23:57:28 -0700 Subject: [PATCH] add more changes --- .../src/handlers/chat_completions.rs | 20 +++++++++++- .../brightstaff/src/handlers/preferences.rs | 32 ++++++++----------- crates/brightstaff/src/main.rs | 10 +++--- crates/brightstaff/src/router/llm_router.rs | 7 ++-- crates/brightstaff/src/router/router_model.rs | 7 +++- .../brightstaff/src/router/router_model_v1.rs | 32 +++++++++++-------- crates/common/src/configuration.rs | 8 +++++ .../hermesllm/src/providers/openai/builder.rs | 1 + .../hermesllm/src/providers/openai/types.rs | 2 ++ 9 files changed, 78 insertions(+), 41 deletions(-) diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 756e115a..1bd44498 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use bytes::Bytes; +use common::configuration::ModelUsagePreference; use common::consts::ARCH_PROVIDER_HINT_HEADER; use hermesllm::providers::openai::types::ChatCompletionsRequest; use http_body_util::combinators::BoxBody; @@ -56,8 +57,25 @@ pub async fn chat_completions( .find(|(ty, _)| ty.as_str() == "traceparent") .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); + let usage_preferences_str: Option = + chat_completion_request.metadata.and_then(|metadata| { + metadata + .get("archgw_preference_config") + .and_then(|value| value.as_str().map(String::from)) + }); + + let usage_preferences: Option> = usage_preferences_str + .as_ref() + .and_then(|s| serde_yaml::from_str(s).ok()); + + debug!("usage preferences: {:?}", usage_preferences); + let mut selected_llm = match router_service - .determine_route(&chat_completion_request.messages, trace_parent.clone()) + .determine_route( + &chat_completion_request.messages, + trace_parent.clone(), + usage_preferences, + ) .await { Ok(route) => route, diff --git a/crates/brightstaff/src/handlers/preferences.rs b/crates/brightstaff/src/handlers/preferences.rs index c4d365e7..9dd68dd1 100644 --- a/crates/brightstaff/src/handlers/preferences.rs +++ b/crates/brightstaff/src/handlers/preferences.rs @@ -1,19 +1,10 @@ use bytes::Bytes; -use common::configuration::LlmProvider; +use common::configuration::{LlmProvider, ModelUsagePreference}; use http_body_util::{combinators::BoxBody, BodyExt, Full}; use hyper::{Request, Response, StatusCode}; -use serde::{Deserialize, Serialize}; use serde_json; -use tracing::{info, warn}; use std::{collections::HashMap, sync::Arc}; -use serde_with::skip_serializing_none; - -#[skip_serializing_none] -#[derive(Serialize, Deserialize)] -struct UsageBasedProvider { - model: String, - usage: Option, -} +use tracing::{info, warn}; pub async fn list_preferences( llm_providers: Arc>>, @@ -22,11 +13,11 @@ pub async fn list_preferences( // convert the LlmProvider to UsageBasedProvider let providers_with_usage = prov .iter() - .map(|provider| UsageBasedProvider { + .map(|provider| ModelUsagePreference { model: provider.name.clone(), usage: provider.usage.clone(), }) - .collect::>(); + .collect::>(); match serde_json::to_string(&providers_with_usage) { Ok(json) => { @@ -60,7 +51,7 @@ pub async fn update_preferences( ) -> Result>, hyper::Error> { let request_body = request.collect().await?.to_bytes(); - let usage: Vec = match serde_json::from_slice(&request_body) { + let usage: Vec = match serde_json::from_slice(&request_body) { Ok(usage) => usage, Err(_) => { let response_body = Full::new(Bytes::from_static(b"Invalid request body: ")) @@ -74,10 +65,13 @@ pub async fn update_preferences( } }; - let usage_model_map: HashMap = + let usage_model_map: HashMap = usage.into_iter().map(|u| (u.model.clone(), u)).collect(); - info!("Updating usage preferences for models: {:?}", usage_model_map.keys()); + info!( + "Updating usage preferences for models: {:?}", + usage_model_map.keys() + ); let mut llm_providers = llm_providers.write().await; @@ -106,7 +100,7 @@ pub async fn update_preferences( for provider in llm_providers.iter_mut() { if let Some(usage_provider) = usage_model_map.get(&provider.name) { provider.usage = usage_provider.usage.clone(); - updated_models_list.push(UsageBasedProvider { + updated_models_list.push(ModelUsagePreference { model: provider.name.clone(), usage: provider.usage.clone(), }); @@ -121,11 +115,11 @@ pub async fn update_preferences( ))) .map_err(|never| match never {}) .boxed(); - return Ok(Response::builder() + Ok(Response::builder() .status(StatusCode::OK) .header("Content-Type", "application/json") .body(response_body) - .unwrap()); + .unwrap()) } else { let response_body = Full::new(Bytes::from_static(b"Provider not found")) .map_err(|never| match never {}) diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 65b47cd5..25ea72ff 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -103,10 +103,12 @@ async fn main() -> Result<(), Box> { .with_context(parent_cx) .await } - (&Method::GET, "/v1/router/preferences") => Ok(list_preferences(llm_providers).await), - (&Method::PUT, "/v1/router/preferences") => { - update_preferences(req, llm_providers).await - }, + (&Method::GET, "/v1/router/preferences") => { + Ok(list_preferences(llm_providers).await) + } + (&Method::PUT, "/v1/router/preferences") => { + update_preferences(req, llm_providers).await + } (&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await), (&Method::OPTIONS, "/v1/models") => { let mut response = Response::new(empty()); diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index 4a510caa..c72b19e9 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -1,7 +1,7 @@ use std::sync::Arc; use common::{ - configuration::{LlmProvider, LlmRoute}, + configuration::{LlmProvider, LlmRoute, ModelUsagePreference}, consts::ARCH_PROVIDER_HINT_HEADER, }; use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message}; @@ -68,12 +68,15 @@ impl RouterService { &self, messages: &[Message], trace_parent: Option, + usage_preferences: Option>, ) -> Result> { if !self.llm_usage_defined { return Ok(None); } - let router_request = self.router_model.generate_request(messages); + let router_request = self + .router_model + .generate_request(messages, usage_preferences); info!( "sending request to arch-router model: {}, endpoint: {}", diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index c2ed43c9..b377b3a3 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -1,3 +1,4 @@ +use common::configuration::ModelUsagePreference; use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message}; use thiserror::Error; @@ -10,7 +11,11 @@ pub enum RoutingModelError { pub type Result = std::result::Result; pub trait RouterModel: Send + Sync { - fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest; + fn generate_request( + &self, + messages: &[Message], + usage_preferences: Option>, + ) -> ChatCompletionsRequest; fn parse_response(&self, content: &str) -> Result>; fn get_model_name(&self) -> String; } diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index dc623f0a..7dd57223 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,5 +1,5 @@ use common::{ - configuration::LlmRoute, + configuration::{LlmRoute, ModelUsagePreference}, consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE}, }; use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message}; @@ -55,7 +55,11 @@ struct LlmRouterResponse { const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters impl RouterModel for RouterModelV1 { - fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest { + fn generate_request( + &self, + messages: &[Message], + usage_preferences: Option>, + ) -> ChatCompletionsRequest { // remove system prompt, tool calls, tool call response and messages without content // if content is empty its likely a tool call // when role == tool its tool call response @@ -131,8 +135,13 @@ impl RouterModel for RouterModelV1 { }) .collect::>(); + let llm_route_json = usage_preferences + .as_ref() + .map(|prefs| serde_json::to_string(prefs).unwrap_or_default()) + .unwrap_or_else(|| self.llm_route_json_str.clone()); + let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT - .replace("{routes}", &self.llm_route_json_str) + .replace("{routes}", &llm_route_json) .replace( "{conversation}", &serde_json::to_string(&selected_conversation_list).unwrap_or_default(), @@ -204,8 +213,6 @@ impl std::fmt::Debug for dyn RouterModel { #[cfg(test)] mod tests { - use crate::utils::tracing::init_tracer; - use super::*; use pretty_assertions::assert_eq; @@ -261,7 +268,7 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -270,7 +277,6 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_conversation_exceed_token_count() { - let _tracer = init_tracer(); let expected_prompt = r#" You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: @@ -323,7 +329,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -332,7 +338,6 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_conversation_exceed_token_count_large_single_message() { - let _tracer = init_tracer(); let expected_prompt = r#" You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: @@ -385,7 +390,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -394,7 +399,6 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_conversation_trim_upto_user_message() { - let _tracer = init_tracer(); let expected_prompt = r#" You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: @@ -455,7 +459,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -525,7 +529,7 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -621,7 +625,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, None); let prompt = req.messages[0].content.as_ref().unwrap(); diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 0dbd0b70..f46b6cc6 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -2,6 +2,7 @@ use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt::Display; +use serde_with::skip_serializing_none; use crate::api::open_ai::{ ChatCompletionTool, FunctionDefinition, FunctionParameter, FunctionParameters, ParameterType, @@ -176,6 +177,13 @@ impl Display for LlmProviderType { } } +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug)] +pub struct ModelUsagePreference { + pub model: String, + pub usage: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmRoute { pub name: String, diff --git a/crates/hermesllm/src/providers/openai/builder.rs b/crates/hermesllm/src/providers/openai/builder.rs index 43c4176f..fa1f325e 100644 --- a/crates/hermesllm/src/providers/openai/builder.rs +++ b/crates/hermesllm/src/providers/openai/builder.rs @@ -101,6 +101,7 @@ impl OpenAIRequestBuilder { frequency_penalty: self.frequency_penalty, stream_options: self.stream_options, tools: self.tools, + metadata: None, }; Ok(request) } diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index 6f4e38d7..d1c4430c 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::fmt::Display; use serde::{Deserialize, Serialize}; @@ -109,6 +110,7 @@ pub struct ChatCompletionsRequest { pub frequency_penalty: Option, pub stream_options: Option, pub tools: Option>, + pub metadata: Option>, } impl TryFrom<&[u8]> for ChatCompletionsRequest {