mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fixing non-streaming responses to tranform correctly
This commit is contained in:
parent
d4dfbe600f
commit
2813a8cfa5
2 changed files with 101 additions and 92 deletions
|
|
@ -31,26 +31,43 @@ impl TryFrom<(&[u8], &SupportedAPIs, &ProviderId)> for ProviderResponseType {
|
|||
|
||||
fn try_from((bytes, client_api, provider_id): (&[u8], &SupportedAPIs, &ProviderId)) -> Result<Self, Self::Error> {
|
||||
let upstream_api = provider_id.compatible_api_for_client(client_api);
|
||||
|
||||
// Step 1: Parse bytes using upstream API format (what the provider actually sent)
|
||||
// Step 2: Return response type that matches client API format (what client expects)
|
||||
match (&upstream_api, client_api) {
|
||||
// Upstream sent OpenAI format, client expects OpenAI format - direct pass-through
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
|
||||
}
|
||||
// Upstream sent Anthropic format, client expects Anthropic format - direct pass-through
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::MessagesResponse(resp))
|
||||
}
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
let resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::MessagesResponse(resp))
|
||||
}
|
||||
// Upstream sent Anthropic format, client expects OpenAI format - need transformation
|
||||
(SupportedAPIs::AnthropicMessagesAPI(_), SupportedAPIs::OpenAIChatCompletions(_)) => {
|
||||
let resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
// Parse as Anthropic Messages response first
|
||||
let anthropic_resp: MessagesResponse = serde_json::from_slice(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(resp))
|
||||
|
||||
// Transform to OpenAI ChatCompletions format using the transformer
|
||||
let chat_resp: ChatCompletionsResponse = anthropic_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
Ok(ProviderResponseType::ChatCompletionsResponse(chat_resp))
|
||||
}
|
||||
// Upstream sent OpenAI format, client expects Anthropic format - need transformation
|
||||
(SupportedAPIs::OpenAIChatCompletions(_), SupportedAPIs::AnthropicMessagesAPI(_)) => {
|
||||
// Parse as OpenAI ChatCompletions response first
|
||||
let openai_resp: ChatCompletionsResponse = ChatCompletionsResponse::try_from(bytes)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
|
||||
|
||||
// Transform to Anthropic Messages format using the transformer
|
||||
let messages_resp: MessagesResponse = openai_resp.try_into()
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, format!("Transformation error: {}", e)))?;
|
||||
Ok(ProviderResponseType::MessagesResponse(messages_resp))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -264,7 +281,39 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn test_anthropic_response_from_bytes_with_openai_provider() {
|
||||
// Simulate Anthropic response with OpenAI provider (should parse as MessagesResponse)
|
||||
// OpenAI provider receives OpenAI response but client expects Anthropic format
|
||||
// Upstream API = OpenAI, Client API = Anthropic -> parse OpenAI, convert to Anthropic
|
||||
let resp = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": { "role": "assistant", "content": "Hello! How can I help you today?" },
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": { "prompt_tokens": 10, "completion_tokens": 25, "total_tokens": 35 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.usage.input_tokens, 10);
|
||||
assert_eq!(r.usage.output_tokens, 25);
|
||||
},
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_response_from_bytes_with_claude_provider() {
|
||||
// Claude provider receives Anthropic response but client expects OpenAI format
|
||||
// Upstream API = Anthropic, Client API = OpenAI -> parse Anthropic, convert to OpenAI
|
||||
let resp = json!({
|
||||
"id": "msg_01ABC123",
|
||||
"type": "message",
|
||||
|
|
@ -277,40 +326,13 @@ mod tests {
|
|||
"usage": { "input_tokens": 10, "output_tokens": 25, "cache_creation_input_tokens": 5, "cache_read_input_tokens": 3 }
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::AnthropicMessagesAPI(AnthropicApi::Messages), &ProviderId::OpenAI));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::MessagesResponse(r) => {
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
},
|
||||
_ => panic!("Expected MessagesResponse variant"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_openai_response_from_bytes_with_claude_provider() {
|
||||
// Simulate OpenAI response with Claude provider (should parse as ChatCompletionsResponse)
|
||||
let resp = json!({
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion",
|
||||
"created": 1234567890,
|
||||
"model": "gpt-4",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": { "role": "assistant", "content": "Hello!" },
|
||||
"finish_reason": "stop"
|
||||
}
|
||||
],
|
||||
"usage": { "prompt_tokens": 5, "completion_tokens": 7, "total_tokens": 12 },
|
||||
"system_fingerprint": null
|
||||
});
|
||||
let bytes = serde_json::to_vec(&resp).unwrap();
|
||||
let result = ProviderResponseType::try_from((bytes.as_slice(), &SupportedAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions), &ProviderId::Claude));
|
||||
assert!(result.is_ok());
|
||||
match result.unwrap() {
|
||||
ProviderResponseType::ChatCompletionsResponse(r) => {
|
||||
assert_eq!(r.model, "gpt-4");
|
||||
assert_eq!(r.model, "claude-3-sonnet-20240229");
|
||||
assert_eq!(r.usage.prompt_tokens, 10);
|
||||
assert_eq!(r.usage.completion_tokens, 25);
|
||||
},
|
||||
_ => panic!("Expected ChatCompletionsResponse variant"),
|
||||
}
|
||||
|
|
|
|||
|
|
@ -343,13 +343,12 @@ impl StreamContext {
|
|||
fn handle_streaming_response(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
supported_api: SupportedAPIs,
|
||||
provider_id: ProviderId,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
debug!("processing streaming response");
|
||||
match (Some(supported_api), self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderStreamResponseIter::try_from((body, &supported_api, &provider_id)) {
|
||||
match self.client_api.as_ref() {
|
||||
Some(client_api) => {
|
||||
match ProviderStreamResponseIter::try_from((body, client_api, &provider_id)) {
|
||||
Ok(mut streaming_response) => {
|
||||
while let Some(chunk_result) = streaming_response.next() {
|
||||
match chunk_result {
|
||||
|
|
@ -376,10 +375,11 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for streaming response");
|
||||
None => {
|
||||
warn!("Missing client_api for non-streaming response");
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
}
|
||||
};
|
||||
// NOTE:
|
||||
// We currently pass-through the original SSE bytes for streaming responses.
|
||||
// Non-streaming responses are parsed into ProviderResponseType and re-serialized to
|
||||
|
|
@ -396,38 +396,36 @@ impl StreamContext {
|
|||
fn handle_non_streaming_response(
|
||||
&mut self,
|
||||
body: &[u8],
|
||||
supported_api: SupportedAPIs,
|
||||
provider_id: ProviderId,
|
||||
) -> Result<Vec<u8>, Action> {
|
||||
let response: ProviderResponseType =
|
||||
match (Some(&supported_api), self.resolved_api.as_ref()) {
|
||||
(Some(supported_api), Some(_)) => {
|
||||
match ProviderResponseType::try_from((body, supported_api, &provider_id)) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
let response: ProviderResponseType = match self.client_api.as_ref() {
|
||||
Some(client_api) => {
|
||||
match ProviderResponseType::try_from((body, client_api, &provider_id)) {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not parse response: {}, body str: {}",
|
||||
e,
|
||||
String::from_utf8_lossy(body)
|
||||
);
|
||||
debug!(
|
||||
"on_http_response_body: S[{}], response body: {}",
|
||||
self.context_id,
|
||||
String::from_utf8_lossy(body)
|
||||
);
|
||||
self.send_server_error(
|
||||
ServerError::LogicError(format!("Response parsing error: {}", e)),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
warn!("Missing supported_api or resolved_api for non-streaming response");
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
}
|
||||
None => {
|
||||
warn!("Missing client_api for non-streaming response");
|
||||
return Err(Action::Continue);
|
||||
}
|
||||
};
|
||||
|
||||
// Use provider interface to extract usage information
|
||||
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
|
||||
|
|
@ -768,30 +766,19 @@ impl HttpContext for StreamContext {
|
|||
self.debug_log_body(&body);
|
||||
|
||||
let provider_id = self.get_provider_id();
|
||||
let supported_api_opt = self.client_api.clone();
|
||||
|
||||
if self.streaming_response {
|
||||
if let Some(supported_api) = supported_api_opt {
|
||||
match self.handle_streaming_response(&body, supported_api, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
self.set_http_response_body(0, body_size, &serialized_body);
|
||||
}
|
||||
Err(action) => return action,
|
||||
match self.handle_streaming_response(&body, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
self.set_http_response_body(0, body_size, &serialized_body);
|
||||
}
|
||||
} else {
|
||||
warn!("Missing supported_api or resolved_api for streaming response");
|
||||
Err(action) => return action,
|
||||
}
|
||||
} else {
|
||||
if let Some(supported_api) = supported_api_opt {
|
||||
match self.handle_non_streaming_response(&body, supported_api, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
self.set_http_response_body(0, body_size, &serialized_body);
|
||||
}
|
||||
Err(action) => return action,
|
||||
match self.handle_non_streaming_response(&body, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
self.set_http_response_body(0, body_size, &serialized_body);
|
||||
}
|
||||
} else {
|
||||
warn!("Missing supported_api or resolved_api for non-streaming response");
|
||||
return Action::Continue;
|
||||
Err(action) => return action,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue