diff --git a/crates/brightstaff/src/handlers/response_handler.rs b/crates/brightstaff/src/handlers/response_handler.rs index ce1a20d2..d8b3bedf 100644 --- a/crates/brightstaff/src/handlers/response_handler.rs +++ b/crates/brightstaff/src/handlers/response_handler.rs @@ -1,7 +1,7 @@ use bytes::Bytes; -use hermesllm::SseEvent; use hermesllm::apis::OpenAIApi; use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs}; +use hermesllm::SseEvent; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Frame; @@ -130,18 +130,25 @@ impl ResponseHandler { ) -> Result { use hermesllm::apis::streaming_shapes::sse::SseStreamIter; + let response_headers = llm_response.headers(); + let is_sse_streaming = response_headers + .get(hyper::header::CONTENT_TYPE) + .map_or(false, |v| { + v.to_str().unwrap_or("").contains("text/event-stream") + }); + let response_bytes = llm_response .bytes() .await .map_err(|e| ResponseError::StreamError(format!("Failed to read response: {}", e)))?; + if is_sse_streaming { + let client_api = + SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); + let upstream_api = + SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let client_api = - SupportedAPIsFromClient::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - let upstream_api = SupportedUpstreamAPIs::OpenAIChatCompletions(OpenAIApi::ChatCompletions); - - // Try to parse as SSE streaming response - if let Ok(sse_iter) = SseStreamIter::try_from(response_bytes.as_ref()) { + let sse_iter = SseStreamIter::try_from(response_bytes.as_ref()).unwrap(); let mut accumulated_text = String::new(); for sse_event in sse_iter { @@ -150,7 +157,8 @@ impl ResponseHandler { continue; } - let transformed_event = SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap(); + let transformed_event = + SseEvent::try_from((sse_event, &client_api, &upstream_api)).unwrap(); // Try to get provider response and extract content delta match transformed_event.provider_response() { @@ -166,15 +174,15 @@ impl ResponseHandler { } } } - return Ok(accumulated_text); + } else { + // If not SSE, treat as regular text response + let response_text = String::from_utf8(response_bytes.to_vec()).map_err(|e| { + ResponseError::StreamError(format!("Failed to decode response: {}", e)) + })?; + + Ok(response_text) } - - // If not SSE, treat as regular text response - let response_text = String::from_utf8(response_bytes.to_vec()) - .map_err(|e| ResponseError::StreamError(format!("Failed to decode response: {}", e)))?; - - Ok(response_text) } }