mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixed bug whereby we were sending an OpenAIChatCompletions request object to llm_gateway even though the request may have been AnthropicMessages
This commit is contained in:
parent
788ff87a0c
commit
32bcb55d97
4 changed files with 64 additions and 26 deletions
|
|
@ -5,7 +5,7 @@ use common::configuration::ModelUsagePreference;
|
|||
use common::consts::ARCH_PROVIDER_HINT_HEADER;
|
||||
use hermesllm::apis::openai::ChatCompletionsRequest;
|
||||
use hermesllm::clients::SupportedAPIs;
|
||||
use hermesllm::ProviderRequestType;
|
||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||
use http_body_util::combinators::BoxBody;
|
||||
use http_body_util::{BodyExt, Full, StreamBody};
|
||||
use hyper::body::Frame;
|
||||
|
|
@ -29,13 +29,13 @@ pub async fn chat(
|
|||
router_service: Arc<RouterService>,
|
||||
full_qualified_llm_provider_url: String,
|
||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||
|
||||
let request_path = request.uri().path().to_string();
|
||||
let mut request_headers = request.headers().clone();
|
||||
|
||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||
|
||||
debug!("Received request body (raw utf8): {}", String::from_utf8_lossy(&chat_request_bytes));
|
||||
let provider_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
|
||||
let mut client_request = match ProviderRequestType::try_from((&chat_request_bytes[..], &SupportedAPIs::from_endpoint(request_path.as_str()).unwrap())) {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
||||
|
|
@ -46,9 +46,18 @@ pub async fn chat(
|
|||
}
|
||||
};
|
||||
|
||||
// Convert to ChatCompletionsRequest regardless of input type
|
||||
// Clone metadata for routing and remove archgw_preference_config from original
|
||||
let routing_metadata = client_request.metadata().clone();
|
||||
|
||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||
debug!("Removed archgw_preference_config from metadata");
|
||||
}
|
||||
|
||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||
|
||||
// Convert to ChatCompletionsRequest regardless of input type (clone to avoid moving original)
|
||||
let chat_completions_request_for_arch_router: ChatCompletionsRequest =
|
||||
match ProviderRequestType::try_from((provider_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
|
||||
match ProviderRequestType::try_from((client_request, &SupportedAPIs::OpenAIChatCompletions(hermesllm::apis::OpenAIApi::ChatCompletions))) {
|
||||
Ok(ProviderRequestType::ChatCompletionsRequest(req)) => req,
|
||||
Ok(ProviderRequestType::MessagesRequest(_)) => {
|
||||
// This should not happen after conversion to OpenAI format
|
||||
|
|
@ -67,7 +76,6 @@ pub async fn chat(
|
|||
}
|
||||
};
|
||||
|
||||
|
||||
debug!(
|
||||
"[BRIGHTSTAFF -> ARCH_ROUTER] REQ: {}",
|
||||
&serde_json::to_string(&chat_completions_request_for_arch_router).unwrap()
|
||||
|
|
@ -79,7 +87,7 @@ pub async fn chat(
|
|||
.map(|(_, value)| value.to_str().unwrap_or_default().to_string());
|
||||
|
||||
let usage_preferences_str: Option<String> =
|
||||
chat_completions_request_for_arch_router.metadata.as_ref().and_then(|metadata| {
|
||||
routing_metadata.as_ref().and_then(|metadata| {
|
||||
metadata
|
||||
.get("archgw_preference_config")
|
||||
.map(|value| value.to_string())
|
||||
|
|
@ -155,30 +163,13 @@ pub async fn chat(
|
|||
header::HeaderValue::from_str(&trace_parent).unwrap(),
|
||||
);
|
||||
}
|
||||
|
||||
// remove metadata from the request for downstream calls
|
||||
let mut chat_request_user_preferences_removed = chat_completions_request_for_arch_router.clone();
|
||||
if let Some(ref mut metadata) = chat_request_user_preferences_removed.metadata {
|
||||
metadata.remove("archgw_preference_config");
|
||||
debug!("Removed archgw_preference_config from metadata");
|
||||
|
||||
// if metadata is empty, remove it
|
||||
if metadata.is_empty() {
|
||||
chat_request_user_preferences_removed.metadata = None;
|
||||
debug!("Removed empty metadata from request");
|
||||
}
|
||||
}
|
||||
|
||||
let chat_request_parsed_bytes =
|
||||
serde_json::to_string(&chat_request_user_preferences_removed).unwrap();
|
||||
|
||||
// remove content-length header if it exists
|
||||
request_headers.remove(header::CONTENT_LENGTH);
|
||||
|
||||
let llm_response = match reqwest::Client::new()
|
||||
.post(full_qualified_llm_provider_url)
|
||||
.headers(request_headers)
|
||||
.body(chat_request_parsed_bytes)
|
||||
.body(client_request_bytes_for_upstream)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
|
|
|
|||
|
|
@ -497,6 +497,18 @@ impl ProviderRequest for MessagesRequest {
|
|||
source: Some(Box::new(e)),
|
||||
})
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
if let Some(ref mut metadata) = self.metadata {
|
||||
metadata.remove(key).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl MessagesResponse {
|
||||
|
|
|
|||
|
|
@ -81,7 +81,7 @@ pub struct ChatCompletionsRequest {
|
|||
// Maximum tokens in the response has been deprecated, but we keep it for compatibility
|
||||
pub max_tokens: Option<u32>,
|
||||
pub modalities: Option<Vec<String>>,
|
||||
pub metadata: Option<HashMap<String, String>>,
|
||||
pub metadata: Option<HashMap<String, Value>>,
|
||||
pub n: Option<u32>,
|
||||
pub presence_penalty: Option<f32>,
|
||||
pub parallel_tool_calls: Option<bool>,
|
||||
|
|
@ -599,6 +599,18 @@ impl ProviderRequest for ChatCompletionsRequest {
|
|||
source: Some(Box::new(e)),
|
||||
})
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
return &self.metadata;
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
if let Some(ref mut metadata) = self.metadata {
|
||||
metadata.remove(key).is_some()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Implementation of ProviderResponse for ChatCompletionsResponse
|
||||
|
|
|
|||
|
|
@ -1,8 +1,12 @@
|
|||
use crate::apis::openai::ChatCompletionsRequest;
|
||||
use crate::apis::anthropic::MessagesRequest;
|
||||
use crate::clients::endpoints::SupportedAPIs;
|
||||
|
||||
use serde_json::Value;
|
||||
use std::error::Error;
|
||||
use std::fmt;
|
||||
use std::collections::HashMap;
|
||||
#[derive(Clone)]
|
||||
pub enum ProviderRequestType {
|
||||
ChatCompletionsRequest(ChatCompletionsRequest),
|
||||
MessagesRequest(MessagesRequest),
|
||||
|
|
@ -26,6 +30,11 @@ pub trait ProviderRequest: Send + Sync {
|
|||
|
||||
/// Convert the request to bytes for transmission
|
||||
fn to_bytes(&self) -> Result<Vec<u8>, ProviderRequestError>;
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>>;
|
||||
|
||||
/// Remove a metadata key from the request and return true if the key was present
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool;
|
||||
}
|
||||
|
||||
impl ProviderRequest for ProviderRequestType {
|
||||
|
|
@ -70,6 +79,20 @@ impl ProviderRequest for ProviderRequestType {
|
|||
Self::MessagesRequest(r) => r.to_bytes(),
|
||||
}
|
||||
}
|
||||
|
||||
fn metadata(&self) -> &Option<HashMap<String, Value>> {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.metadata(),
|
||||
Self::MessagesRequest(r) => r.metadata(),
|
||||
}
|
||||
}
|
||||
|
||||
fn remove_metadata_key(&mut self, key: &str) -> bool {
|
||||
match self {
|
||||
Self::ChatCompletionsRequest(r) => r.remove_metadata_key(key),
|
||||
Self::MessagesRequest(r) => r.remove_metadata_key(key),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse the client API from a byte slice.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue