From 8fbdddca479704d000a313938ec8fb136b4b5cdf Mon Sep 17 00:00:00 2001 From: Salman Paracha Date: Tue, 2 Dec 2025 06:04:52 -0800 Subject: [PATCH] fixed changes based on code review --- .../anthropic_streaming_buffer.rs | 36 ++--- .../responses_api_streaming_buffer.rs | 18 ++- crates/hermesllm/src/providers/id.rs | 10 +- crates/llm_gateway/src/stream_context.rs | 15 +- tests/e2e/test_openai_responses_api_client.py | 151 ++++++++++++++++++ 5 files changed, 197 insertions(+), 33 deletions(-) diff --git a/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs index 5879712b..07ff0140 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/anthropic_streaming_buffer.rs @@ -24,7 +24,7 @@ pub struct AnthropicMessagesStreamBuffer { needs_content_block_stop: bool, /// Model name to use when generating message_start events - model: String, + model: Option, } impl AnthropicMessagesStreamBuffer { @@ -34,12 +34,12 @@ impl AnthropicMessagesStreamBuffer { message_started: false, content_block_started: false, needs_content_block_stop: false, - model: "unknown".to_string(), + model: None, } } /// Helper to create and format a ContentBlockStart SSE event - fn create_content_block_start_event(&self) -> SseEvent { + fn create_content_block_start_event() -> SseEvent { let content_block_start = MessagesStreamEvent::ContentBlockStart { index: 0, content_block: crate::apis::anthropic::MessagesContentBlock::Text { @@ -59,14 +59,14 @@ impl AnthropicMessagesStreamBuffer { } /// Helper to create and format a MessageStart SSE event - fn create_message_start_event(&self) -> SseEvent { + fn create_message_start_event(model: &str) -> SseEvent { let message_start = MessagesStreamEvent::MessageStart { message: crate::apis::anthropic::MessagesStreamMessage { id: format!("msg_{}", uuid::Uuid::new_v4().to_string().replace("-", "")), obj_type: "message".to_string(), role: crate::apis::anthropic::MessagesRole::Assistant, content: vec![], - model: self.model.clone(), + model: model.to_string(), stop_reason: None, stop_sequence: None, usage: crate::apis::anthropic::MessagesUsage { @@ -89,7 +89,7 @@ impl AnthropicMessagesStreamBuffer { } /// Helper to create and format a ContentBlockStop SSE event - fn create_content_block_stop_event(&self) -> SseEvent { + fn create_content_block_stop_event() -> SseEvent { let content_block_stop = MessagesStreamEvent::ContentBlockStop { index: 0 }; let sse_string: String = content_block_stop.into(); @@ -113,12 +113,12 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { // FIRST: Try to extract model name from the raw event data before transformation // The provider_stream_response has already been transformed to Anthropic format, // so we need to extract the model from the original raw data if available - if self.model == "unknown" { + if self.model.is_none() { if let Some(data) = &event.data { // Try to parse as JSON and extract model field if let Ok(json) = serde_json::from_str::(data) { if let Some(model) = json.get("model").and_then(|m| m.as_str()) { - self.model = model.to_string(); + self.model = Some(model.to_string()); } } } @@ -132,12 +132,12 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { // Add the message_start event self.buffered_events.push(event); self.message_started = true; - return; } "content_block_start" => { // If we haven't seen message_start yet, inject it first if !self.message_started { - let message_start = self.create_message_start_event(); + let model = self.model.as_deref().unwrap_or("unknown"); + let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); self.buffered_events.push(message_start); self.message_started = true; } @@ -146,21 +146,21 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { self.buffered_events.push(event); self.content_block_started = true; self.needs_content_block_stop = true; - return; } "content_block_delta" => { // If this is the first content delta and we haven't started yet, // inject message_start and content_block_start first if !self.message_started { // Create and inject message_start event - let message_start = self.create_message_start_event(); + let model = self.model.as_deref().unwrap_or("unknown"); + let message_start = AnthropicMessagesStreamBuffer::create_message_start_event(model); self.buffered_events.push(message_start); self.message_started = true; } if !self.content_block_started { // Inject ContentBlockStart after message_start - let content_block_start = self.create_content_block_start_event(); + let content_block_start = AnthropicMessagesStreamBuffer::create_content_block_start_event(); self.buffered_events.push(content_block_start); self.content_block_started = true; self.needs_content_block_stop = true; @@ -168,26 +168,24 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { // Content deltas are between ContentBlockStart and ContentBlockStop self.buffered_events.push(event); - return; } "message_delta" => { // Inject ContentBlockStop before message_delta if self.needs_content_block_stop { - let content_block_stop = self.create_content_block_stop_event(); + let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event(); self.buffered_events.push(content_block_stop); self.needs_content_block_stop = false; } // Add the message_delta event self.buffered_events.push(event); - return; } _ => { - // Other event types, just accumulate + // Other event types, just accumulate the event self.buffered_events.push(event); - return; } } + return; } } @@ -198,7 +196,7 @@ impl SseStreamBufferTrait for AnthropicMessagesStreamBuffer { fn into_bytes(&mut self) -> Vec { // Inject ContentBlockStop if needed before flushing if self.needs_content_block_stop { - let content_block_stop = self.create_content_block_stop_event(); + let content_block_stop = AnthropicMessagesStreamBuffer::create_content_block_stop_event(); self.buffered_events.push(content_block_stop); self.needs_content_block_stop = false; } diff --git a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs index 27e6a199..84854af3 100644 --- a/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs +++ b/crates/hermesllm/src/apis/streaming_shapes/responses_api_streaming_buffer.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use log::debug; use crate::apis::openai_responses::{ ResponsesAPIStreamEvent, ResponsesAPIResponse, OutputItem, OutputItemStatus, ResponseStatus, TextConfig, TextFormat, Reasoning, @@ -17,10 +18,19 @@ fn event_to_sse(event: ResponsesAPIStreamEvent) -> SseEvent { ResponsesAPIStreamEvent::ResponseOutputTextDone { .. } => "response.output_text.done", ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDelta { .. } => "response.function_call_arguments.delta", ResponsesAPIStreamEvent::ResponseFunctionCallArgumentsDone { .. } => "response.function_call_arguments.done", - _ => "unknown", + unknown => { + debug!("Unknown ResponsesAPIStreamEvent type encountered: {:?}", unknown); + "unknown" + } }; - let json_data = serde_json::to_string(&event).unwrap_or_default(); + let json_data = match serde_json::to_string(&event) { + Ok(data) => data, + Err(e) => { + debug!("Error serializing ResponsesAPIStreamEvent to JSON: {}", e); + String::new() + } + }; let wire_format: String = event.into(); SseEvent { @@ -95,7 +105,7 @@ impl ResponsesAPIStreamBuffer { seq } - fn generate_item_id(&self, prefix: &str) -> String { + fn generate_item_id(prefix: &str) -> String { format!("{}_{}", prefix, uuid::Uuid::new_v4().to_string().replace("-", "")) } @@ -103,7 +113,7 @@ impl ResponsesAPIStreamBuffer { if let Some(id) = self.item_ids.get(&output_index) { return id.clone(); } - let id = self.generate_item_id(prefix); + let id = ResponsesAPIStreamBuffer::generate_item_id(prefix); self.item_ids.insert(output_index, id.clone()); id } diff --git a/crates/hermesllm/src/providers/id.rs b/crates/hermesllm/src/providers/id.rs index 69455eaf..3c964240 100644 --- a/crates/hermesllm/src/providers/id.rs +++ b/crates/hermesllm/src/providers/id.rs @@ -111,11 +111,6 @@ impl ProviderId { SupportedUpstreamAPIs::OpenAIResponsesAPI(OpenAIApi::Responses) } - // Non-OpenAI providers: if client requested the Responses API, fall back to Chat Completions - (_, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { - SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions) - } - // Amazon Bedrock natively supports Bedrock APIs (ProviderId::AmazonBedrock, SupportedAPIsFromClient::OpenAIChatCompletions(_)) => { if is_streaming { @@ -135,6 +130,11 @@ impl ProviderId { SupportedUpstreamAPIs::AmazonBedrockConverse(AmazonBedrockApi::Converse) } } + + // Non-OpenAI providers: if client requested the Responses API, fall back to Chat Completions + (_, SupportedAPIsFromClient::OpenAIResponsesAPI(_)) => { + SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions) + } } } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 13fbcde4..9541ae45 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -543,14 +543,19 @@ impl StreamContext { } // Add transformed event to buffer (buffer may inject lifecycle events) - self.sse_buffer - .as_mut() - .unwrap() - .add_transformed_event(transformed_event); + if let Some(buffer) = self.sse_buffer.as_mut() { + buffer.add_transformed_event(transformed_event); + } } // Get accumulated bytes from buffer and return - Ok(self.sse_buffer.as_mut().unwrap().into_bytes()) + match self.sse_buffer.as_mut() { + Some(buffer) => Ok(buffer.into_bytes()), + None => { + warn!("SSE buffer unexpectedly missing after initialization"); + Err(Action::Continue) + } + } } None => { warn!("Missing client_api for non-streaming response"); diff --git a/tests/e2e/test_openai_responses_api_client.py b/tests/e2e/test_openai_responses_api_client.py index c01c6454..2d9fa248 100644 --- a/tests/e2e/test_openai_responses_api_client.py +++ b/tests/e2e/test_openai_responses_api_client.py @@ -325,3 +325,154 @@ def test_openai_responses_api_streaming_with_tools_upstream_chat_completions(): assert ( full_text or tool_calls ), "Expected streamed text or tool call argument deltas from Responses tools stream" + + +def test_openai_responses_api_non_streaming_upstream_anthropic(): + """Send a v1/responses request using the grok alias to verify translation/routing""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + resp = client.responses.create( + model="claude-sonnet-4-20250514", input="Hello, translate this via grok alias" + ) + + # Print the response content - handle both responses format and chat completions format + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + assert resp is not None + assert resp.id is not None + + +def test_openai_responses_api_with_streaming_upstream_anthropic(): + """Build a v1/responses API streaming request (pass-through) and ensure gateway accepts it""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + # Simple streaming responses API request using a direct model (pass-through) + stream = client.responses.create( + model="claude-sonnet-4-20250514", + input="Write a short haiku about coding", + stream=True, + ) + + # Collect streamed content using the official Responses API streaming shape + text_chunks = [] + final_message = None + + for event in stream: + # The Python SDK surfaces a high-level Responses streaming interface. + # We rely on its typed helpers instead of digging into model_extra. + if getattr(event, "type", None) == "response.output_text.delta" and getattr( + event, "delta", None + ): + # Each delta contains a text fragment + text_chunks.append(event.delta) + + # Track the final response message if provided by the SDK + if getattr(event, "type", None) == "response.completed" and getattr( + event, "response", None + ): + final_message = event.response + + full_content = "".join(text_chunks) + + # Print the streaming response + print(f"\n{'='*80}") + print( + f"Model: {getattr(final_message, 'model', 'unknown') if final_message else 'unknown'}" + ) + print(f"Streamed Output: {full_content}") + print(f"{'='*80}\n") + + assert len(text_chunks) > 0, "Should have received streaming text deltas" + assert len(full_content) > 0, "Should have received content" + + +def test_openai_responses_api_non_streaming_with_tools_upstream_anthropic(): + """Responses API with tools routed to grok via alias""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1") + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input: hello_world", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + resp = client.responses.create( + model="claude-sonnet-4-20250514", + input="Call the echo tool", + tools=tools, + ) + + assert resp.id is not None + + print(f"\n{'='*80}") + print(f"Model: {resp.model}") + print(f"Output: {resp.output_text}") + print(f"{'='*80}\n") + + +def test_openai_responses_api_streaming_with_tools_upstream_anthropic(): + """Responses API with a function/tool definition (streaming, pass-through)""" + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1", max_retries=0) + + tools = [ + { + "type": "function", + "name": "echo_tool", + "description": "Echo back the provided input: hello_world", + "parameters": { + "type": "object", + "properties": {"text": {"type": "string"}}, + "required": ["text"], + }, + } + ] + + stream = client.responses.create( + model="claude-sonnet-4-20250514", + input="Call the echo tool", + tools=tools, + stream=True, + ) + + text_chunks = [] + tool_calls = [] + + for event in stream: + etype = getattr(event, "type", None) + + # Collect streamed text output + if etype == "response.output_text.delta" and getattr(event, "delta", None): + text_chunks.append(event.delta) + + # Collect streamed tool call arguments + if etype == "response.function_call_arguments.delta" and getattr( + event, "delta", None + ): + tool_calls.append(event.delta) + + full_text = "".join(text_chunks) + + print(f"\n{'='*80}") + print("Responses tools streaming test") + print(f"Streamed text: {full_text}") + print(f"Tool call argument chunks: {len(tool_calls)}") + print(f"{'='*80}\n") + + # We expect either streamed text output or streamed tool-call arguments + assert ( + full_text or tool_calls + ), "Expected streamed text or tool call argument deltas from Responses tools stream"