diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 789bfc69..1f33b4c2 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -225,6 +225,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "serde_with", "serde_yaml", "thiserror 2.0.12", "tokio", @@ -335,6 +336,7 @@ dependencies = [ "rand 0.8.5", "serde", "serde_json", + "serde_with", "serde_yaml", "thiserror 1.0.69", "tiktoken-rs", @@ -664,6 +666,12 @@ dependencies = [ "serde", ] +[[package]] +name = "dyn-clone" +version = "1.0.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" + [[package]] name = "either" version = "1.15.0" @@ -2305,6 +2313,26 @@ dependencies = [ "thiserror 1.0.69", ] +[[package]] +name = "ref-cast" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4a0ae411dbe946a674d89546582cea4ba2bb8defac896622d6496f14c23ba5cf" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1165225c21bff1f3bbce98f5a1f889949bc902d3575308cc7b0de30b4f6d27c7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.101", +] + [[package]] name = "regalloc2" version = "0.9.3" @@ -2566,6 +2594,18 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "schemars" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd191f9397d57d581cddd31014772520aa448f65ef991055d7f61582c65165f" +dependencies = [ + "dyn-clone", + "ref-cast", + "serde", + "serde_json", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2675,15 +2715,16 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.12.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d6b6f7f2fcb69f747921f79f3926bd1e203fce4fef62c268dd3abfb6d86029aa" +checksum = "bf65a400f8f66fb7b0552869ad70157166676db75ed8181f8104ea91cf9d0b42" dependencies = [ "base64 0.22.1", "chrono", "hex", "indexmap 1.9.3", "indexmap 2.9.0", + "schemars", "serde", "serde_derive", "serde_json", @@ -2693,9 +2734,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.12.0" +version = "3.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8d00caa5193a3c8362ac2b73be6b9e768aa5a4b2f721d8f4b339600c3cb51f8e" +checksum = "81679d9ed988d5e9a5e6531dc3f2c28efbd639cbd1dfb628df08edea6004da77" dependencies = [ "darling", "proc-macro2", diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index b8c2582c..23f3e8b1 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -24,6 +24,7 @@ pretty_assertions = "1.4.1" reqwest = { version = "0.12.15", features = ["stream"] } serde = { version = "1.0.219", features = ["derive"] } serde_json = "1.0.140" +serde_with = "3.13.0" serde_yaml = "0.9.34" thiserror = "2.0.12" tokio = { version = "1.44.2", features = ["full"] } diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 756e115a..55f6d475 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -1,6 +1,7 @@ use std::sync::Arc; use bytes::Bytes; +use common::configuration::ModelUsagePreference; use common::consts::ARCH_PROVIDER_HINT_HEADER; use hermesllm::providers::openai::types::ChatCompletionsRequest; use http_body_util::combinators::BoxBody; @@ -11,7 +12,7 @@ use hyper::{Request, Response, StatusCode}; use tokio::sync::mpsc; use tokio_stream::wrappers::ReceiverStream; use tokio_stream::StreamExt; -use tracing::{debug, info, warn}; +use tracing::{debug, info, trace, warn}; use crate::router::llm_router::RouterService; @@ -30,23 +31,57 @@ pub async fn chat_completions( let chat_request_bytes = request.collect().await?.to_bytes(); - let chat_completion_request: ChatCompletionsRequest = - match ChatCompletionsRequest::try_from(chat_request_bytes.as_ref()) { - Ok(request) => request, - Err(err) => { - warn!( - "arch-router request body string: {}", - String::from_utf8_lossy(&chat_request_bytes) - ); - let err_msg = format!("Failed to parse request body: {}", err); - warn!("{}", err_msg); - let mut bad_request = Response::new(full(err_msg)); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - return Ok(bad_request); - } - }; + let chat_request_parsed = serde_json::from_slice::(&chat_request_bytes) + .inspect_err(|err| { + warn!( + "Failed to parse request body as JSON: err: {}, str: {}", + err, + String::from_utf8_lossy(&chat_request_bytes) + ) + }) + .unwrap_or_else(|_| { + warn!( + "Failed to parse request body as JSON: {}", + String::from_utf8_lossy(&chat_request_bytes) + ); + serde_json::Value::Null + }); - debug!( + if chat_request_parsed == serde_json::Value::Null { + warn!("Request body is not valid JSON"); + let err_msg = "Request body is not valid JSON".to_string(); + let mut bad_request = Response::new(full(err_msg)); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + return Ok(bad_request); + } + + let chat_completion_request: ChatCompletionsRequest = + serde_json::from_value(chat_request_parsed.clone()).unwrap(); + + // remove metadata from the request + let mut chat_request_user_preferences_removed = chat_request_parsed; + if let Some(metadata) = chat_request_user_preferences_removed.get_mut("metadata") { + info!("Removing metadata from request"); + if let Some(m) = metadata.as_object_mut() { + m.remove("archgw_preference_config"); + info!("Removed archgw_preference_config from metadata"); + } + + // metadata.as_object_mut().map(|m| { + // m.remove("archgw_preference_config"); + // info!("Removed archgw_preference_config from metadata"); + // }); + + // if metadata is empty, remove it + if metadata.as_object().map_or(false, |m| m.is_empty()) { + info!("Removing empty metadata from request"); + chat_request_user_preferences_removed + .as_object_mut() + .map(|m| m.remove("metadata")); + } + } + + trace!( "arch-router request body: {}", &serde_json::to_string(&chat_completion_request).unwrap() ); @@ -56,8 +91,25 @@ pub async fn chat_completions( .find(|(ty, _)| ty.as_str() == "traceparent") .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); + let usage_preferences_str: Option = + chat_completion_request.metadata.and_then(|metadata| { + metadata + .get("archgw_preference_config") + .and_then(|value| value.as_str().map(String::from)) + }); + + let usage_preferences: Option> = usage_preferences_str + .as_ref() + .and_then(|s| serde_yaml::from_str(s).ok()); + + debug!("usage preferences: {:?}", usage_preferences); + let mut selected_llm = match router_service - .determine_route(&chat_completion_request.messages, trace_parent.clone()) + .determine_route( + &chat_completion_request.messages, + trace_parent.clone(), + usage_preferences, + ) .await { Ok(route) => route, @@ -93,10 +145,16 @@ pub async fn chat_completions( ); } + let chat_request_parsed_bytes = + serde_json::to_string(&chat_request_user_preferences_removed).unwrap(); + + // remove content-length header if it exists + request_headers.remove(header::CONTENT_LENGTH); + let llm_response = match reqwest::Client::new() .post(llm_provider_endpoint) .headers(request_headers) - .body(chat_request_bytes) + .body(chat_request_parsed_bytes) .send() .await { diff --git a/crates/brightstaff/src/handlers/mod.rs b/crates/brightstaff/src/handlers/mod.rs index 6de38b5b..febab6c2 100644 --- a/crates/brightstaff/src/handlers/mod.rs +++ b/crates/brightstaff/src/handlers/mod.rs @@ -1,2 +1,3 @@ pub mod chat_completions; pub mod models; +pub mod preferences; diff --git a/crates/brightstaff/src/handlers/models.rs b/crates/brightstaff/src/handlers/models.rs index 5e4b55b2..3a4662a6 100644 --- a/crates/brightstaff/src/handlers/models.rs +++ b/crates/brightstaff/src/handlers/models.rs @@ -7,10 +7,10 @@ use serde_json; use std::sync::Arc; pub async fn list_models( - llm_providers: Arc>, + llm_providers: Arc>>, ) -> Response> { - let prov = llm_providers.clone(); - let providers = (*prov).clone(); + let prov = llm_providers.read().await; + let providers = prov.clone(); let openai_models: Models = providers.into_models(); match serde_json::to_string(&openai_models) { diff --git a/crates/brightstaff/src/handlers/preferences.rs b/crates/brightstaff/src/handlers/preferences.rs new file mode 100644 index 00000000..a9c5a65d --- /dev/null +++ b/crates/brightstaff/src/handlers/preferences.rs @@ -0,0 +1,135 @@ +use bytes::Bytes; +use common::configuration::{LlmProvider, ModelUsagePreference}; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use hyper::{Request, Response, StatusCode}; +use serde_json; +use std::{collections::HashMap, sync::Arc}; +use tracing::{info, warn}; + +pub async fn list_preferences( + llm_providers: Arc>>, +) -> Response> { + let prov = llm_providers.read().await; + // convert the LlmProvider to UsageBasedProvider + let providers_with_usage = prov + .iter() + .map(|provider| ModelUsagePreference { + name: provider.name.clone(), + model: provider.model.clone().unwrap_or_default(), + usage: provider.usage.clone(), + }) + .collect::>(); + + match serde_json::to_string(&providers_with_usage) { + Ok(json) => { + let body = Full::new(Bytes::from(json)) + .map_err(|never| match never {}) + .boxed(); + Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(body) + .unwrap() + } + Err(_) => { + let body = Full::new(Bytes::from_static( + b"{\"error\":\"Failed to serialize models\"}", + )) + .map_err(|never| match never {}) + .boxed(); + Response::builder() + .status(StatusCode::INTERNAL_SERVER_ERROR) + .header("Content-Type", "application/json") + .body(body) + .unwrap() + } + } +} + +pub async fn update_preferences( + request: Request, + llm_providers: Arc>>, +) -> Result>, hyper::Error> { + let request_body = request.collect().await?.to_bytes(); + + let usage: Vec = match serde_json::from_slice(&request_body) { + Ok(usage) => usage, + Err(_) => { + let response_body = Full::new(Bytes::from_static(b"Invalid request body: ")) + .map_err(|never| match never {}) + .boxed(); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header("Content-Type", "text/plain") + .body(response_body) + .unwrap()); + } + }; + + let usage_model_map: HashMap = + usage.into_iter().map(|u| (u.model.clone(), u)).collect(); + + info!( + "Updating usage preferences for models: {:?}", + usage_model_map.keys() + ); + + let mut llm_providers = llm_providers.write().await; + + // ensure that models coming in the request are valid + let llm_provider_names: Vec = llm_providers + .iter() + .map(|provider| provider.name.clone()) + .collect(); + + for model in usage_model_map.keys() { + if !llm_provider_names.contains(model) { + let model_not_found = format!("model not found: {}", model); + warn!("updating preferences: {}", model_not_found); + let response_body = Full::new(model_not_found.into()) + .map_err(|never| match never {}) + .boxed(); + return Ok(Response::builder() + .status(StatusCode::BAD_REQUEST) + .header("Content-Type", "text/plain") + .body(response_body) + .unwrap()); + } + } + + let mut updated_models_list = Vec::new(); + for provider in llm_providers.iter_mut() { + if let Some(usage_provider) = usage_model_map.get(&provider.name) { + provider.usage = usage_provider.usage.clone(); + updated_models_list.push(ModelUsagePreference { + name: provider.name.clone(), + model: provider.model.clone().unwrap_or_default(), + usage: provider.usage.clone(), + }); + } + } + + if !updated_models_list.is_empty() { + // return list of updated models + let response_body = Full::new(Bytes::from(format!( + "{{\"updated_models\": {}}}", + serde_json::to_string(&updated_models_list).unwrap() + ))) + .map_err(|never| match never {}) + .boxed(); + Ok(Response::builder() + .status(StatusCode::OK) + .header("Content-Type", "application/json") + .body(response_body) + .unwrap()) + } else { + let response_body = Full::new(Bytes::from_static(b"Provider not found")) + .map_err(|never| match never {}) + .boxed(); + Ok(Response::builder() + .status(StatusCode::NOT_FOUND) + .header("Content-Type", "text/plain") + .body(response_body) + .unwrap()) + } +} diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 5502c983..25ea72ff 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,5 +1,6 @@ use brightstaff::handlers::chat_completions::chat_completions; use brightstaff::handlers::models::list_models; +use brightstaff::handlers::preferences::{list_preferences, update_preferences}; use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; @@ -16,7 +17,8 @@ use opentelemetry_http::HeaderExtractor; use std::sync::Arc; use std::{env, fs}; use tokio::net::TcpListener; -use tracing::{debug, info}; +use tokio::sync::RwLock; +use tracing::{debug, info, warn}; pub mod router; @@ -53,7 +55,7 @@ async fn main() -> Result<(), Box> { let arch_config = Arc::new(config); - let llm_providers = Arc::new(arch_config.llm_providers.clone()); + let llm_providers = Arc::new(RwLock::new(arch_config.llm_providers.clone())); debug!( "arch_config: {:?}", @@ -101,6 +103,12 @@ async fn main() -> Result<(), Box> { .with_context(parent_cx) .await } + (&Method::GET, "/v1/router/preferences") => { + Ok(list_preferences(llm_providers).await) + } + (&Method::PUT, "/v1/router/preferences") => { + update_preferences(req, llm_providers).await + } (&Method::GET, "/v1/models") => Ok(list_models(llm_providers).await), (&Method::OPTIONS, "/v1/models") => { let mut response = Response::new(empty()); @@ -141,7 +149,7 @@ async fn main() -> Result<(), Box> { .serve_connection(io, service) .await { - info!("Error serving connection: {:?}", err); + warn!("Error serving connection: {:?}", err); } }); } diff --git a/crates/brightstaff/src/router/llm_router.rs b/crates/brightstaff/src/router/llm_router.rs index 4a510caa..0dab0a18 100644 --- a/crates/brightstaff/src/router/llm_router.rs +++ b/crates/brightstaff/src/router/llm_router.rs @@ -1,7 +1,7 @@ -use std::sync::Arc; +use std::{collections::HashMap, sync::Arc}; use common::{ - configuration::{LlmProvider, LlmRoute}, + configuration::{LlmProvider, LlmRoute, ModelUsagePreference}, consts::ARCH_PROVIDER_HINT_HEADER, }; use hermesllm::providers::openai::types::{ChatCompletionsResponse, ContentType, Message}; @@ -19,6 +19,7 @@ pub struct RouterService { router_model: Arc, routing_model_name: String, llm_usage_defined: bool, + llm_provider_map: HashMap, } #[derive(Debug, Error)] @@ -55,12 +56,18 @@ impl RouterService { router_model_v1::MAX_TOKEN_LEN, )); + let llm_provider_map: HashMap = providers + .into_iter() + .map(|provider| (provider.name.clone(), provider)) + .collect(); + RouterService { router_url, client: reqwest::Client::new(), router_model, routing_model_name, llm_usage_defined: !providers_with_usage.is_empty(), + llm_provider_map, } } @@ -68,12 +75,15 @@ impl RouterService { &self, messages: &[Message], trace_parent: Option, + usage_preferences: Option>, ) -> Result> { if !self.llm_usage_defined { return Ok(None); } - let router_request = self.router_model.generate_request(messages); + let router_request = self + .router_model + .generate_request(messages, &usage_preferences); info!( "sending request to arch-router model: {}, endpoint: {}", @@ -144,13 +154,40 @@ impl RouterService { if let Some(ContentType::Text(content)) = &chat_completion_response.choices[0].message.content { + let mut selected_model: Option = None; + if let Some(selected_llm_name) = self.router_model.parse_response(content)? { + if selected_llm_name != "other" { + if let Some(usage_preferences) = usage_preferences { + for usage in usage_preferences { + if usage.name == selected_llm_name { + selected_model = Some(usage.model); + break; + } + } + if selected_model.is_none() { + warn!( + "Selected LLM model not found in usage preferences: {}", + selected_llm_name + ); + } + } else if let Some(provider) = self.llm_provider_map.get(&selected_llm_name) { + selected_model = provider.model.clone(); + } else { + warn!( + "Selected LLM model not found in provider map: {}", + selected_llm_name + ); + } + } + } info!( - "router response: {}, response time: {}ms", + "router response: {}, selected_model: {:?}, response time: {}ms", content.replace("\n", "\\n"), + selected_model, router_response_time.as_millis() ); - let selected_llm = self.router_model.parse_response(content)?; - Ok(selected_llm) + + Ok(selected_model) } else { Ok(None) } diff --git a/crates/brightstaff/src/router/router_model.rs b/crates/brightstaff/src/router/router_model.rs index c2ed43c9..dafa8776 100644 --- a/crates/brightstaff/src/router/router_model.rs +++ b/crates/brightstaff/src/router/router_model.rs @@ -1,3 +1,4 @@ +use common::configuration::ModelUsagePreference; use hermesllm::providers::openai::types::{ChatCompletionsRequest, Message}; use thiserror::Error; @@ -10,7 +11,11 @@ pub enum RoutingModelError { pub type Result = std::result::Result; pub trait RouterModel: Send + Sync { - fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest; + fn generate_request( + &self, + messages: &[Message], + usage_preferences: &Option>, + ) -> ChatCompletionsRequest; fn parse_response(&self, content: &str) -> Result>; fn get_model_name(&self) -> String; } diff --git a/crates/brightstaff/src/router/router_model_v1.rs b/crates/brightstaff/src/router/router_model_v1.rs index dc623f0a..e6ccd912 100644 --- a/crates/brightstaff/src/router/router_model_v1.rs +++ b/crates/brightstaff/src/router/router_model_v1.rs @@ -1,5 +1,5 @@ use common::{ - configuration::LlmRoute, + configuration::{LlmRoute, ModelUsagePreference}, consts::{SYSTEM_ROLE, TOOL_ROLE, USER_ROLE}, }; use hermesllm::providers::openai::types::{ChatCompletionsRequest, ContentType, Message}; @@ -55,7 +55,11 @@ struct LlmRouterResponse { const TOKEN_LENGTH_DIVISOR: usize = 4; // Approximate token length divisor for UTF-8 characters impl RouterModel for RouterModelV1 { - fn generate_request(&self, messages: &[Message]) -> ChatCompletionsRequest { + fn generate_request( + &self, + messages: &[Message], + usage_preferences: &Option>, + ) -> ChatCompletionsRequest { // remove system prompt, tool calls, tool call response and messages without content // if content is empty its likely a tool call // when role == tool its tool call response @@ -131,8 +135,22 @@ impl RouterModel for RouterModelV1 { }) .collect::>(); + let llm_route_json = usage_preferences + .as_ref() + .map(|prefs| { + let llm_route: Vec = prefs + .iter() + .map(|pref| LlmRoute { + name: pref.name.clone(), + description: pref.usage.clone().unwrap_or_default(), + }) + .collect(); + serde_json::to_string(&llm_route).unwrap_or_default() + }) + .unwrap_or_else(|| self.llm_route_json_str.clone()); + let messages_content = ARCH_ROUTER_V1_SYSTEM_PROMPT - .replace("{routes}", &self.llm_route_json_str) + .replace("{routes}", &llm_route_json) .replace( "{conversation}", &serde_json::to_string(&selected_conversation_list).unwrap_or_default(), @@ -204,8 +222,6 @@ impl std::fmt::Debug for dyn RouterModel { #[cfg(test)] mod tests { - use crate::utils::tracing::init_tracer; - use super::*; use pretty_assertions::assert_eq; @@ -261,7 +277,71 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, &None); + + let prompt = req.messages[0].content.as_ref().unwrap(); + + assert_eq!(expected_prompt, prompt.to_string()); + } + + #[test] + fn test_system_prompt_format_usage_preferences() { + let expected_prompt = r#" +You are a helpful assistant designed to find the best suited route. +You are provided with route description within XML tags: + +[{"name":"code-generation","description":"generating new code snippets, functions, or boilerplate based on user prompts or requirements"}] + + + +[{"role":"user","content":"hi"},{"role":"assistant","content":"Hello! How can I assist you today?"},{"role":"user","content":"given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson"}] + + +Your task is to decide which route is best suit with user intent on the conversation in XML tags. Follow the instruction: +1. If the latest intent from user is irrelevant or user intent is full filled, response with other route {"route": "other"}. +2. You must analyze the route descriptions and find the best match route for user latest intent. +3. You only response the name of the route that best matches the user's request, use the exact name in the . + +Based on your analysis, provide your response in the following JSON formats if you decide to match any route: +{"route": "route_name"} +"#; + let routes_str = r#" + [ + {"name": "Image generation", "description": "generating image"}, + {"name": "image conversion", "description": "convert images to provided format"}, + {"name": "image search", "description": "search image"}, + {"name": "Audio Processing", "description": "Analyzing and interpreting audio input including speech, music, and environmental sounds"}, + {"name": "Speech Recognition", "description": "Converting spoken language into written text"} + ] + "#; + let llm_routes = serde_json::from_str::>(routes_str).unwrap(); + let routing_model = "test-model".to_string(); + let router = RouterModelV1::new(llm_routes, routing_model.clone(), usize::MAX); + + let conversation_str = r#" + [ + { + "role": "user", + "content": "hi" + }, + { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + { + "role": "user", + "content": "given the image In style of Andy Warhol, portrait of Bart and Lisa Simpson" + } + ] + "#; + let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); + + let usage_preferences = Some(vec![ModelUsagePreference { + name: "code-generation".to_string(), + model: "claude/claude-3-7-sonnet".to_string(), + usage: Some("generating new code snippets, functions, or boilerplate based on user prompts or requirements".to_string()), + }]); + let req = router.generate_request(&conversation, &usage_preferences); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -270,7 +350,6 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_conversation_exceed_token_count() { - let _tracer = init_tracer(); let expected_prompt = r#" You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: @@ -323,7 +402,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -332,7 +411,6 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_conversation_exceed_token_count_large_single_message() { - let _tracer = init_tracer(); let expected_prompt = r#" You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: @@ -385,7 +463,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -394,7 +472,6 @@ Based on your analysis, provide your response in the following JSON formats if y #[test] fn test_conversation_trim_upto_user_message() { - let _tracer = init_tracer(); let expected_prompt = r#" You are a helpful assistant designed to find the best suited route. You are provided with route description within XML tags: @@ -455,7 +532,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -525,7 +602,7 @@ Based on your analysis, provide your response in the following JSON formats if y "#; let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); @@ -621,7 +698,7 @@ Based on your analysis, provide your response in the following JSON formats if y let conversation: Vec = serde_json::from_str(conversation_str).unwrap(); - let req = router.generate_request(&conversation); + let req = router.generate_request(&conversation, &None); let prompt = req.messages[0].content.as_ref().unwrap(); diff --git a/crates/common/Cargo.toml b/crates/common/Cargo.toml index 4696b43b..aa95e2e4 100644 --- a/crates/common/Cargo.toml +++ b/crates/common/Cargo.toml @@ -19,6 +19,7 @@ hex = "0.4.3" urlencoding = "2.1.3" url = "2.5.4" hermesllm = { version = "0.1.0", path = "../hermesllm" } +serde_with = "3.13.0" [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 0dbd0b70..80ec98bb 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -1,5 +1,6 @@ use hermesllm::providers::openai::types::{ModelDetail, ModelObject, Models}; use serde::{Deserialize, Serialize}; +use serde_with::skip_serializing_none; use std::collections::HashMap; use std::fmt::Display; @@ -176,6 +177,14 @@ impl Display for LlmProviderType { } } +#[skip_serializing_none] +#[derive(Serialize, Deserialize, Debug)] +pub struct ModelUsagePreference { + pub name: String, + pub model: String, + pub usage: Option, +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LlmRoute { pub name: String, diff --git a/crates/hermesllm/src/providers/openai/builder.rs b/crates/hermesllm/src/providers/openai/builder.rs index 43c4176f..fa1f325e 100644 --- a/crates/hermesllm/src/providers/openai/builder.rs +++ b/crates/hermesllm/src/providers/openai/builder.rs @@ -101,6 +101,7 @@ impl OpenAIRequestBuilder { frequency_penalty: self.frequency_penalty, stream_options: self.stream_options, tools: self.tools, + metadata: None, }; Ok(request) } diff --git a/crates/hermesllm/src/providers/openai/types.rs b/crates/hermesllm/src/providers/openai/types.rs index 6f4e38d7..d1c4430c 100644 --- a/crates/hermesllm/src/providers/openai/types.rs +++ b/crates/hermesllm/src/providers/openai/types.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::fmt::Display; use serde::{Deserialize, Serialize}; @@ -109,6 +110,7 @@ pub struct ChatCompletionsRequest { pub frequency_penalty: Option, pub stream_options: Option, pub tools: Option>, + pub metadata: Option>, } impl TryFrom<&[u8]> for ChatCompletionsRequest { diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 050afce9..6eebb398 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -569,7 +569,11 @@ impl HttpContext for StreamContext { match SseChatCompletionIter::try_from((body.as_slice(), &hermes_llm_provider)) { Ok(events) => events, Err(e) => { - warn!("could not parse response: {}", e); + warn!( + "could not parse response: {}, body str: {}", + e, + String::from_utf8_lossy(&body) + ); return Action::Continue; } }; @@ -614,7 +618,11 @@ impl HttpContext for StreamContext { match ChatCompletionsResponse::try_from((body.as_slice(), &hermes_llm_provider)) { Ok(de) => de, Err(e) => { - warn!("could not parse response: {}", e); + warn!( + "could not parse response: {}, body str: {}", + e, + String::from_utf8_lossy(&body) + ); debug!( "on_http_response_body: S[{}], response body: {}", self.context_id, diff --git a/crates/prompt_gateway/src/http_context.rs b/crates/prompt_gateway/src/http_context.rs index cd251064..fc66de12 100644 --- a/crates/prompt_gateway/src/http_context.rs +++ b/crates/prompt_gateway/src/http_context.rs @@ -62,7 +62,7 @@ impl HttpContext for StreamContext { return Action::Continue; } - self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(&request_path.as_str()); + self.is_chat_completions_request = CHAT_COMPLETIONS_PATH.contains(request_path.as_str()); debug!( "on_http_request_headers S[{}] req_headers={:?}",