mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
Merge 5276dfde2f into f3d6ea41ad
This commit is contained in:
commit
b5eaa541e6
4 changed files with 154 additions and 77 deletions
|
|
@ -56,6 +56,8 @@ pub struct StreamContext {
|
|||
http_protocol: Option<String>,
|
||||
sse_buffer: Option<SseStreamBuffer>,
|
||||
sse_chunk_processor: Option<SseChunkProcessor>,
|
||||
/// Accumulates upstream non-streaming response chunks until end of stream.
|
||||
non_streaming_response_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -87,6 +89,7 @@ impl StreamContext {
|
|||
http_protocol: None,
|
||||
sse_buffer: None,
|
||||
sse_chunk_processor: None,
|
||||
non_streaming_response_buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -816,6 +819,31 @@ impl StreamContext {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn flush_streaming_response_tail(&mut self) -> Option<Vec<u8>> {
|
||||
let provider_id = self.get_provider_id();
|
||||
let has_buffered_sse = self
|
||||
.sse_chunk_processor
|
||||
.as_ref()
|
||||
.is_some_and(|processor| processor.has_buffered_data());
|
||||
|
||||
if has_buffered_sse {
|
||||
match self.handle_streaming_response(&[], provider_id) {
|
||||
Ok(bytes) if !bytes.is_empty() => return Some(bytes),
|
||||
Ok(_) => {}
|
||||
Err(_) => return None,
|
||||
}
|
||||
}
|
||||
|
||||
self.sse_buffer.as_mut().and_then(|buffer| {
|
||||
let bytes = buffer.to_bytes();
|
||||
if bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(bytes)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
|
||||
|
|
@ -1174,6 +1202,21 @@ impl HttpContext for StreamContext {
|
|||
|
||||
let current_time = get_current_time().unwrap();
|
||||
if end_of_stream && body_size == 0 {
|
||||
if self.streaming_response {
|
||||
if let Some(serialized_body) = self.flush_streaming_response_tail() {
|
||||
self.set_http_response_body(0, 0, &serialized_body);
|
||||
}
|
||||
} else if !self.non_streaming_response_buffer.is_empty() {
|
||||
let body = std::mem::take(&mut self.non_streaming_response_buffer);
|
||||
let provider_id = self.get_provider_id();
|
||||
match self.handle_non_streaming_response(&body, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
self.set_http_response_body(0, 0, &serialized_body);
|
||||
}
|
||||
Err(action) => return action,
|
||||
}
|
||||
}
|
||||
|
||||
debug!(
|
||||
"request_id={}: response body complete, total_bytes={}",
|
||||
self.request_identifier(),
|
||||
|
|
@ -1248,7 +1291,15 @@ impl HttpContext for StreamContext {
|
|||
Err(action) => return action,
|
||||
}
|
||||
} else {
|
||||
match self.handle_non_streaming_response(&body, provider_id) {
|
||||
self.non_streaming_response_buffer.extend_from_slice(&body);
|
||||
if !end_of_stream {
|
||||
// Hold chunks until the full JSON body arrives.
|
||||
self.set_http_response_body(0, body_size, &[]);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
let complete_body = std::mem::take(&mut self.non_streaming_response_buffer);
|
||||
match self.handle_non_streaming_response(&complete_body, provider_id) {
|
||||
Ok(serialized_body) => {
|
||||
self.set_http_response_body(0, body_size, &serialized_body);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ use common::{
|
|||
ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME, ARCH_ROUTING_HEADER,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, CHAT_COMPLETIONS_PATH, HEALTHZ_PATH,
|
||||
MODEL_SERVER_NAME, MODEL_SERVER_REQUEST_TIMEOUT_MS, REQUEST_ID_HEADER, TOOL_ROLE,
|
||||
TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE,
|
||||
X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL,
|
||||
TRACE_PARENT_HEADER, USER_ROLE, X_ARCH_STATE_HEADER,
|
||||
},
|
||||
errors::ServerError,
|
||||
http::{CallArgs, Client},
|
||||
|
|
@ -17,7 +16,6 @@ use common::{
|
|||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::{traits::HttpContext, types::Action};
|
||||
use serde_json::Value;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
time::{Duration, SystemTime, UNIX_EPOCH},
|
||||
|
|
@ -291,6 +289,10 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
|
||||
if end_of_stream && body_size == 0 {
|
||||
if !self.streaming_response && !self.non_streaming_response_buffer.is_empty() {
|
||||
let body = std::mem::take(&mut self.non_streaming_response_buffer);
|
||||
self.process_non_streaming_response_body(&body, 0);
|
||||
}
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
|
|
@ -326,15 +328,15 @@ impl HttpContext for StreamContext {
|
|||
}
|
||||
};
|
||||
|
||||
let body_utf8 = match String::from_utf8(body) {
|
||||
Ok(body_utf8) => body_utf8,
|
||||
Err(e) => {
|
||||
info!("could not convert to utf8: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if self.streaming_response {
|
||||
let body_utf8 = match String::from_utf8(body) {
|
||||
Ok(body_utf8) => body_utf8,
|
||||
Err(e) => {
|
||||
info!("could not convert to utf8: {}", e);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
debug!("streaming response");
|
||||
|
||||
if self.tool_calls.is_some() && !self.tool_calls.as_ref().unwrap().is_empty() {
|
||||
|
|
@ -359,70 +361,15 @@ impl HttpContext for StreamContext {
|
|||
self.set_http_response_body(0, body_size, response_str.as_bytes());
|
||||
self.tool_calls = None;
|
||||
}
|
||||
} else if let Some(tool_calls) = self.tool_calls.as_ref() {
|
||||
if !tool_calls.is_empty() {
|
||||
if self.arch_state.is_none() {
|
||||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
let mut data = match serde_json::from_str(&body_utf8) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not deserialize response, sending data as it is: {}",
|
||||
e
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
// use serde::Value to manipulate the json object and ensure that we don't lose any data
|
||||
if let Value::Object(ref mut map) = data {
|
||||
// serialize arch state and add to metadata
|
||||
let metadata = map
|
||||
.entry("metadata")
|
||||
.or_insert(Value::Object(serde_json::Map::new()));
|
||||
if metadata == &Value::Null {
|
||||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
let tool_call_message = self.generate_tool_call_message();
|
||||
let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_TOOL_CALL.to_string(),
|
||||
serde_json::Value::String(tool_call_message_str),
|
||||
);
|
||||
|
||||
let api_response_message = self.generate_api_response_message();
|
||||
let api_response_message_str =
|
||||
serde_json::to_string(&api_response_message).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_API_RESPONSE.to_string(),
|
||||
serde_json::Value::String(api_response_message_str),
|
||||
);
|
||||
|
||||
let fc_messages = vec![tool_call_message, api_response_message];
|
||||
|
||||
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
|
||||
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
|
||||
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
if let Some(arch_fc_response) = self.arch_fc_response.as_ref() {
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_FC_MODEL_RESPONSE.to_string(),
|
||||
serde_json::Value::String(
|
||||
serde_json::to_string(arch_fc_response).unwrap(),
|
||||
),
|
||||
);
|
||||
}
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
info!("plano <= developer: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
};
|
||||
} else {
|
||||
self.non_streaming_response_buffer.extend_from_slice(&body);
|
||||
if !end_of_stream {
|
||||
self.set_http_response_body(0, body_size, &[]);
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
let complete_body = std::mem::take(&mut self.non_streaming_response_buffer);
|
||||
self.process_non_streaming_response_body(&complete_body, body_size);
|
||||
}
|
||||
|
||||
debug!("recv [S={}] end_stream={}", self.context_id, end_of_stream);
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ use common::consts::{
|
|||
API_REQUEST_TIMEOUT_MS, ARCH_FC_MODEL_NAME, ARCH_INTERNAL_CLUSTER_NAME,
|
||||
ARCH_UPSTREAM_HOST_HEADER, ASSISTANT_ROLE, DEFAULT_TARGET_REQUEST_TIMEOUT_MS, MESSAGES_KEY,
|
||||
REQUEST_ID_HEADER, SYSTEM_ROLE, TOOL_ROLE, TRACE_PARENT_HEADER, USER_ROLE,
|
||||
X_ARCH_FC_MODEL_RESPONSE,
|
||||
X_ARCH_API_RESPONSE, X_ARCH_FC_MODEL_RESPONSE, X_ARCH_STATE_HEADER, X_ARCH_TOOL_CALL,
|
||||
};
|
||||
use common::errors::ServerError;
|
||||
use common::http::{CallArgs, Client};
|
||||
|
|
@ -18,6 +18,7 @@ use derivative::Derivative;
|
|||
use http::StatusCode;
|
||||
use log::{debug, info, warn};
|
||||
use proxy_wasm::traits::*;
|
||||
use serde_json::Value;
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
use std::rc::Rc;
|
||||
|
|
@ -66,6 +67,8 @@ pub struct StreamContext {
|
|||
pub traceparent: Option<String>,
|
||||
pub _tracing: Rc<Option<Tracing>>,
|
||||
pub arch_fc_response: Option<String>,
|
||||
/// Accumulates upstream non-streaming response chunks until end of stream.
|
||||
pub non_streaming_response_buffer: Vec<u8>,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
|
|
@ -100,6 +103,7 @@ impl StreamContext {
|
|||
start_upstream_llm_request_time: 0,
|
||||
time_to_first_token: None,
|
||||
arch_fc_response: None,
|
||||
non_streaming_response_buffer: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -803,6 +807,80 @@ impl StreamContext {
|
|||
self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes());
|
||||
self.resume_http_request();
|
||||
}
|
||||
|
||||
pub fn process_non_streaming_response_body(&mut self, body: &[u8], body_size: usize) {
|
||||
let body_utf8 = match String::from_utf8(body.to_vec()) {
|
||||
Ok(body_utf8) => body_utf8,
|
||||
Err(e) => {
|
||||
info!("could not convert to utf8: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(tool_calls) = self.tool_calls.as_ref() {
|
||||
if !tool_calls.is_empty() {
|
||||
if self.arch_state.is_none() {
|
||||
self.arch_state = Some(Vec::new());
|
||||
}
|
||||
|
||||
let mut data = match serde_json::from_str(&body_utf8) {
|
||||
Ok(data) => data,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"could not deserialize response, sending data as it is: {}",
|
||||
e
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
if let Value::Object(ref mut map) = data {
|
||||
let metadata = map
|
||||
.entry("metadata")
|
||||
.or_insert(Value::Object(serde_json::Map::new()));
|
||||
if metadata == &Value::Null {
|
||||
*metadata = Value::Object(serde_json::Map::new());
|
||||
}
|
||||
|
||||
let tool_call_message = self.generate_tool_call_message();
|
||||
let tool_call_message_str = serde_json::to_string(&tool_call_message).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_TOOL_CALL.to_string(),
|
||||
serde_json::Value::String(tool_call_message_str),
|
||||
);
|
||||
|
||||
let api_response_message = self.generate_api_response_message();
|
||||
let api_response_message_str =
|
||||
serde_json::to_string(&api_response_message).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_API_RESPONSE.to_string(),
|
||||
serde_json::Value::String(api_response_message_str),
|
||||
);
|
||||
|
||||
let fc_messages = vec![tool_call_message, api_response_message];
|
||||
|
||||
let fc_messages_str = serde_json::to_string(&fc_messages).unwrap();
|
||||
let arch_state = HashMap::from([("messages".to_string(), fc_messages_str)]);
|
||||
let arch_state_str = serde_json::to_string(&arch_state).unwrap();
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_STATE_HEADER.to_string(),
|
||||
serde_json::Value::String(arch_state_str),
|
||||
);
|
||||
|
||||
if let Some(arch_fc_response) = self.arch_fc_response.as_ref() {
|
||||
metadata.as_object_mut().unwrap().insert(
|
||||
X_ARCH_FC_MODEL_RESPONSE.to_string(),
|
||||
serde_json::Value::String(
|
||||
serde_json::to_string(arch_fc_response).unwrap(),
|
||||
),
|
||||
);
|
||||
}
|
||||
let data_serialized = serde_json::to_string(&data).unwrap();
|
||||
info!("plano <= developer: {}", data_serialized);
|
||||
self.set_http_response_body(0, body_size, data_serialized.as_bytes());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn check_intent_matched(model_server_response: &ChatCompletionsResponse) -> bool {
|
||||
|
|
|
|||
|
|
@ -19,7 +19,8 @@ run_hurl_with_retries() {
|
|||
local max_attempts=1
|
||||
local attempt=1
|
||||
|
||||
if [ "$demo_name" = "llm_routing/preference_based_routing" ]; then
|
||||
if [ "$demo_name" = "llm_routing/preference_based_routing" ] \
|
||||
|| [ "$demo_name" = "advanced/currency_exchange" ]; then
|
||||
max_attempts=3
|
||||
fi
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue