From d4dfbe600f4d55d1b21f464f6a35b652fc91f5ab Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Tue, 2 Sep 2025 16:19:45 -0700 Subject: [PATCH] making sure that we convert the raw bytes to the correct provider type upstream --- crates/hermesllm/src/providers/request.rs | 101 +++++++++++++++------- crates/llm_gateway/src/stream_context.rs | 58 +++++++++---- 2 files changed, 110 insertions(+), 49 deletions(-) diff --git a/crates/hermesllm/src/providers/request.rs b/crates/hermesllm/src/providers/request.rs index 2cd33fee..69da73dd 100644 --- a/crates/hermesllm/src/providers/request.rs +++ b/crates/hermesllm/src/providers/request.rs @@ -8,39 +8,6 @@ pub enum ProviderRequestType { MessagesRequest(MessagesRequest), //add more request types here } - -impl TryFrom<&[u8]> for ProviderRequestType { - type Error = std::io::Error; - - // if passing bytes without provider id we assume the request is in OpenAI format - fn try_from(bytes: &[u8]) -> Result { - let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) - } -} - -/// Parse request based on endpoint and provider information -impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { - type Error = std::io::Error; - - fn try_from((bytes, endpoint): (&[u8], &SupportedAPIs)) -> Result { - // Use SupportedApi to determine the appropriate request type - match endpoint { - SupportedAPIs::OpenAIChatCompletions(_) => { - let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) - } - SupportedAPIs::AnthropicMessagesAPI(_) => { - let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderRequestType::MessagesRequest(messages_request)) - } - } - } -} - pub trait ProviderRequest: Send + Sync { /// Extract the model name from the request fn model(&self) -> &str; @@ -61,6 +28,74 @@ pub trait ProviderRequest: Send + Sync { fn to_bytes(&self) -> Result, ProviderRequestError>; } + +impl TryFrom<&[u8]> for ProviderRequestType { + type Error = std::io::Error; + + // if passing bytes without provider id we assume the request is in OpenAI format + fn try_from(bytes: &[u8]) -> Result { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } +} + +/// Parse request based on api +impl TryFrom<(&[u8], &SupportedAPIs)> for ProviderRequestType { + type Error = std::io::Error; + + fn try_from((bytes, the_api_type): (&[u8], &SupportedAPIs)) -> Result { + // Use SupportedApi to determine the appropriate request type + match the_api_type { + SupportedAPIs::OpenAIChatCompletions(_) => { + let chat_completion_request: ChatCompletionsRequest = ChatCompletionsRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_completion_request)) + } + SupportedAPIs::AnthropicMessagesAPI(_) => { + let messages_request: MessagesRequest = MessagesRequest::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + Ok(ProviderRequestType::MessagesRequest(messages_request)) + } + } + } +} + +impl TryFrom<(&ProviderRequestType, &SupportedAPIs)> for ProviderRequestType { + type Error = ProviderRequestError; + + fn try_from((r, target_api): (&ProviderRequestType, &SupportedAPIs)) -> Result { + match (r, target_api) { + // Same API - no conversion needed, just clone the reference + (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::OpenAIChatCompletions(_)) => { + Ok(ProviderRequestType::ChatCompletionsRequest(chat_req.clone())) + } + (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { + Ok(ProviderRequestType::MessagesRequest(messages_req.clone())) + } + + // Cross-API conversion - cloning is necessary for transformation + (ProviderRequestType::ChatCompletionsRequest(chat_req), SupportedAPIs::AnthropicMessagesAPI(_)) => { + let messages_req = MessagesRequest::try_from(chat_req.clone()) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert ChatCompletionsRequest to MessagesRequest: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::MessagesRequest(messages_req)) + } + + (ProviderRequestType::MessagesRequest(messages_req), SupportedAPIs::OpenAIChatCompletions(_)) => { + let chat_req = ChatCompletionsRequest::try_from(messages_req.clone()) + .map_err(|e| ProviderRequestError { + message: format!("Failed to convert MessagesRequest to ChatCompletionsRequest: {}", e), + source: Some(Box::new(e)) + })?; + Ok(ProviderRequestType::ChatCompletionsRequest(chat_req)) + } + } + } +} + impl ProviderRequest for ProviderRequestType { fn model(&self) -> &str { match self { diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index d0b42c52..7ab7a1fc 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -289,7 +289,7 @@ impl StreamContext { } } - fn read_response_body(&mut self, body_size: usize) -> Result, Action> { + fn read_raw_response_body(&mut self, body_size: usize) -> Result, Action> { if self.streaming_response { let chunk_start = 0; let chunk_size = body_size; @@ -583,9 +583,10 @@ impl HttpContext for StreamContext { } }; - let mut deserialized_body = match self.resolved_api.as_ref() { - Some(resolved_api) => { - match ProviderRequestType::try_from((&body_bytes[..], resolved_api)) { + //We need to deserialize the request body based on the resolved API + let mut deserialized_client_request: ProviderRequestType = match self.client_api.as_ref() { + Some(the_client_api) => { + match ProviderRequestType::try_from((&body_bytes[..], the_client_api)) { Ok(deserialized) => deserialized, Err(e) => { debug!( @@ -620,7 +621,7 @@ impl HttpContext for StreamContext { }; // Store the original model for logging - let model_requested = deserialized_body.model().to_string(); + let model_requested = deserialized_client_request.model().to_string(); // Apply model name resolution logic using the trait method let resolved_model = match model_name { @@ -646,10 +647,10 @@ impl HttpContext for StreamContext { }; // Set the resolved model using the trait method - deserialized_body.set_model(resolved_model.clone()); + deserialized_client_request.set_model(resolved_model.clone()); // Extract user message for tracing - self.user_message = deserialized_body.get_recent_user_message(); + self.user_message = deserialized_client_request.get_recent_user_message(); info!( "on_http_request_body: provider: {}, model requested (in body): {}, model selected: {}", @@ -659,10 +660,10 @@ impl HttpContext for StreamContext { ); // Use provider interface for streaming detection and setup - self.streaming_response = deserialized_body.is_streaming(); + self.streaming_response = deserialized_client_request.is_streaming(); // Use provider interface for text extraction (after potential mutation) - let input_tokens_str = deserialized_body.extract_messages_text(); + let input_tokens_str = deserialized_client_request.extract_messages_text(); // enforce ratelimits on ingress if let Err(e) = self.enforce_ratelimits(&resolved_model, input_tokens_str.as_str()) { self.send_server_error( @@ -674,19 +675,44 @@ impl HttpContext for StreamContext { } // Convert chat completion request to llm provider specific request using provider interface - let deserialized_body_bytes = match deserialized_body.to_bytes() { - Ok(bytes) => bytes, - Err(e) => { - warn!("Failed to serialize request body: {}", e); + let serialized_body_bytes_upstream = match self.resolved_api.as_ref() { + Some(upstream) => { + match ProviderRequestType::try_from((&deserialized_client_request, upstream)) { + Ok(request) => match request.to_bytes() { + Ok(bytes) => bytes, + Err(e) => { + warn!("Failed to serialize request body: {}", e); + self.send_server_error( + ServerError::LogicError(format!( + "Request serialization error: {}", + e + )), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }, + Err(e) => { + warn!("Failed to create provider request: {}", e); + self.send_server_error( + ServerError::LogicError(format!("Provider request error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + } + } + None => { + warn!("No upstream API resolved"); self.send_server_error( - ServerError::LogicError(format!("Request serialization error: {}", e)), + ServerError::LogicError("No upstream API resolved".into()), Some(StatusCode::BAD_REQUEST), ); return Action::Pause; } }; - self.set_http_request_body(0, body_size, &deserialized_body_bytes); + self.set_http_request_body(0, body_size, &serialized_body_bytes_upstream); Action::Continue } @@ -734,7 +760,7 @@ impl HttpContext for StreamContext { return Action::Continue; } - let body = match self.read_response_body(body_size) { + let body = match self.read_raw_response_body(body_size) { Ok(bytes) => bytes, Err(action) => return action, };