diff --git a/crates/brightstaff/src/handlers/chat_completions.rs b/crates/brightstaff/src/handlers/chat_completions.rs index d0e5910a..afabb64f 100644 --- a/crates/brightstaff/src/handlers/chat_completions.rs +++ b/crates/brightstaff/src/handlers/chat_completions.rs @@ -22,10 +22,10 @@ fn full>(chunk: T) -> BoxBody { .boxed() } -pub async fn chat_completions( +pub async fn chat( request: Request, router_service: Arc, - llm_provider_endpoint: String, + full_qualified_llm_provider_url: String, ) -> Result>, hyper::Error> { let request_path = request.uri().path().to_string(); let mut request_headers = request.headers().clone(); @@ -152,7 +152,7 @@ pub async fn chat_completions( debug!( "sending request to llm provider: {}, with model hint: {}", - llm_provider_endpoint, model_name + full_qualified_llm_provider_url, model_name ); request_headers.insert( @@ -174,7 +174,7 @@ pub async fn chat_completions( request_headers.remove(header::CONTENT_LENGTH); let llm_response = match reqwest::Client::new() - .post(llm_provider_endpoint) + .post(full_qualified_llm_provider_url) .headers(request_headers) .body(chat_request_parsed_bytes) .send() diff --git a/crates/brightstaff/src/main.rs b/crates/brightstaff/src/main.rs index 34fa3aa3..4f5357eb 100644 --- a/crates/brightstaff/src/main.rs +++ b/crates/brightstaff/src/main.rs @@ -1,9 +1,10 @@ -use brightstaff::handlers::chat_completions::chat_completions; +use brightstaff::handlers::chat_completions::chat; use brightstaff::handlers::models::list_models; use brightstaff::router::llm_router::RouterService; use brightstaff::utils::tracing::init_tracer; use bytes::Bytes; use common::configuration::Configuration; +use common::consts::CHAT_COMPLETIONS_PATH; use http_body_util::{combinators::BoxBody, BodyExt, Empty}; use hyper::body::Incoming; use hyper::server::conn::http1; @@ -67,10 +68,10 @@ async fn main() -> Result<(), Box> { &serde_json::to_string(arch_config.as_ref()).unwrap() ); - let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT") - .unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string()); + let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT") + .unwrap_or_else(|_| "http://localhost:12001".to_string()); - info!("llm provider endpoint: {}", llm_provider_endpoint); + info!("llm provider url: {}", llm_provider_url); info!("listening on http://{}", bind_address); let listener = TcpListener::bind(bind_address).await?; @@ -88,7 +89,7 @@ async fn main() -> Result<(), Box> { let router_service: Arc = Arc::new(RouterService::new( arch_config.llm_providers.clone(), - llm_provider_endpoint.clone(), + llm_provider_url.clone() + CHAT_COMPLETIONS_PATH, routing_model_name, routing_llm_provider, )); @@ -99,19 +100,21 @@ async fn main() -> Result<(), Box> { let io = TokioIo::new(stream); let router_service: Arc = Arc::clone(&router_service); - let llm_provider_endpoint = llm_provider_endpoint.clone(); + let llm_provider_url = llm_provider_url.clone(); let llm_providers = llm_providers.clone(); let service = service_fn(move |req| { + let router_service = Arc::clone(&router_service); let parent_cx = extract_context_from_request(&req); - let llm_provider_endpoint = llm_provider_endpoint.clone(); + let llm_provider_url = llm_provider_url.clone(); let llm_providers = llm_providers.clone(); async move { match (req.method(), req.uri().path()) { - (&Method::POST, "/v1/chat/completions") => { - chat_completions(req, router_service, llm_provider_endpoint) + (&Method::POST, "/v1/chat/completions" | "/v1/messages") => { + let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path()); + chat(req, router_service, fully_qualified_url) .with_context(parent_cx) .await } diff --git a/crates/hermesllm/src/providers/response.rs b/crates/hermesllm/src/providers/response.rs index 610d16eb..636d2a7c 100644 --- a/crates/hermesllm/src/providers/response.rs +++ b/crates/hermesllm/src/providers/response.rs @@ -1,5 +1,4 @@ use crate::providers::id::ProviderId; - use serde::Serialize; use std::error::Error; use std::fmt; diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/llm_gateway/src/stream_context.rs index fbd7faf5..9109fbeb 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/llm_gateway/src/stream_context.rs @@ -1,3 +1,14 @@ +use http::StatusCode; +use log::{debug, info, warn}; +use proxy_wasm::hostcalls::get_current_time; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; +use std::collections::VecDeque; +use std::num::NonZero; +use std::rc::Rc; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, SystemTime, UNIX_EPOCH}; + use crate::metrics::Metrics; use common::configuration::{LlmProvider, LlmProviderType, Overrides}; use common::consts::{ @@ -13,16 +24,6 @@ use common::{ratelimit, routing, tokenizer}; use hermesllm::clients::endpoints::SupportedAPIs; use hermesllm::providers::response::{ProviderResponse, ProviderStreamResponseIter}; use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType}; -use http::StatusCode; -use log::{debug, info, warn}; -use proxy_wasm::hostcalls::get_current_time; -use proxy_wasm::traits::*; -use proxy_wasm::types::*; -use std::collections::VecDeque; -use std::num::NonZero; -use std::rc::Rc; -use std::sync::{Arc, Mutex}; -use std::time::{Duration, SystemTime, UNIX_EPOCH}; pub struct StreamContext { context_id: u32, @@ -30,6 +31,7 @@ pub struct StreamContext { ratelimit_selector: Option
, streaming_response: bool, response_tokens: usize, + /// The API that is requested by the client (before compatibility mapping) client_api: Option, /// The API that should be used for the upstream provider (after compatibility mapping) resolved_api: Option, @@ -191,6 +193,270 @@ impl StreamContext { Ok(()) } + + // === Helper methods extracted from on_http_response_body (no behavior change) === + #[inline] + fn record_ttft_if_needed(&mut self) { + if self.ttft_duration.is_none() { + let current_time = get_current_time().unwrap(); + self.ttft_time = Some(current_time_ns()); + match current_time.duration_since(self.start_time) { + Ok(duration) => { + let duration_ms = duration.as_millis(); + info!( + "on_http_response_body: time to first token: {}ms", + duration_ms + ); + self.ttft_duration = Some(duration); + self.metrics.time_to_first_token.record(duration_ms as u64); + } + Err(e) => { + warn!("SystemTime error: {:?}", e); + } + } + } + } + fn handle_end_of_stream_metrics_and_traces(&mut self, current_time: SystemTime) { + // All streaming responses end with bytes=0 and end_stream=true + // Record the latency for the request + match current_time.duration_since(self.start_time) { + Ok(duration) => { + // Convert the duration to milliseconds + let duration_ms = duration.as_millis(); + info!("on_http_response_body: request latency: {}ms", duration_ms); + // Record the latency to the latency histogram + self.metrics.request_latency.record(duration_ms as u64); + + if self.response_tokens > 0 { + // Compute the time per output token + let tpot = duration_ms as u64 / self.response_tokens as u64; + + // Record the time per output token + self.metrics.time_per_output_token.record(tpot); + + debug!( + "time per token: {}ms, tokens per second: {}", + tpot, + 1000 / tpot + ); + // Record the tokens per second + self.metrics.tokens_per_second.record(1000 / tpot); + } + } + Err(e) => { + warn!("SystemTime error: {:?}", e); + } + } + // Record the output sequence length + self.metrics + .output_sequence_length + .record(self.response_tokens as u64); + + if let Some(traceparent) = self.traceparent.as_ref() { + let current_time_ns = current_time_ns(); + + match Traceparent::try_from(traceparent.to_string()) { + Err(e) => { + warn!("traceparent header is invalid: {}", e); + } + Ok(traceparent) => { + let mut trace_data = common::tracing::TraceData::new(); + let mut llm_span = Span::new( + "egress_traffic".to_string(), + Some(traceparent.trace_id), + Some(traceparent.parent_id), + self.request_body_sent_time.unwrap(), + current_time_ns, + ); + llm_span + .add_attribute("model".to_string(), self.llm_provider().name.to_string()); + + if let Some(user_message) = &self.user_message { + llm_span.add_attribute("user_message".to_string(), user_message.clone()); + } + + if self.ttft_time.is_some() { + llm_span.add_event(Event::new( + "time_to_first_token".to_string(), + self.ttft_time.unwrap(), + )); + trace_data.add_span(llm_span); + } + + self.traces_queue.lock().unwrap().push_back(trace_data); + } + }; + } + } + + fn read_response_body(&mut self, body_size: usize) -> Result, Action> { + if self.streaming_response { + let chunk_start = 0; + let chunk_size = body_size; + debug!( + "on_http_response_body: streaming response reading, {}..{}", + chunk_start, chunk_size + ); + let streaming_chunk = match self.get_http_response_body(0, chunk_size) { + Some(chunk) => chunk, + None => { + warn!( + "response body empty, chunk_start: {}, chunk_size: {}", + chunk_start, chunk_size + ); + return Err(Action::Continue); + } + }; + + if streaming_chunk.len() != chunk_size { + warn!( + "chunk size mismatch: read: {} != requested: {}", + streaming_chunk.len(), + chunk_size + ); + } + Ok(streaming_chunk) + } else { + if body_size == 0 { + return Err(Action::Continue); + } + debug!("non streaming response bytes read: 0:{}", body_size); + match self.get_http_response_body(0, body_size) { + Some(body) => Ok(body), + None => { + warn!("non streaming response body empty"); + Err(Action::Continue) + } + } + } + } + + fn debug_log_body(&self, body: &[u8]) { + if log::log_enabled!(log::Level::Debug) { + debug!( + "response data (converted to utf8): {}", + String::from_utf8_lossy(body) + ); + } + } + + fn handle_streaming_response( + &mut self, + body: &[u8], + supported_api: SupportedAPIs, + provider_id: ProviderId, + ) -> Result, Action> { + debug!("processing streaming response"); + match (Some(supported_api), self.resolved_api.as_ref()) { + (Some(supported_api), Some(_)) => { + match ProviderStreamResponseIter::try_from((body, &supported_api, &provider_id)) { + Ok(mut streaming_response) => { + while let Some(chunk_result) = streaming_response.next() { + match chunk_result { + Ok(chunk) => { + self.record_ttft_if_needed(); + + if chunk.is_final() { + debug!("Received final streaming chunk"); + } + if let Some(content) = chunk.content_delta() { + let estimated_tokens = content.len() / 4; + self.response_tokens += estimated_tokens.max(1); + } + } + Err(e) => { + warn!("Error processing streaming chunk: {}", e); + return Err(Action::Continue); + } + } + } + } + Err(e) => { + warn!("Failed to parse streaming response: {}", e); + } + } + } + _ => { + warn!("Missing supported_api or resolved_api for streaming response"); + } + } + // NOTE: + // We currently pass-through the original SSE bytes for streaming responses. + // Non-streaming responses are parsed into ProviderResponseType and re-serialized to + // normalize the payload to the client API. Doing the same for streaming would require + // a streaming serializer that emits normalized SSE events for the target client API. + // That doesn't exist yet in hermesllm; implementing it is a follow-up. + // TODO(salmanap): Add a normalized SSE serializer in hermesllm and use it here so both + // streaming and non-streaming paths perform the same compatibility mapping. + // Until then, we keep behavior unchanged and forward upstream SSE as-is. + // For consistency of the method contract, still return Vec. + Ok(body.to_vec()) + } + + fn handle_non_streaming_response( + &mut self, + body: &[u8], + supported_api: SupportedAPIs, + provider_id: ProviderId, + ) -> Result, Action> { + debug!("non streaming response"); + + let response: ProviderResponseType = + match (Some(&supported_api), self.resolved_api.as_ref()) { + (Some(supported_api), Some(_)) => { + match ProviderResponseType::try_from((body, supported_api, &provider_id)) { + Ok(response) => response, + Err(e) => { + warn!( + "could not parse response: {}, body str: {}", + e, + String::from_utf8_lossy(body) + ); + debug!( + "on_http_response_body: S[{}], response body: {}", + self.context_id, + String::from_utf8_lossy(body) + ); + self.send_server_error( + ServerError::LogicError(format!("Response parsing error: {}", e)), + Some(StatusCode::BAD_REQUEST), + ); + return Err(Action::Continue); + } + } + } + _ => { + warn!("Missing supported_api or resolved_api for non-streaming response"); + return Err(Action::Continue); + } + }; + + // Use provider interface to extract usage information + if let Some((prompt_tokens, completion_tokens, total_tokens)) = + response.extract_usage_counts() + { + debug!( + "Response usage: prompt={}, completion={}, total={}", + prompt_tokens, completion_tokens, total_tokens + ); + self.response_tokens = completion_tokens; + } else { + warn!("No usage information found in response"); + } + + // Serialize the normalized response back to JSON bytes + match serde_json::to_vec(&response) { + Ok(bytes) => Ok(bytes), + Err(e) => { + warn!("Failed to serialize normalized response: {}", e); + self.send_server_error( + ServerError::LogicError(format!("Response serialization error: {}", e)), + Some(StatusCode::INTERNAL_SERVER_ERROR), + ); + Err(Action::Continue) + } + } + } } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. @@ -457,232 +723,44 @@ impl HttpContext for StreamContext { let current_time = get_current_time().unwrap(); if end_of_stream && body_size == 0 { - // All streaming responses end with bytes=0 and end_stream=true - // Record the latency for the request - match current_time.duration_since(self.start_time) { - Ok(duration) => { - // Convert the duration to milliseconds - let duration_ms = duration.as_millis(); - info!("on_http_response_body: request latency: {}ms", duration_ms); - // Record the latency to the latency histogram - self.metrics.request_latency.record(duration_ms as u64); - - if self.response_tokens > 0 { - // Compute the time per output token - let tpot = duration_ms as u64 / self.response_tokens as u64; - - // Record the time per output token - self.metrics.time_per_output_token.record(tpot); - - debug!( - "time per token: {}ms, tokens per second: {}", - tpot, - 1000 / tpot - ); - // Record the tokens per second - self.metrics.tokens_per_second.record(1000 / tpot); - } - } - Err(e) => { - warn!("SystemTime error: {:?}", e); - } - } - // Record the output sequence length - self.metrics - .output_sequence_length - .record(self.response_tokens as u64); - - if let Some(traceparent) = self.traceparent.as_ref() { - let current_time_ns = current_time_ns(); - - match Traceparent::try_from(traceparent.to_string()) { - Err(e) => { - warn!("traceparent header is invalid: {}", e); - } - Ok(traceparent) => { - let mut trace_data = common::tracing::TraceData::new(); - let mut llm_span = Span::new( - "egress_traffic".to_string(), - Some(traceparent.trace_id), - Some(traceparent.parent_id), - self.request_body_sent_time.unwrap(), - current_time_ns, - ); - llm_span.add_attribute( - "model".to_string(), - self.llm_provider().name.to_string(), - ); - - if let Some(user_message) = &self.user_message { - llm_span - .add_attribute("user_message".to_string(), user_message.clone()); - } - - if self.ttft_time.is_some() { - llm_span.add_event(Event::new( - "time_to_first_token".to_string(), - self.ttft_time.unwrap(), - )); - trace_data.add_span(llm_span); - } - - self.traces_queue.lock().unwrap().push_back(trace_data); - } - }; - } - + self.handle_end_of_stream_metrics_and_traces(current_time); return Action::Continue; } - let body = if self.streaming_response { - let chunk_start = 0; - let chunk_size = body_size; - debug!( - "on_http_response_body: streaming response reading, {}..{}", - chunk_start, chunk_size - ); - let streaming_chunk = match self.get_http_response_body(0, chunk_size) { - Some(chunk) => chunk, - None => { - warn!( - "response body empty, chunk_start: {}, chunk_size: {}", - chunk_start, chunk_size - ); - return Action::Continue; - } - }; - - if streaming_chunk.len() != chunk_size { - warn!( - "chunk size mismatch: read: {} != requested: {}", - streaming_chunk.len(), - chunk_size - ); - } - streaming_chunk - } else { - if body_size == 0 { - return Action::Continue; - } - debug!("non streaming response bytes read: 0:{}", body_size); - match self.get_http_response_body(0, body_size) { - Some(body) => body, - None => { - warn!("non streaming response body empty"); - return Action::Continue; - } - } + let body = match self.read_response_body(body_size) { + Ok(b) => b, + Err(action) => return action, }; - if log::log_enabled!(log::Level::Debug) { - debug!( - "response data (converted to utf8): {}", - String::from_utf8_lossy(&body) - ); - } + self.debug_log_body(&body); let provider_id = self.get_provider_id(); - let supported_api = self.client_api.as_ref(); + let supported_api_opt = self.client_api.clone(); if self.streaming_response { - debug!("processing streaming response"); - match (supported_api, self.resolved_api.as_ref()) { - (Some(supported_api), Some(_)) => { - match ProviderStreamResponseIter::try_from(( - &body[..], - supported_api, - &provider_id, - )) { - Ok(mut streaming_response) => { - while let Some(chunk_result) = streaming_response.next() { - match chunk_result { - Ok(chunk) => { - if self.ttft_duration.is_none() { - let current_time = get_current_time().unwrap(); - self.ttft_time = Some(current_time_ns()); - match current_time.duration_since(self.start_time) { - Ok(duration) => { - let duration_ms = duration.as_millis(); - info!("on_http_response_body: time to first token: {}ms", duration_ms); - self.ttft_duration = Some(duration); - self.metrics - .time_to_first_token - .record(duration_ms as u64); - } - Err(e) => { - warn!("SystemTime error: {:?}", e); - } - } - } - if chunk.is_final() { - debug!("Received final streaming chunk"); - } - if let Some(content) = chunk.content_delta() { - let estimated_tokens = content.len() / 4; - self.response_tokens += estimated_tokens.max(1); - } - } - Err(e) => { - warn!("Error processing streaming chunk: {}", e); - return Action::Continue; - } - } - } - } - Err(e) => { - warn!("Failed to parse streaming response: {}", e); - } + if let Some(supported_api) = supported_api_opt { + match self.handle_streaming_response(&body, supported_api, provider_id) { + Ok(serialized_body) => { + // Write the normalized body back to the wire using the original body_size + self.set_http_response_body(0, body_size, &serialized_body); } + Err(action) => return action, } - _ => { - warn!("Missing supported_api or resolved_api for streaming response"); - } + } else { + warn!("Missing supported_api or resolved_api for streaming response"); } } else { - debug!("non streaming response"); - let provider_id = self.get_provider_id(); - let supported_api = self.client_api.as_ref(); - - let response: ProviderResponseType = match (supported_api, self.resolved_api.as_ref()) { - (Some(supported_api), Some(_)) => { - match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) { - Ok(response) => response, - Err(e) => { - warn!( - "could not parse response: {}, body str: {}", - e, - String::from_utf8_lossy(&body) - ); - debug!( - "on_http_response_body: S[{}], response body: {}", - self.context_id, - String::from_utf8_lossy(&body) - ); - self.send_server_error( - ServerError::LogicError(format!("Response parsing error: {}", e)), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Continue; - } + if let Some(supported_api) = supported_api_opt { + match self.handle_non_streaming_response(&body, supported_api, provider_id) { + Ok(serialized_body) => { + // Write the normalized body back to the wire using the original body_size + self.set_http_response_body(0, body_size, &serialized_body); } + Err(action) => return action, } - _ => { - warn!("Missing supported_api or resolved_api for non-streaming response"); - return Action::Continue; - } - }; - - // Use provider interface to extract usage information - if let Some((prompt_tokens, completion_tokens, total_tokens)) = - response.extract_usage_counts() - { - debug!( - "Response usage: prompt={}, completion={}, total={}", - prompt_tokens, completion_tokens, total_tokens - ); - self.response_tokens = completion_tokens; } else { - warn!("No usage information found in response"); + warn!("Missing supported_api or resolved_api for non-streaming response"); + return Action::Continue; } }