From 2813a8cfa53837320f0e3848143273474e4e77f5 Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Tue, 2 Sep 2025 17:42:02 -0700 Subject: [PATCH] fixing non-streaming responses to tranform correctly --- crates/hermesllm/src/providers/response.rs | 98 +++++++++++++--------- crates/llm_gateway/src/stream_context.rs | 95 +++++++++------------ 2 files changed, 101 insertions(+), 92 deletions(-) diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index a4f1c889..a153962b 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -31,26 +31,43 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType { fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result { let upstream_api = provider_id.compatible_api_for_client(client_api); + + // Step 1: Parse bytes using upstream API format (what the provider actually sent) + // Step 2: Return response type that matches client API format (what client expects) match (&upstream_api, client_api) { + // Upstream sent OpenAI format, client expects OpenAI format - direct pass-through (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => { let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderResponseType::ChatCompletionsResponse(resp)) } + // Upstream sent Anthropic format, client expects Anthropic format - direct pass-through (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { let resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; Ok(ProviderResponseType::MessagesResponse(resp)) } - (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { - let resp: MessagesResponse = serde_json::from_slice(bytes) - .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderResponseType::MessagesResponse(resp)) - } + // Upstream sent Anthropic format, client expects OpenAI format - need transformation (SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => { - let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + // Parse as Anthropic Messages response first + let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes) .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; - Ok(ProviderResponseType::ChatCompletionsResponse(resp)) + + // Transform to OpenAI ChatCompletions format using the transformer + let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp)) + } + // Upstream sent OpenAI format, client expects Anthropic format - need transformation + (SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => { + // Parse as OpenAI ChatCompletions response first + let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes) + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?; + + // Transform to Anthropic Messages format using the transformer + let messages_resp: MessagesResponse = openai_resp.try_into() + .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?; + Ok(ProviderResponseType::MessagesResponse(messages_resp)) } } } @@ -264,7 +281,39 @@ mod tests { #[test] fn test_anthropic_response_from_bytes_with_openai_provider() { - // Simulate Anthropic response with OpenAI provider (should parse as MessagesResponse) + // OpenAI provider receives OpenAI response but client expects Anthropic format + // Upstream API = OpenAI, Client API = Anthropic -> parse OpenAI, convert to Anthropic + let resp = json!({ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "gpt-4", + "choices": [ + { + "index": 0, + "message": { "role": "assistant", "content": "Hello! How can I help you today?" }, + "finish_reason": "stop" + } + ], + "usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 } + }); + let bytes = serde_json::to_vec(&resp).unwrap(); + let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI)); + assert!(result.is_ok()); + match result.unwrap() { + ProviderResponseType::MessagesResponse(r) => { + assert_eq!(r.model, "gpt-4"); + assert_eq!(r.usage.input_tokens, 10); + assert_eq!(r.usage.output_tokens, 25); + }, + _ => panic!("Expected MessagesResponse variant"), + } + } + + #[test] + fn test_openai_response_from_bytes_with_claude_provider() { + // Claude provider receives Anthropic response but client expects OpenAI format + // Upstream API = Anthropic, Client API = OpenAI -> parse Anthropic, convert to OpenAI let resp = json!({ "id": "msg_01ABC123", "type": "message", @@ -277,40 +326,13 @@ mod tests { "usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 } }); let bytes = serde_json::to_vec(&resp).unwrap(); - let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI)); - assert!(result.is_ok()); - match result.unwrap() { - ProviderResponseType::MessagesResponse(r) => { - assert_eq!(r.model, "claude-3-sonnet-20240229"); - }, - _ => panic!("Expected MessagesResponse variant"), - } - } - - #[test] - fn test_openai_response_from_bytes_with_claude_provider() { - // Simulate OpenAI response with Claude provider (should parse as ChatCompletionsResponse) - let resp = json!({ - "id": "chatcmpl-123", - "object": "chat.completion", - "created": 1234567890, - "model": "gpt-4", - "choices": [ - { - "index": 0, - "message": { "role": "assistant", "content": "Hello!" }, - "finish_reason": "stop" - } - ], - "usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 }, - "system_fingerprint": null - }); - let bytes = serde_json::to_vec(&resp).unwrap(); let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Claude)); assert!(result.is_ok()); match result.unwrap() { ProviderResponseType::ChatCompletionsResponse(r) => { - assert_eq!(r.model, "gpt-4"); + assert_eq!(r.model, "claude-3-sonnet-20240229"); + assert_eq!(r.usage.prompt_tokens, 10); + assert_eq!(r.usage.completion_tokens, 25); }, _ => panic!("Expected ChatCompletionsResponse variant"), } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 7ab7a1fc..415e89f3 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -343,13 +343,12 @@ impl StreamContext { fn handle_streaming_response( &mut self, body: &[u8], - supported_api: SupportedAPIs, provider_id: ProviderId, ) -> Result, Action> { debug!("processing streaming response"); - match (Some(supported_api), self.resolved_api.as_ref()) { - (Some(supported_api), Some(_)) => { - match ProviderStreamResponseIter::try_from((body, &supported_api, &provider_id)) { + match self.client_api.as_ref() { + Some(client_api) => { + match ProviderStreamResponseIter::try_from((body, client_api, &provider_id)) { Ok(mut streaming_response) => { while let Some(chunk_result) = streaming_response.next() { match chunk_result { @@ -376,10 +375,11 @@ impl StreamContext { } } } - _ => { - warn!("Missing supported_api or resolved_api for streaming response"); + None => { + warn!("Missing client_api for non-streaming response"); + return Err(Action::Continue); } - } + }; // NOTE: // We currently pass-through the original SSE bytes for streaming responses. // Non-streaming responses are parsed into ProviderResponseType and re-serialized to @@ -396,38 +396,36 @@ impl StreamContext { fn handle_non_streaming_response( &mut self, body: &[u8], - supported_api: SupportedAPIs, provider_id: ProviderId, ) -> Result, Action> { - let response: ProviderResponseType = - match (Some(&supported_api), self.resolved_api.as_ref()) { - (Some(supported_api), Some(_)) => { - match ProviderResponseType::try_from((body, supported_api, &provider_id)) { - Ok(response) => response, - Err(e) => { - warn!( - "could not parse response: {}, body str: {}", - e, - String::from_utf8_lossy(body) - ); - debug!( - "on_http_response_body: S[{}], response body: {}", - self.context_id, - String::from_utf8_lossy(body) - ); - self.send_server_error( - ServerError::LogicError(format!("Response parsing error: {}", e)), - Some(StatusCode::BAD_REQUEST), - ); - return Err(Action::Continue); - } + let response: ProviderResponseType = match self.client_api.as_ref() { + Some(client_api) => { + match ProviderResponseType::try_from((body, client_api, &provider_id)) { + Ok(response) => response, + Err(e) => { + warn!( + "could not parse response: {}, body str: {}", + e, + String::from_utf8_lossy(body) + ); + debug!( + "on_http_response_body: S[{}], response body: {}", + self.context_id, + String::from_utf8_lossy(body) + ); + self.send_server_error( + ServerError::LogicError(format!("Response parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Err(Action::Continue); } } - _ => { - warn!("Missing supported_api or resolved_api for non-streaming response"); - return Err(Action::Continue); - } - }; + } + None => { + warn!("Missing client_api for non-streaming response"); + return Err(Action::Continue); + } + }; // Use provider interface to extract usage information if let Some((prompt_tokens, completion_tokens, total_tokens)) = @@ -768,30 +766,19 @@ impl HttpContext for StreamContext { self.debug_log_body(&body); let provider_id = self.get_provider_id(); - let supported_api_opt = self.client_api.clone(); - if self.streaming_response { - if let Some(supported_api) = supported_api_opt { - match self.handle_streaming_response(&body, supported_api, provider_id) { - Ok(serialized_body) => { - self.set_http_response_body(0, body_size, &serialized_body); - } - Err(action) => return action, + match self.handle_streaming_response(&body, provider_id) { + Ok(serialized_body) => { + self.set_http_response_body(0, body_size, &serialized_body); } - } else { - warn!("Missing supported_api or resolved_api for streaming response"); + Err(action) => return action, } } else { - if let Some(supported_api) = supported_api_opt { - match self.handle_non_streaming_response(&body, supported_api, provider_id) { - Ok(serialized_body) => { - self.set_http_response_body(0, body_size, &serialized_body); - } - Err(action) => return action, + match self.handle_non_streaming_response(&body, provider_id) { + Ok(serialized_body) => { + self.set_http_response_body(0, body_size, &serialized_body); } - } else { - warn!("Missing supported_api or resolved_api for non-streaming response"); - return Action::Continue; + Err(action) => return action, } }