updated the stream_context to update response bytes

This commit is contained in:
Salman Paracha 2025-08-28 22:55:12 -07:00
parent 9f6d2464f6
commit e7238fb7fd
4 changed files with 314 additions and 234 deletions

View file

@ -22,10 +22,10 @@ fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
.boxed()
}
pub async fn chat_completions(
pub async fn chat(
request: Request<hyper::body::Incoming>,
router_service: Arc<RouterService>,
llm_provider_endpoint: String,
full_qualified_llm_provider_url: String,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
let request_path = request.uri().path().to_string();
let mut request_headers = request.headers().clone();
@ -152,7 +152,7 @@ pub async fn chat_completions(
debug!(
"sending request to llm provider: {}, with model hint: {}",
llm_provider_endpoint, model_name
full_qualified_llm_provider_url, model_name
);
request_headers.insert(
@ -174,7 +174,7 @@ pub async fn chat_completions(
request_headers.remove(header::CONTENT_LENGTH);
let llm_response = match reqwest::Client::new()
.post(llm_provider_endpoint)
.post(full_qualified_llm_provider_url)
.headers(request_headers)
.body(chat_request_parsed_bytes)
.send()

View file

@ -1,9 +1,10 @@
use brightstaff::handlers::chat_completions::chat_completions;
use brightstaff::handlers::chat_completions::chat;
use brightstaff::handlers::models::list_models;
use brightstaff::router::llm_router::RouterService;
use brightstaff::utils::tracing::init_tracer;
use bytes::Bytes;
use common::configuration::Configuration;
use common::consts::CHAT_COMPLETIONS_PATH;
use http_body_util::{combinators::BoxBody, BodyExt, Empty};
use hyper::body::Incoming;
use hyper::server::conn::http1;
@ -67,10 +68,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
&serde_json::to_string(arch_config.as_ref()).unwrap()
);
let llm_provider_endpoint = env::var("LLM_PROVIDER_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:12001/v1/chat/completions".to_string());
let llm_provider_url = env::var("LLM_PROVIDER_ENDPOINT")
.unwrap_or_else(|_| "http://localhost:12001".to_string());
info!("llm provider endpoint: {}", llm_provider_endpoint);
info!("llm provider url: {}", llm_provider_url);
info!("listening on http://{}", bind_address);
let listener = TcpListener::bind(bind_address).await?;
@ -88,7 +89,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let router_service: Arc<RouterService> = Arc::new(RouterService::new(
arch_config.llm_providers.clone(),
llm_provider_endpoint.clone(),
llm_provider_url.clone() + CHAT_COMPLETIONS_PATH,
routing_model_name,
routing_llm_provider,
));
@ -99,19 +100,21 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let io = TokioIo::new(stream);
let router_service: Arc<RouterService> = Arc::clone(&router_service);
let llm_provider_endpoint = llm_provider_endpoint.clone();
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
let service = service_fn(move |req| {
let router_service = Arc::clone(&router_service);
let parent_cx = extract_context_from_request(&req);
let llm_provider_endpoint = llm_provider_endpoint.clone();
let llm_provider_url = llm_provider_url.clone();
let llm_providers = llm_providers.clone();
async move {
match (req.method(), req.uri().path()) {
(&Method::POST, "/v1/chat/completions") => {
chat_completions(req, router_service, llm_provider_endpoint)
(&Method::POST, "/v1/chat/completions" | "/v1/messages") => {
let fully_qualified_url = format!("{}{}", llm_provider_url, req.uri().path());
chat(req, router_service, fully_qualified_url)
.with_context(parent_cx)
.await
}

View file

@ -1,5 +1,4 @@
use crate::providers::id::ProviderId;
use serde::Serialize;
use std::error::Error;
use std::fmt;

View file

@ -1,3 +1,14 @@
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::collections::VecDeque;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use crate::metrics::Metrics;
use common::configuration::{LlmProvider, LlmProviderType, Overrides};
use common::consts::{
@ -13,16 +24,6 @@ use common::{ratelimit, routing, tokenizer};
use hermesllm::clients::endpoints::SupportedAPIs;
use hermesllm::providers::response::{ProviderResponse, ProviderStreamResponseIter};
use hermesllm::{ProviderId, ProviderRequest, ProviderRequestType, ProviderResponseType};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::hostcalls::get_current_time;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use std::collections::VecDeque;
use std::num::NonZero;
use std::rc::Rc;
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub struct StreamContext {
context_id: u32,
@ -30,6 +31,7 @@ pub struct StreamContext {
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
/// The API that is requested by the client (before compatibility mapping)
client_api: Option<SupportedAPIs>,
/// The API that should be used for the upstream provider (after compatibility mapping)
resolved_api: Option<SupportedAPIs>,
@ -191,6 +193,270 @@ impl StreamContext {
Ok(())
}
// === Helper methods extracted from on_http_response_body (no behavior change) ===
#[inline]
fn record_ttft_if_needed(&mut self) {
if self.ttft_duration.is_none() {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time_ns());
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
info!(
"on_http_response_body: time to first token: {}ms",
duration_ms
);
self.ttft_duration = Some(duration);
self.metrics.time_to_first_token.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
}
}
fn handle_end_of_stream_metrics_and_traces(&mut self, current_time: SystemTime) {
// All streaming responses end with bytes=0 and end_stream=true
// Record the latency for the request
match current_time.duration_since(self.start_time) {
Ok(duration) => {
// Convert the duration to milliseconds
let duration_ms = duration.as_millis();
info!("on_http_response_body: request latency: {}ms", duration_ms);
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
if self.response_tokens > 0 {
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!(
"time per token: {}ms, tokens per second: {}",
tpot,
1000 / tpot
);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
}
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
// Record the output sequence length
self.metrics
.output_sequence_length
.record(self.response_tokens as u64);
if let Some(traceparent) = self.traceparent.as_ref() {
let current_time_ns = current_time_ns();
match Traceparent::try_from(traceparent.to_string()) {
Err(e) => {
warn!("traceparent header is invalid: {}", e);
}
Ok(traceparent) => {
let mut trace_data = common::tracing::TraceData::new();
let mut llm_span = Span::new(
"egress_traffic".to_string(),
Some(traceparent.trace_id),
Some(traceparent.parent_id),
self.request_body_sent_time.unwrap(),
current_time_ns,
);
llm_span
.add_attribute("model".to_string(), self.llm_provider().name.to_string());
if let Some(user_message) = &self.user_message {
llm_span.add_attribute("user_message".to_string(), user_message.clone());
}
if self.ttft_time.is_some() {
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time.unwrap(),
));
trace_data.add_span(llm_span);
}
self.traces_queue.lock().unwrap().push_back(trace_data);
}
};
}
}
fn read_response_body(&mut self, body_size: usize) -> Result<Vec<u8>, Action> {
if self.streaming_response {
let chunk_start = 0;
let chunk_size = body_size;
debug!(
"on_http_response_body: streaming response reading, {}..{}",
chunk_start, chunk_size
);
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
Some(chunk) => chunk,
None => {
warn!(
"response body empty, chunk_start: {}, chunk_size: {}",
chunk_start, chunk_size
);
return Err(Action::Continue);
}
};
if streaming_chunk.len() != chunk_size {
warn!(
"chunk size mismatch: read: {} != requested: {}",
streaming_chunk.len(),
chunk_size
);
}
Ok(streaming_chunk)
} else {
if body_size == 0 {
return Err(Action::Continue);
}
debug!("non streaming response bytes read: 0:{}", body_size);
match self.get_http_response_body(0, body_size) {
Some(body) => Ok(body),
None => {
warn!("non streaming response body empty");
Err(Action::Continue)
}
}
}
}
fn debug_log_body(&self, body: &[u8]) {
if log::log_enabled!(log::Level::Debug) {
debug!(
"response data (converted to utf8): {}",
String::from_utf8_lossy(body)
);
}
}
fn handle_streaming_response(
&mut self,
body: &[u8],
supported_api: SupportedAPIs,
provider_id: ProviderId,
) -> Result<Vec<u8>, Action> {
debug!("processing streaming response");
match (Some(supported_api), self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderStreamResponseIter::try_from((body, &supported_api, &provider_id)) {
Ok(mut streaming_response) => {
while let Some(chunk_result) = streaming_response.next() {
match chunk_result {
Ok(chunk) => {
self.record_ttft_if_needed();
if chunk.is_final() {
debug!("Received final streaming chunk");
}
if let Some(content) = chunk.content_delta() {
let estimated_tokens = content.len() / 4;
self.response_tokens += estimated_tokens.max(1);
}
}
Err(e) => {
warn!("Error processing streaming chunk: {}", e);
return Err(Action::Continue);
}
}
}
}
Err(e) => {
warn!("Failed to parse streaming response: {}", e);
}
}
}
_ => {
warn!("Missing supported_api or resolved_api for streaming response");
}
}
// NOTE:
// We currently pass-through the original SSE bytes for streaming responses.
// Non-streaming responses are parsed into ProviderResponseType and re-serialized to
// normalize the payload to the client API. Doing the same for streaming would require
// a streaming serializer that emits normalized SSE events for the target client API.
// That doesn't exist yet in hermesllm; implementing it is a follow-up.
// TODO(salmanap): Add a normalized SSE serializer in hermesllm and use it here so both
// streaming and non-streaming paths perform the same compatibility mapping.
// Until then, we keep behavior unchanged and forward upstream SSE as-is.
// For consistency of the method contract, still return Vec<u8>.
Ok(body.to_vec())
}
fn handle_non_streaming_response(
&mut self,
body: &[u8],
supported_api: SupportedAPIs,
provider_id: ProviderId,
) -> Result<Vec<u8>, Action> {
debug!("non streaming response");
let response: ProviderResponseType =
match (Some(&supported_api), self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderResponseType::try_from((body, supported_api, &provider_id)) {
Ok(response) => response,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(body)
);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(body)
);
self.send_server_error(
ServerError::LogicError(format!("Response parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Err(Action::Continue);
}
}
}
_ => {
warn!("Missing supported_api or resolved_api for non-streaming response");
return Err(Action::Continue);
}
};
// Use provider interface to extract usage information
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
response.extract_usage_counts()
{
debug!(
"Response usage: prompt={}, completion={}, total={}",
prompt_tokens, completion_tokens, total_tokens
);
self.response_tokens = completion_tokens;
} else {
warn!("No usage information found in response");
}
// Serialize the normalized response back to JSON bytes
match serde_json::to_vec(&response) {
Ok(bytes) => Ok(bytes),
Err(e) => {
warn!("Failed to serialize normalized response: {}", e);
self.send_server_error(
ServerError::LogicError(format!("Response serialization error: {}", e)),
Some(StatusCode::INTERNAL_SERVER_ERROR),
);
Err(Action::Continue)
}
}
}
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
@ -457,232 +723,44 @@ impl HttpContext for StreamContext {
let current_time = get_current_time().unwrap();
if end_of_stream && body_size == 0 {
// All streaming responses end with bytes=0 and end_stream=true
// Record the latency for the request
match current_time.duration_since(self.start_time) {
Ok(duration) => {
// Convert the duration to milliseconds
let duration_ms = duration.as_millis();
info!("on_http_response_body: request latency: {}ms", duration_ms);
// Record the latency to the latency histogram
self.metrics.request_latency.record(duration_ms as u64);
if self.response_tokens > 0 {
// Compute the time per output token
let tpot = duration_ms as u64 / self.response_tokens as u64;
// Record the time per output token
self.metrics.time_per_output_token.record(tpot);
debug!(
"time per token: {}ms, tokens per second: {}",
tpot,
1000 / tpot
);
// Record the tokens per second
self.metrics.tokens_per_second.record(1000 / tpot);
}
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
// Record the output sequence length
self.metrics
.output_sequence_length
.record(self.response_tokens as u64);
if let Some(traceparent) = self.traceparent.as_ref() {
let current_time_ns = current_time_ns();
match Traceparent::try_from(traceparent.to_string()) {
Err(e) => {
warn!("traceparent header is invalid: {}", e);
}
Ok(traceparent) => {
let mut trace_data = common::tracing::TraceData::new();
let mut llm_span = Span::new(
"egress_traffic".to_string(),
Some(traceparent.trace_id),
Some(traceparent.parent_id),
self.request_body_sent_time.unwrap(),
current_time_ns,
);
llm_span.add_attribute(
"model".to_string(),
self.llm_provider().name.to_string(),
);
if let Some(user_message) = &self.user_message {
llm_span
.add_attribute("user_message".to_string(), user_message.clone());
}
if self.ttft_time.is_some() {
llm_span.add_event(Event::new(
"time_to_first_token".to_string(),
self.ttft_time.unwrap(),
));
trace_data.add_span(llm_span);
}
self.traces_queue.lock().unwrap().push_back(trace_data);
}
};
}
self.handle_end_of_stream_metrics_and_traces(current_time);
return Action::Continue;
}
let body = if self.streaming_response {
let chunk_start = 0;
let chunk_size = body_size;
debug!(
"on_http_response_body: streaming response reading, {}..{}",
chunk_start, chunk_size
);
let streaming_chunk = match self.get_http_response_body(0, chunk_size) {
Some(chunk) => chunk,
None => {
warn!(
"response body empty, chunk_start: {}, chunk_size: {}",
chunk_start, chunk_size
);
return Action::Continue;
}
};
if streaming_chunk.len() != chunk_size {
warn!(
"chunk size mismatch: read: {} != requested: {}",
streaming_chunk.len(),
chunk_size
);
}
streaming_chunk
} else {
if body_size == 0 {
return Action::Continue;
}
debug!("non streaming response bytes read: 0:{}", body_size);
match self.get_http_response_body(0, body_size) {
Some(body) => body,
None => {
warn!("non streaming response body empty");
return Action::Continue;
}
}
let body = match self.read_response_body(body_size) {
Ok(b) => b,
Err(action) => return action,
};
if log::log_enabled!(log::Level::Debug) {
debug!(
"response data (converted to utf8): {}",
String::from_utf8_lossy(&body)
);
}
self.debug_log_body(&body);
let provider_id = self.get_provider_id();
let supported_api = self.client_api.as_ref();
let supported_api_opt = self.client_api.clone();
if self.streaming_response {
debug!("processing streaming response");
match (supported_api, self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderStreamResponseIter::try_from((
&body[..],
supported_api,
&provider_id,
)) {
Ok(mut streaming_response) => {
while let Some(chunk_result) = streaming_response.next() {
match chunk_result {
Ok(chunk) => {
if self.ttft_duration.is_none() {
let current_time = get_current_time().unwrap();
self.ttft_time = Some(current_time_ns());
match current_time.duration_since(self.start_time) {
Ok(duration) => {
let duration_ms = duration.as_millis();
info!("on_http_response_body: time to first token: {}ms", duration_ms);
self.ttft_duration = Some(duration);
self.metrics
.time_to_first_token
.record(duration_ms as u64);
}
Err(e) => {
warn!("SystemTime error: {:?}", e);
}
}
}
if chunk.is_final() {
debug!("Received final streaming chunk");
}
if let Some(content) = chunk.content_delta() {
let estimated_tokens = content.len() / 4;
self.response_tokens += estimated_tokens.max(1);
}
}
Err(e) => {
warn!("Error processing streaming chunk: {}", e);
return Action::Continue;
}
}
}
}
Err(e) => {
warn!("Failed to parse streaming response: {}", e);
}
if let Some(supported_api) = supported_api_opt {
match self.handle_streaming_response(&body, supported_api, provider_id) {
Ok(serialized_body) => {
// Write the normalized body back to the wire using the original body_size
self.set_http_response_body(0, body_size, &serialized_body);
}
Err(action) => return action,
}
_ => {
warn!("Missing supported_api or resolved_api for streaming response");
}
} else {
warn!("Missing supported_api or resolved_api for streaming response");
}
} else {
debug!("non streaming response");
let provider_id = self.get_provider_id();
let supported_api = self.client_api.as_ref();
let response: ProviderResponseType = match (supported_api, self.resolved_api.as_ref()) {
(Some(supported_api), Some(_)) => {
match ProviderResponseType::try_from((&body[..], supported_api, &provider_id)) {
Ok(response) => response,
Err(e) => {
warn!(
"could not parse response: {}, body str: {}",
e,
String::from_utf8_lossy(&body)
);
debug!(
"on_http_response_body: S[{}], response body: {}",
self.context_id,
String::from_utf8_lossy(&body)
);
self.send_server_error(
ServerError::LogicError(format!("Response parsing error: {}", e)),
Some(StatusCode::BAD_REQUEST),
);
return Action::Continue;
}
if let Some(supported_api) = supported_api_opt {
match self.handle_non_streaming_response(&body, supported_api, provider_id) {
Ok(serialized_body) => {
// Write the normalized body back to the wire using the original body_size
self.set_http_response_body(0, body_size, &serialized_body);
}
Err(action) => return action,
}
_ => {
warn!("Missing supported_api or resolved_api for non-streaming response");
return Action::Continue;
}
};
// Use provider interface to extract usage information
if let Some((prompt_tokens, completion_tokens, total_tokens)) =
response.extract_usage_counts()
{
debug!(
"Response usage: prompt={}, completion={}, total={}",
prompt_tokens, completion_tokens, total_tokens
);
self.response_tokens = completion_tokens;
} else {
warn!("No usage information found in response");
warn!("Missing supported_api or resolved_api for non-streaming response");
return Action::Continue;
}
}