This commit is contained in:
Musa 2026-06-03 17:15:11 +00:00 committed by GitHub
commit b5eaa541e6
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 154 additions and 77 deletions

View file

@ -56,6 +56,8 @@ pub struct StreamContext {
http_protocol: Option<String>,
sse_buffer: Option<SseStreamBuffer>,
sse_chunk_processor: Option<SseChunkProcessor>,
/// Accumulates upstream non-streaming response chunks until end of stream.
non_streaming_response_buffer: Vec<u8>,
}
impl StreamContext {
@ -87,6 +89,7 @@ impl StreamContext {
http_protocol: None,
sse_buffer: None,
sse_chunk_processor: None,
non_streaming_response_buffer: Vec::new(),
}
}
@ -816,6 +819,31 @@ impl StreamContext {
}
}
}
fn flush_streaming_response_tail(&mut self) -> Option<Vec<u8>> {
let provider_id = self.get_provider_id();
let has_buffered_sse = self
.sse_chunk_processor
.as_ref()
.is_some_and(|processor| processor.has_buffered_data());
if has_buffered_sse {
match self.handle_streaming_response(&[], provider_id) {
Ok(bytes) if !bytes.is_empty() => return Some(bytes),
Ok(_) => {}
Err(_) => return None,
}
}
self.sse_buffer.as_mut().and_then(|buffer| {
let bytes = buffer.to_bytes();
if bytes.is_empty() {
None
} else {
Some(bytes)
}
})
}
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
@ -1174,6 +1202,21 @@ impl HttpContext for StreamContext {
let current_time = get_current_time().unwrap();
if end_of_stream && body_size == 0 {
if self.streaming_response {
if let Some(serialized_body) = self.flush_streaming_response_tail() {
self.set_http_response_body(0, 0, &serialized_body);
}
} else if !self.non_streaming_response_buffer.is_empty() {
let body = std::mem::take(&mut self.non_streaming_response_buffer);
let provider_id = self.get_provider_id();
match self.handle_non_streaming_response(&body, provider_id) {
Ok(serialized_body) => {
self.set_http_response_body(0, 0, &serialized_body);
}
Err(action) => return action,
}
}
debug!(
"request_id={}: response body complete, total_bytes={}",
self.request_identifier(),
@ -1248,7 +1291,15 @@ impl HttpContext for StreamContext {
Err(action) => return action,
}
} else {
match self.handle_non_streaming_response(&body, provider_id) {
self.non_streaming_response_buffer.extend_from_slice(&body);
if !end_of_stream {
// Hold chunks until the full JSON body arrives.
self.set_http_response_body(0, body_size, &[]);
return Action::Continue;
}
let complete_body = std::mem::take(&mut self.non_streaming_response_buffer);
match self.handle_non_streaming_response(&complete_body, provider_id) {
Ok(serialized_body) => {
self.set_http_response_body(0, body_size, &serialized_body);
}

View file

@ -7,8 +7,7 @@ use common::{
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE,
TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE,
X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL,
TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_STATE_HEADER,
},
errors::ServerError,
http::{CallArgs, Client},
@ -17,7 +16,6 @@ use common::{
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::{traits::HttpContext, types::Action};
use serde_json::Value;
use std::{
collections::HashMap,
time::{Duration, SystemTime, UNIX_EPOCH},
@ -291,6 +289,10 @@ impl HttpContext for StreamContext {
}
if end_of_stream && body_size == 0 {
if !self.streaming_response && !self.non_streaming_response_buffer.is_empty() {
let body = std::mem::take(&mut self.non_streaming_response_buffer);
self.process_non_streaming_response_body(&body, 0);
}
return Action::Continue;
}
@ -326,15 +328,15 @@ impl HttpContext for StreamContext {
}
};
let body_utf8 = match String::from_utf8(body) {
Ok(body_utf8) => body_utf8,
Err(e) => {
info!("could not convert to utf8: {}", e);
return Action::Continue;
}
};
if self.streaming_response {
let body_utf8 = match String::from_utf8(body) {
Ok(body_utf8) => body_utf8,
Err(e) => {
info!("could not convert to utf8: {}", e);
return Action::Continue;
}
};
debug!("streaming response");
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
@ -359,70 +361,15 @@ impl HttpContext for StreamContext {
self.set_http_response_body(0, body_size, response_str.as_bytes());
self.tool_calls = None;
}
} else if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
let mut data = match serde_json::from_str(&body_utf8) {
Ok(data) => data,
Err(e) => {
warn!(
"could not deserialize response, sending data as it is: {}",
e
);
return Action::Continue;
}
};
// use serde::Value to manipulate the json object and ensure that we don't lose any data
if let Value::Object(ref mut map) = data {
// serialize arch state and add to metadata
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
let tool_call_message = self.generate_tool_call_message();
let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap();
metadata.as_object_mut().unwrap().insert(
X_ARCH_TOOL_CALL.to_string(),
serde_json::Value::String(tool_call_message_str),
);
let api_response_message = self.generate_api_response_message();
let api_response_message_str =
serde_json::to_string(&api_response_message).unwrap();
metadata.as_object_mut().unwrap().insert(
X_ARCH_API_RESPONSE.to_string(),
serde_json::Value::String(api_response_message_str),
);
let fc_messages = vec![tool_call_message, api_response_message];
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
metadata.as_object_mut().unwrap().insert(
X_ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
if let Some(arch_fc_response) = self.arch_fc_response.as_ref() {
metadata.as_object_mut().unwrap().insert(
X_ARCH_FC_MODEL_RESPONSE.to_string(),
serde_json::Value::String(
serde_json::to_string(arch_fc_response).unwrap(),
),
);
}
let data_serialized = serde_json::to_string(&data).unwrap();
info!("plano <= developer: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
};
} else {
self.non_streaming_response_buffer.extend_from_slice(&body);
if !end_of_stream {
self.set_http_response_body(0, body_size, &[]);
return Action::Continue;
}
let complete_body = std::mem::take(&mut self.non_streaming_response_buffer);
self.process_non_streaming_response_body(&complete_body, body_size);
}
debug!("recv [S={}] end_stream={}", self.context_id, end_of_stream);

View file

@ -9,7 +9,7 @@ use common::consts::{
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
X_ARCH_FC_MODEL_RESPONSE,
X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE, X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL,
};
use common::errors::ServerError;
use common::http::{CallArgs, Client};
@ -18,6 +18,7 @@ use derivative::Derivative;
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
use serde_json::Value;
use std::cell::RefCell;
use std::collections::HashMap;
use std::rc::Rc;
@ -66,6 +67,8 @@ pub struct StreamContext {
pub traceparent: Option<String>,
pub _tracing: Rc<Option<Tracing>>,
pub arch_fc_response: Option<String>,
/// Accumulates upstream non-streaming response chunks until end of stream.
pub non_streaming_response_buffer: Vec<u8>,
}
impl StreamContext {
@ -100,6 +103,7 @@ impl StreamContext {
start_upstream_llm_request_time: 0,
time_to_first_token: None,
arch_fc_response: None,
non_streaming_response_buffer: Vec::new(),
}
}
@ -803,6 +807,80 @@ impl StreamContext {
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
self.resume_http_request();
}
pub fn process_non_streaming_response_body(&mut self, body: &[u8], body_size: usize) {
let body_utf8 = match String::from_utf8(body.to_vec()) {
Ok(body_utf8) => body_utf8,
Err(e) => {
info!("could not convert to utf8: {}", e);
return;
}
};
if let Some(tool_calls) = self.tool_calls.as_ref() {
if !tool_calls.is_empty() {
if self.arch_state.is_none() {
self.arch_state = Some(Vec::new());
}
let mut data = match serde_json::from_str(&body_utf8) {
Ok(data) => data,
Err(e) => {
warn!(
"could not deserialize response, sending data as it is: {}",
e
);
return;
}
};
if let Value::Object(ref mut map) = data {
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
let tool_call_message = self.generate_tool_call_message();
let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap();
metadata.as_object_mut().unwrap().insert(
X_ARCH_TOOL_CALL.to_string(),
serde_json::Value::String(tool_call_message_str),
);
let api_response_message = self.generate_api_response_message();
let api_response_message_str =
serde_json::to_string(&api_response_message).unwrap();
metadata.as_object_mut().unwrap().insert(
X_ARCH_API_RESPONSE.to_string(),
serde_json::Value::String(api_response_message_str),
);
let fc_messages = vec![tool_call_message, api_response_message];
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
metadata.as_object_mut().unwrap().insert(
X_ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
);
if let Some(arch_fc_response) = self.arch_fc_response.as_ref() {
metadata.as_object_mut().unwrap().insert(
X_ARCH_FC_MODEL_RESPONSE.to_string(),
serde_json::Value::String(
serde_json::to_string(arch_fc_response).unwrap(),
),
);
}
let data_serialized = serde_json::to_string(&data).unwrap();
info!("plano <= developer: {}", data_serialized);
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
}
}
}
}
}
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {

View file

@ -19,7 +19,8 @@ run_hurl_with_retries() {
local max_attempts=1
local attempt=1
if [ "$demo_name" = "llm_routing/preference_based_routing" ]; then
if [ "$demo_name" = "llm_routing/preference_based_routing" ] \
|| [ "$demo_name" = "advanced/currency_exchange" ]; then
max_attempts=3
fi