fixed bug whereby we were sending an OpenAIChatCompletions request object to llm_gateway even though the request may have been AnthropicMessages

This commit is contained in:
Salman Paracha 2025-09-09 22:06:25 -07:00
parent 788ff87a0c
commit 32bcb55d97
4 changed files with 64 additions and 26 deletions

View file

@ -5,7 +5,7 @@ use common::configuration::ModelUsagePreference;
use common::consts::ARCH_PROVIDER_HINT_HEADER;
use hermesllm::apis::openai::ChatCompletionsRequest;
use hermesllm::clients::SupportedAPIs;
use hermesllm::ProviderRequestType;
use hermesllm::{ProviderRequest, ProviderRequestType};
use http_body_util::combinators::BoxBody;
use http_body_util::{BodyExt, Full, StreamBody};
use hyper::body::Frame;
@ -29,13 +29,13 @@ pub async fn chat(
router_service: Arc<RouterService>,
full_qualified_llm_provider_url: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
let chat_request_bytes = request.collect().await?.to_bytes();
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
let provider_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
Ok(request) => request,
Err(err) => {
warn!("Failed to parse request as ProviderRequestType: {}", err);
@ -46,9 +46,18 @@ pub async fn chat(
}
};
// Convert to ChatCompletionsRequest regardless of input type
// Clone metadata for routing and remove archgw_preference_config from original
let routing_metadata = client_request.metadata().clone();
if client_request.remove_metadata_key("archgw_preference_config") {
debug!("Removed archgw_preference_config from metadata");
}
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
match ProviderRequestType::try_from((provider_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
Ok(ProviderRequestType::MessagesRequest(_)) => {
// This should not happen after conversion to OpenAI format
@ -67,7 +76,6 @@ pub async fn chat(
}
};
debug!(
"[BRIGHTSTAFF -> ARCH_ROUTER] REQ: {}",
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
@ -79,7 +87,7 @@ pub async fn chat(
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
let usage_preferences_str: Option<String> =
chat_completions_request_for_arch_router.metadata.as_ref().and_then(|metadata| {
routing_metadata.as_ref().and_then(|metadata| {
metadata
.get("archgw_preference_config")
.map(|value| value.to_string())
@ -155,30 +163,13 @@ pub async fn chat(
header::HeaderValue::from_str(&trace_parent).unwrap(),
);
}
// remove metadata from the request for downstream calls
let mut chat_request_user_preferences_removed = chat_completions_request_for_arch_router.clone();
if let Some(ref mut metadata) = chat_request_user_preferences_removed.metadata {
metadata.remove("archgw_preference_config");
debug!("Removed archgw_preference_config from metadata");
// if metadata is empty, remove it
if metadata.is_empty() {
chat_request_user_preferences_removed.metadata = None;
debug!("Removed empty metadata from request");
}
}
let chat_request_parsed_bytes =
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();
// remove content-length header if it exists
request_headers.remove(header::CONTENT_LENGTH);
let llm_response = match reqwest::Client::new()
.post(full_qualified_llm_provider_url)
.headers(request_headers)
.body(chat_request_parsed_bytes)
.body(client_request_bytes_for_upstream)
.send()
.await
{

View file

@ -497,6 +497,18 @@ impl ProviderRequest for MessagesRequest {
source: Some(Box::new(e)),
})
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
if let Some(ref mut metadata) = self.metadata {
metadata.remove(key).is_some()
} else {
false
}
}
}
impl MessagesResponse {

View file

@ -81,7 +81,7 @@ pub struct ChatCompletionsRequest {
// Maximum tokens in the response has been deprecated, but we keep it for compatibility
pub max_tokens: Option<u32>,
pub modalities: Option<Vec<String>>,
pub metadata: Option<HashMap<String, String>>,
pub metadata: Option<HashMap<String, Value>>,
pub n: Option<u32>,
pub presence_penalty: Option<f32>,
pub parallel_tool_calls: Option<bool>,
@ -599,6 +599,18 @@ impl ProviderRequest for ChatCompletionsRequest {
source: Some(Box::new(e)),
})
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
return &self.metadata;
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
if let Some(ref mut metadata) = self.metadata {
metadata.remove(key).is_some()
} else {
false
}
}
}
/// Implementation of ProviderResponse for ChatCompletionsResponse

View file

@ -1,8 +1,12 @@
use crate::apis::openai::ChatCompletionsRequest;
use crate::apis::anthropic::MessagesRequest;
use crate::clients::endpoints::SupportedAPIs;
use serde_json::Value;
use std::error::Error;
use std::fmt;
use std::collections::HashMap;
#[derive(Clone)]
pub enum ProviderRequestType {
ChatCompletionsRequest(ChatCompletionsRequest),
MessagesRequest(MessagesRequest),
@ -26,6 +30,11 @@ pub trait ProviderRequest: Send + Sync {
/// Convert the request to bytes for transmission
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
fn metadata(&self) -> &Option<HashMap<String, Value>>;
/// Remove a metadata key from the request and return true if the key was present
fn remove_metadata_key(&mut self, key: &str) -> bool;
}
impl ProviderRequest for ProviderRequestType {
@ -70,6 +79,20 @@ impl ProviderRequest for ProviderRequestType {
Self::MessagesRequest(r) => r.to_bytes(),
}
}
fn metadata(&self) -> &Option<HashMap<String, Value>> {
match self {
Self::ChatCompletionsRequest(r) => r.metadata(),
Self::MessagesRequest(r) => r.metadata(),
}
}
fn remove_metadata_key(&mut self, key: &str) -> bool {
match self {
Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key),
Self::MessagesRequest(r) => r.remove_metadata_key(key),
}
}
}
/// Parse the client API from a byte slice.