diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 240dc25c..5536471e 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -6,9 +6,12 @@ use common::consts::ARCH_PROVIDER_HINT_HEADER; use http_body_util::combinators::BoxBody; use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Body; +use hyper::body::Frame; use hyper::header::{self}; use hyper::{Request, Response, StatusCode}; use tokio::sync::mpsc; +use tokio_stream::wrappers::ReceiverStream; +use tokio_stream::StreamExt; use tracing::{info, warn}; use crate::router::llm_router::RouterService; @@ -112,7 +115,13 @@ pub async fn chat_completion( } }; + // copy over the headers from the original response let response_headers = llm_response.headers().clone(); + let mut response = Response::builder(); + let headers = response.headers_mut().unwrap(); + for (header_name, header_value) in response_headers.iter() { + headers.insert(header_name, header_value.clone()); + } if chat_completion_request.stream { // Create a channel to send data @@ -126,35 +135,31 @@ pub async fn chat_completion( let item = match item { Ok(item) => item, Err(err) => { - //TODO: use mpsc to send result with error to the receiver warn!("Error receiving chunk: {:?}", err); break; } }; - //TODO: send error to the receiver - tx.send(item).await.unwrap(); + if tx.send(item).await.is_err() { + warn!("Receiver dropped"); + break; + } } }); - use bytes::Bytes; - use hyper::body::Frame; - use hyper::Response; - use tokio_stream::wrappers::ReceiverStream; - use tokio_stream::StreamExt; - let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); let stream_body = BoxBody::new(StreamBody::new(stream)); - let mut res = Response::builder(); - let headers = res.headers_mut().unwrap(); - - for (header_name, header_value) in response_headers.iter() { - headers.insert(header_name, header_value.clone()); + match response.body(stream_body) { + Ok(response) => Ok(response), + Err(err) => { + let err_msg = format!("Failed to create response: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(internal_error) + } } - - Ok(res.body(stream_body).unwrap()) } else { let body = match llm_response.text().await { Ok(body) => body, @@ -166,9 +171,14 @@ pub async fn chat_completion( } }; - let mut ok_response = Response::new(full(body)); - *ok_response.status_mut() = StatusCode::OK; - - Ok(ok_response) + match response.body(full(body)) { + Ok(response) => Ok(response), + Err(err) => { + let err_msg = format!("Failed to create response: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(internal_error) + } + } } }