From 7e39d048b4b036c234d5fb05c6320c7b55acc229 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Tue, 13 May 2025 16:27:47 -0700 Subject: [PATCH] add more changes --- crates/Cargo.lock | 1 + crates/brightstaff/Cargo.toml | 1 + .../src/handlers/chat_completions.rs | 89 ++++++++++++------- crates/brightstaff/src/main.rs | 4 +- crates/common/src/configuration.rs | 2 +- crates/common/src/consts.rs | 3 +- crates/common/src/utils.rs | 10 +-- crates/llm_gateway/src/stream_context.rs | 2 +- crates/prompt_gateway/src/stream_context.rs | 13 ++- 9 files changed, 78 insertions(+), 47 deletions(-) diff --git a/crates/Cargo.lock b/crates/Cargo.lock index 8142085f..b2501deb 100644 --- a/crates/Cargo.lock +++ b/crates/Cargo.lock @@ -196,6 +196,7 @@ dependencies = [ "eventsource-stream", "futures", "futures-util", + "http-body 1.0.1", "http-body-util", "hyper 1.6.0", "hyper-util", diff --git a/crates/brightstaff/Cargo.toml b/crates/brightstaff/Cargo.toml index 821c766a..1ad17dc8 100644 --- a/crates/brightstaff/Cargo.toml +++ b/crates/brightstaff/Cargo.toml @@ -10,6 +10,7 @@ eventsource-client = "0.15.0" eventsource-stream = "0.2.3" futures = "0.3.31" futures-util = "0.3.31" +http-body = "1.0.1" http-body-util = "0.1.3" hyper = { version = "1.6.0", features = ["full"] } hyper-util = "0.1.11" diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index 49848b37..240dc25c 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -4,11 +4,12 @@ use bytes::Bytes; use common::api::open_ai::ChatCompletionsRequest; use common::consts::ARCH_PROVIDER_HINT_HEADER; use http_body_util::combinators::BoxBody; -use http_body_util::{BodyExt, Full}; +use http_body_util::{BodyExt, Full, StreamBody}; use hyper::body::Body; -use hyper::header; +use hyper::header::{self}; use hyper::{Request, Response, StatusCode}; -use tracing::info; +use tokio::sync::mpsc; +use tracing::{info, warn}; use crate::router::llm_router::RouterService; @@ -111,41 +112,63 @@ pub async fn chat_completion( } }; - // if chat_completion_request.stream { - // let mut byte_stream = llm_response.bytes_stream(); + let response_headers = llm_response.headers().clone(); - // while let Some(item) = byte_stream.next().await { - // let item = match item { - // Ok(item) => item, - // Err(err) => { - // let err_msg = format!("Failed to read stream: {}", err); - // let mut internal_error = Response::new(full(err_msg)); - // *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - // return Ok(internal_error); - // } - // }; + if chat_completion_request.stream { + // Create a channel to send data + let (tx, rx) = mpsc::channel::(16); - // info!("Received chunk: {:?}", item); - // } + // Spawn a task to send data as it becomes available + tokio::spawn(async move { + let mut byte_stream = llm_response.bytes_stream(); - // let mut ok_response = Response::new(empty()); - // *ok_response.status_mut() = StatusCode::OK; + while let Some(item) = byte_stream.next().await { + let item = match item { + Ok(item) => item, + Err(err) => { + //TODO: use mpsc to send result with error to the receiver + warn!("Error receiving chunk: {:?}", err); + break; + } + }; - // return Ok(ok_response); - // } else { - let body = match llm_response.text().await { - Ok(body) => body, - Err(err) => { - let err_msg = format!("Failed to read response: {}", err); - let mut internal_error = Response::new(full(err_msg)); - *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; - return Ok(internal_error); + //TODO: send error to the receiver + tx.send(item).await.unwrap(); + } + }); + + use bytes::Bytes; + use hyper::body::Frame; + use hyper::Response; + use tokio_stream::wrappers::ReceiverStream; + use tokio_stream::StreamExt; + + let stream = ReceiverStream::new(rx).map(|chunk| Ok::<_, hyper::Error>(Frame::data(chunk))); + + let stream_body = BoxBody::new(StreamBody::new(stream)); + + let mut res = Response::builder(); + let headers = res.headers_mut().unwrap(); + + for (header_name, header_value) in response_headers.iter() { + headers.insert(header_name, header_value.clone()); } - }; - let mut ok_response = Response::new(full(body)); - *ok_response.status_mut() = StatusCode::OK; + Ok(res.body(stream_body).unwrap()) + } else { + let body = match llm_response.text().await { + Ok(body) => body, + Err(err) => { + let err_msg = format!("Failed to read response: {}", err); + let mut internal_error = Response::new(full(err_msg)); + *internal_error.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + return Ok(internal_error); + } + }; - Ok(ok_response) - // } + let mut ok_response = Response::new(full(body)); + *ok_response.status_mut() = StatusCode::OK; + + Ok(ok_response) + } } diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index ac8f7b16..9bd51b93 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,3 +1,5 @@ +use brightstaff::handlers::chat_completions::chat_completion; +use brightstaff::router::llm_router::RouterService; use bytes::Bytes; use common::configuration::Configuration; use common::utils::shorten_string; @@ -22,8 +24,6 @@ use std::{env, fs}; use tokio::net::TcpListener; use tracing::info; use tracing_subscriber::EnvFilter; -use brightstaff::handlers::chat_completions::chat_completion; -use brightstaff::router::llm_router::RouterService; pub mod router; diff --git a/crates/common/src/configuration.rs b/crates/common/src/configuration.rs index 2fb0238f..71c13f8b 100644 --- a/crates/common/src/configuration.rs +++ b/crates/common/src/configuration.rs @@ -8,7 +8,7 @@ use crate::api::open_ai::{ #[derive(Debug, Clone, Serialize, Deserialize)] pub struct Routing { - pub model: String, + pub model: String, } #[derive(Debug, Clone, Serialize, Deserialize)] diff --git a/crates/common/src/consts.rs b/crates/common/src/consts.rs index 1ea0063d..9eee4693 100644 --- a/crates/common/src/consts.rs +++ b/crates/common/src/consts.rs @@ -11,7 +11,8 @@ pub const MODEL_SERVER_NAME: &str = "model_server"; pub const ARCH_ROUTING_HEADER: &str = "x-arch-llm-provider"; pub const MESSAGES_KEY: &str = "messages"; pub const ARCH_PROVIDER_HINT_HEADER: &str = "x-arch-llm-provider-hint"; -pub const CHAT_COMPLETIONS_PATH: [&str; 2] = ["/v1/chat/completions", "/openai/v1/chat/completions"]; +pub const CHAT_COMPLETIONS_PATH: [&str; 2] = + ["/v1/chat/completions", "/openai/v1/chat/completions"]; pub const HEALTHZ_PATH: &str = "/healthz"; pub const X_ARCH_STATE_HEADER: &str = "x-arch-state"; pub const X_ARCH_API_RESPONSE: &str = "x-arch-api-response-message"; diff --git a/crates/common/src/utils.rs b/crates/common/src/utils.rs index fa31d166..5c5793da 100644 --- a/crates/common/src/utils.rs +++ b/crates/common/src/utils.rs @@ -1,7 +1,7 @@ pub fn shorten_string(s: &str) -> String { - if s.len() > 80 { - format!("{}...", &s[..80]) - } else { - s.to_string() - } + if s.len() > 80 { + format!("{}...", &s[..80]) + } else { + s.to_string() + } } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index 326ccbb8..446c2bcd 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -425,7 +425,7 @@ impl HttpContext for StreamContext { self.request_body = Some(chat_completion_request_str); self.request_size = Some(body_size); - return Action::Continue; + Action::Continue } fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs index 7345586a..3f486862 100644 --- a/crates/prompt_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/stream_context.rs @@ -103,7 +103,11 @@ impl StreamContext { } } - pub (crate) fn send_server_error(&self, error: ServerError, override_status_code: Option) { + pub(crate) fn send_server_error( + &self, + error: ServerError, + override_status_code: Option, + ) { self.send_http_response( override_status_code .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) @@ -777,18 +781,19 @@ impl StreamContext { fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool { let content = model_server_response - .choices.first() + .choices + .first() .and_then(|choice| choice.message.content.as_ref()); let content_has_value = content.is_some() && !content.unwrap().is_empty(); let tool_calls = model_server_response - .choices.first() + .choices + .first() .and_then(|choice| choice.message.tool_calls.as_ref()); // intent was matched if content has some value or tool_calls is empty - content_has_value || (tool_calls.is_some() && !tool_calls.unwrap().is_empty()) }