diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index c2e7e78b..fff07c22 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -5,7 +5,7 @@ use common::configuration::ModelUsagePreference; use common::consts::ARCH_PROVIDER_HINT_HEADER; use hermesllm::apis::openai::ChatCompletionsRequest; use hermesllm::clients::SupportedAPIs; -use hermesllm::ProviderRequestType; +use hermesllm::{ProviderRequest, ProviderRequestType}; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; @@ -29,13 +29,13 @@ pub async fn chat( router_service: Arc, full_qualified_llm_provider_url: String, ) -> Result>, hyper::Error> { + let request_path = request.uri().path().to_string(); let mut request_headers = request.headers().clone(); - let chat_request_bytes = request.collect().await?.to_bytes(); debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes)); - let provider_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) { + let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) { Ok(request) => request, Err(err) => { warn!("Failed to parse request as ProviderRequestType: {}", err); @@ -46,9 +46,18 @@ pub async fn chat( } }; - // Convert to ChatCompletionsRequest regardless of input type + // Clone metadata for routing and remove archgw_preference_config from original + let routing_metadata = client_request.metadata().clone(); + + if client_request.remove_metadata_key("archgw_preference_config") { + debug!("Removed archgw_preference_config from metadata"); + } + + let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap(); + + // Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original) let chat_completions_request_for_arch_router: ChatCompletionsRequest = - match ProviderRequestType::try_from((provider_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) { + match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) { Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req, Ok(ProviderRequestType::MessagesRequest(_)) => { // This should not happen after conversion to OpenAI format @@ -67,7 +76,6 @@ pub async fn chat( } }; - debug!( "[BRIGHTSTAFF -> ARCH_ROUTER] REQ: {}", &serde_json::to_string(&chat_completions_request_for_arch_router).unwrap() @@ -79,7 +87,7 @@ pub async fn chat( .map(|(_, value)| value.to_str().unwrap_or_default().to_string()); let usage_preferences_str: Option = - chat_completions_request_for_arch_router.metadata.as_ref().and_then(|metadata| { + routing_metadata.as_ref().and_then(|metadata| { metadata .get("archgw_preference_config") .map(|value| value.to_string()) @@ -155,30 +163,13 @@ pub async fn chat( header::HeaderValue::from_str(&trace_parent).unwrap(), ); } - - // remove metadata from the request for downstream calls - let mut chat_request_user_preferences_removed = chat_completions_request_for_arch_router.clone(); - if let Some(ref mut metadata) = chat_request_user_preferences_removed.metadata { - metadata.remove("archgw_preference_config"); - debug!("Removed archgw_preference_config from metadata"); - - // if metadata is empty, remove it - if metadata.is_empty() { - chat_request_user_preferences_removed.metadata = None; - debug!("Removed empty metadata from request"); - } - } - - let chat_request_parsed_bytes = - serde_json::to_string(&chat_request_user_preferences_removed).unwrap(); - // remove content-length header if it exists request_headers.remove(header::CONTENT_LENGTH); let llm_response = match reqwest::Client::new() .post(full_qualified_llm_provider_url) .headers(request_headers) - .body(chat_request_parsed_bytes) + .body(client_request_bytes_for_upstream) .send() .await { diff --git a/crates/hermesllm/src/apis/anthropic.rs b/crates/hermesllm/src/apis/anthropic.rs index 4125f8d7..ae61e2fe 100644 --- a/crates/hermesllm/src/apis/anthropic.rs +++ b/crates/hermesllm/src/apis/anthropic.rs @@ -497,6 +497,18 @@ impl ProviderRequest for MessagesRequest { source: Some(Box::new(e)), }) } + + fn metadata(&self) -> &Option> { + return &self.metadata; + } + + fn remove_metadata_key(&mut self, key: &str) -> bool { + if let Some(ref mut metadata) = self.metadata { + metadata.remove(key).is_some() + } else { + false + } + } } impl MessagesResponse { diff --git a/crates/hermesllm/src/apis/openai.rs b/crates/hermesllm/src/apis/openai.rs index 85e0db34..7e89acd2 100644 --- a/crates/hermesllm/src/apis/openai.rs +++ b/crates/hermesllm/src/apis/openai.rs @@ -81,7 +81,7 @@ pub struct ChatCompletionsRequest { // Maximum tokens in the response has been deprecated, but we keep it for compatibility pub max_tokens: Option, pub modalities: Option>, - pub metadata: Option>, + pub metadata: Option>, pub n: Option, pub presence_penalty: Option, pub parallel_tool_calls: Option, @@ -599,6 +599,18 @@ impl ProviderRequest for ChatCompletionsRequest { source: Some(Box::new(e)), }) } + + fn metadata(&self) -> &Option> { + return &self.metadata; + } + + fn remove_metadata_key(&mut self, key: &str) -> bool { + if let Some(ref mut metadata) = self.metadata { + metadata.remove(key).is_some() + } else { + false + } + } } /// Implementation of ProviderResponse for ChatCompletionsResponse diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index b3f46ee1..adde81f4 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -1,8 +1,12 @@ use crate::apis::openai::ChatCompletionsRequest; use crate::apis::anthropic::MessagesRequest; use crate::clients::endpoints::SupportedAPIs; + +use serde_json::Value; use std::error::Error; use std::fmt; +use std::collections::HashMap; +#[derive(Clone)] pub enum ProviderRequestType { ChatCompletionsRequest(ChatCompletionsRequest), MessagesRequest(MessagesRequest), @@ -26,6 +30,11 @@ pub trait ProviderRequest: Send + Sync { /// Convert the request to bytes for transmission fn to_bytes(&self) -> Result, ProviderRequestError>; + + fn metadata(&self) -> &Option>; + + /// Remove a metadata key from the request and return true if the key was present + fn remove_metadata_key(&mut self, key: &str) -> bool; } impl ProviderRequest for ProviderRequestType { @@ -70,6 +79,20 @@ impl ProviderRequest for ProviderRequestType { Self::MessagesRequest(r) => r.to_bytes(), } } + + fn metadata(&self) -> &Option> { + match self { + Self::ChatCompletionsRequest(r) => r.metadata(), + Self::MessagesRequest(r) => r.metadata(), + } + } + + fn remove_metadata_key(&mut self, key: &str) -> bool { + match self { + Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key), + Self::MessagesRequest(r) => r.remove_metadata_key(key), + } + } } /// Parse the client API from a byte slice.