fixing non-streaming responses to tranform correctly

This commit is contained in:
Salman Paracha 2025-09-02 17:42:02 -07:00
parent d4dfbe600f
commit 2813a8cfa5
2 changed files with 101 additions and 92 deletions

View file

@ -31,26 +31,43 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
let upstream_api = provider_id.compatible_api_for_client(client_api);
// Step 1: Parse bytes using upstream API format (what the provider actually sent)
// Step 2: Return response type that matches client API format (what client expects)
match (&upstream_api, client_api) {
// Upstream sent OpenAI format, client expects OpenAI format - direct pass-through
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
}
// Upstream sent Anthropic format, client expects Anthropic format - direct pass-through
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::MessagesResponse(resp))
}
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
let resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::MessagesResponse(resp))
}
// Upstream sent Anthropic format, client expects OpenAI format - need transformation
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
// Parse as Anthropic Messages response first
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
// Transform to OpenAI ChatCompletions format using the transformer
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
}
// Upstream sent OpenAI format, client expects Anthropic format - need transformation
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
// Parse as OpenAI ChatCompletions response first
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
// Transform to Anthropic Messages format using the transformer
let messages_resp: MessagesResponse = openai_resp.try_into()
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
Ok(ProviderResponseType::MessagesResponse(messages_resp))
}
}
}
@ -264,7 +281,39 @@ mod tests {
#[test]
fn test_anthropic_response_from_bytes_with_openai_provider() {
// Simulate Anthropic response with OpenAI provider (should parse as MessagesResponse)
// OpenAI provider receives OpenAI response but client expects Anthropic format
// Upstream API = OpenAI, Client API = Anthropic -> parse OpenAI, convert to Anthropic
let resp = json!({
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [
{
"index": 0,
"message": { "role": "assistant", "content": "Hello! How can I help you today?" },
"finish_reason": "stop"
}
],
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.usage.input_tokens, 10);
assert_eq!(r.usage.output_tokens, 25);
},
_ => panic!("Expected MessagesResponse variant"),
}
}
#[test]
fn test_openai_response_from_bytes_with_claude_provider() {
// Claude provider receives Anthropic response but client expects OpenAI format
// Upstream API = Anthropic, Client API = OpenAI -> parse Anthropic, convert to OpenAI
let resp = json!({
"id": "msg_01ABC123",
"type": "message",
@ -277,40 +326,13 @@ mod tests {
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::MessagesResponse(r) => {
assert_eq!(r.model, "claude-3-sonnet-20240229");
},
_ => panic!("Expected MessagesResponse variant"),
}
}
#[test]
fn test_openai_response_from_bytes_with_claude_provider() {
// Simulate OpenAI response with Claude provider (should parse as ChatCompletionsResponse)
let resp = json!({
"id": "chatcmpl-123",
"object": "chat.completion",
"created": 1234567890,
"model": "gpt-4",
"choices": [
{
"index": 0,
"message": { "role": "assistant", "content": "Hello!" },
"finish_reason": "stop"
}
],
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 },
"system_fingerprint": null
});
let bytes = serde_json::to_vec(&resp).unwrap();
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Claude));
assert!(result.is_ok());
match result.unwrap() {
ProviderResponseType::ChatCompletionsResponse(r) => {
assert_eq!(r.model, "gpt-4");
assert_eq!(r.model, "claude-3-sonnet-20240229");
assert_eq!(r.usage.prompt_tokens, 10);
assert_eq!(r.usage.completion_tokens, 25);
},
_ => panic!("Expected ChatCompletionsResponse variant"),
}

View file

@ -343,13 +343,12 @@ impl StreamContext {
fn handle_streaming_response(
&mut self,
body: &[u8],
supported_api: SupportedAPIs,
provider_id: ProviderId,
) -> Result<Vec<u8>, Action> {
debug!("processing streaming response");
match (Some(supported_api), self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderStreamResponseIter::try_from((body, &supported_api, &provider_id)) {
match self.client_api.as_ref() {
Some(client_api) => {
match ProviderStreamResponseIter::try_from((body, client_api, &provider_id)) {
Ok(mut streaming_response) => {
while let Some(chunk_result) = streaming_response.next() {
match chunk_result {
@ -376,10 +375,11 @@ impl StreamContext {
}
}
}
_ => {
warn!("Missing supported_api or resolved_api for streaming response");
None => {
warn!("Missing client_api for non-streaming response");
return Err(Action::Continue);
}
}
};
// NOTE:
// We currently pass-through the original SSE bytes for streaming responses.
// Non-streaming responses are parsed into ProviderResponseType and re-serialized to
@ -396,38 +396,36 @@ impl StreamContext {
fn handle_non_streaming_response(
&mut self,
body: &[u8],
supported_api: SupportedAPIs,
provider_id: ProviderId,
) -> Result<Vec<u8>, Action> {
let response: ProviderResponseType =
match (Some(&supported_api), self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderResponseType::try_from((body, supported_api, &provider_id)) {
Ok(response) => response,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(body)
);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(body)
);
self.send_server_error(
ServerError::LogicError(format!("Response parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Err(Action::Continue);
}
let response: ProviderResponseType = match self.client_api.as_ref() {
Some(client_api) => {
match ProviderResponseType::try_from((body, client_api, &provider_id)) {
Ok(response) => response,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(body)
);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(body)
);
self.send_server_error(
ServerError::LogicError(format!("Response parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Err(Action::Continue);
}
}
_ => {
warn!("Missing supported_api or resolved_api for non-streaming response");
return Err(Action::Continue);
}
};
}
None => {
warn!("Missing client_api for non-streaming response");
return Err(Action::Continue);
}
};
// Use provider interface to extract usage information
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
@ -768,30 +766,19 @@ impl HttpContext for StreamContext {
self.debug_log_body(&body);
let provider_id = self.get_provider_id();
let supported_api_opt = self.client_api.clone();
if self.streaming_response {
if let Some(supported_api) = supported_api_opt {
match self.handle_streaming_response(&body, supported_api, provider_id) {
Ok(serialized_body) => {
self.set_http_response_body(0, body_size, &serialized_body);
}
Err(action) => return action,
match self.handle_streaming_response(&body, provider_id) {
Ok(serialized_body) => {
self.set_http_response_body(0, body_size, &serialized_body);
}
} else {
warn!("Missing supported_api or resolved_api for streaming response");
Err(action) => return action,
}
} else {
if let Some(supported_api) = supported_api_opt {
match self.handle_non_streaming_response(&body, supported_api, provider_id) {
Ok(serialized_body) => {
self.set_http_response_body(0, body_size, &serialized_body);
}
Err(action) => return action,
match self.handle_non_streaming_response(&body, provider_id) {
Ok(serialized_body) => {
self.set_http_response_body(0, body_size, &serialized_body);
}
} else {
warn!("Missing supported_api or resolved_api for non-streaming response");
return Action::Continue;
Err(action) => return action,
}
}