mirror of
https://github.com/katanemo/plano.git
synced 2026-05-18 13:45:15 +02:00
Add ability to stream a response (#50)
Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
parent
118bff7c7e
commit
9f3c845610
5 changed files with 251 additions and 104 deletions
|
|
@ -7,4 +7,5 @@ pub const USER_ROLE: &str = "user";
|
|||
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
|
||||
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
|
||||
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
|
||||
pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
|
||||
pub const MODEL_SERVER_NAME: &str = "model_server";
|
||||
|
|
|
|||
|
|
@ -232,14 +232,12 @@ impl RootContext for FilterContext {
|
|||
true
|
||||
}
|
||||
|
||||
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
Some(Box::new(StreamContext {
|
||||
host_header: None,
|
||||
ratelimit_selector: None,
|
||||
callouts: HashMap::new(),
|
||||
metrics: Rc::clone(&self.metrics),
|
||||
prompt_targets: Rc::clone(&self.prompt_targets),
|
||||
}))
|
||||
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
|
||||
Some(Box::new(StreamContext::new(
|
||||
context_id,
|
||||
Rc::clone(&self.metrics),
|
||||
Rc::clone(&self.prompt_targets),
|
||||
)))
|
||||
}
|
||||
|
||||
fn get_type(&self) -> Option<ContextType> {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use crate::consts::{
|
||||
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
|
||||
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH,
|
||||
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
|
||||
};
|
||||
use crate::filter_context::{embeddings_store, WasmMetrics};
|
||||
|
|
@ -10,13 +10,16 @@ use crate::stats::IncrementingMetric;
|
|||
use crate::tokenizer;
|
||||
use acap::cos;
|
||||
use http::StatusCode;
|
||||
use log::{debug, error, info, warn};
|
||||
use log::{debug, info, warn};
|
||||
use open_message_format_embeddings::models::{
|
||||
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
|
||||
};
|
||||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::open_ai::{ChatCompletions, Message};
|
||||
use public_types::common_types::open_ai::{
|
||||
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message,
|
||||
StreamOptions,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
|
|
@ -39,19 +42,40 @@ pub struct CallContext {
|
|||
response_handler_type: ResponseHandlerType,
|
||||
user_message: Option<String>,
|
||||
prompt_target: Option<PromptTarget>,
|
||||
request_body: ChatCompletions,
|
||||
request_body: ChatCompletionsRequest,
|
||||
similarity_scores: Option<Vec<(String, f64)>>,
|
||||
}
|
||||
|
||||
pub struct StreamContext {
|
||||
pub host_header: Option<String>,
|
||||
pub ratelimit_selector: Option<Header>,
|
||||
pub callouts: HashMap<u32, CallContext>,
|
||||
pub context_id: u32,
|
||||
pub metrics: Rc<WasmMetrics>,
|
||||
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
callouts: HashMap<u32, CallContext>,
|
||||
host_header: Option<String>,
|
||||
ratelimit_selector: Option<Header>,
|
||||
streaming_response: bool,
|
||||
response_tokens: usize,
|
||||
chat_completions_request: bool,
|
||||
}
|
||||
|
||||
impl StreamContext {
|
||||
pub fn new(
|
||||
context_id: u32,
|
||||
metrics: Rc<WasmMetrics>,
|
||||
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
|
||||
) -> Self {
|
||||
StreamContext {
|
||||
context_id,
|
||||
metrics,
|
||||
prompt_targets,
|
||||
callouts: HashMap::new(),
|
||||
host_header: None,
|
||||
ratelimit_selector: None,
|
||||
streaming_response: false,
|
||||
response_tokens: 0,
|
||||
chat_completions_request: false,
|
||||
}
|
||||
}
|
||||
fn save_host_header(&mut self) {
|
||||
// Save the host header to be used by filter logic later on.
|
||||
self.host_header = self.get_http_request_header(":host");
|
||||
|
|
@ -70,7 +94,8 @@ impl StreamContext {
|
|||
// The gateway can start gathering information necessary for routing. For now change the path to an
|
||||
// OpenAI API path.
|
||||
Some(path) if path == "/llmrouting" => {
|
||||
self.set_http_request_header(":path", Some("/v1/chat/completions"));
|
||||
self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH));
|
||||
self.chat_completions_request = true;
|
||||
}
|
||||
// Otherwise let the filter continue.
|
||||
_ => (),
|
||||
|
|
@ -86,21 +111,26 @@ impl StreamContext {
|
|||
});
|
||||
}
|
||||
|
||||
fn send_server_error(&self, error: String) {
|
||||
fn send_server_error(&self, error: String, override_status_code: Option<StatusCode>) {
|
||||
debug!("server error occurred: {}", error);
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
override_status_code
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.as_u16()
|
||||
.into(),
|
||||
vec![],
|
||||
Some(error.as_bytes()),
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
|
||||
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
|
||||
Ok(embedding_response) => embedding_response,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!("Error deserializing embedding response: {:?}", e));
|
||||
return;
|
||||
return self.send_server_error(
|
||||
format!("Error deserializing embedding response: {:?}", e),
|
||||
None,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -115,19 +145,15 @@ impl StreamContext {
|
|||
let prompt_target_embeddings = match embeddings_store().read() {
|
||||
Ok(embeddings) => embeddings,
|
||||
Err(e) => {
|
||||
let error_message = format!("Error reading embeddings store: {:?}", e);
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message);
|
||||
return;
|
||||
return self
|
||||
.send_server_error(format!("Error reading embeddings store: {:?}", e), None);
|
||||
}
|
||||
};
|
||||
|
||||
let prompt_targets = match self.prompt_targets.read() {
|
||||
Ok(prompt_targets) => prompt_targets,
|
||||
Err(e) => {
|
||||
let error_message = format!("Error reading prompt targets: {:?}", e);
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message);
|
||||
self.send_server_error(format!("Error reading prompt targets: {:?}", e), None);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
@ -220,12 +246,13 @@ impl StreamContext {
|
|||
match serde_json::from_slice(&body) {
|
||||
Ok(zeroshot_response) => zeroshot_response,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Error deserializing zeroshot intent detection response: {:?}",
|
||||
e
|
||||
self.send_server_error(
|
||||
format!(
|
||||
"Error deserializing zeroshot intent detection response: {:?}",
|
||||
e
|
||||
),
|
||||
None,
|
||||
);
|
||||
info!("body: {:?}", String::from_utf8(body).unwrap());
|
||||
self.resume_http_request();
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
|
@ -319,10 +346,12 @@ impl StreamContext {
|
|||
parameters: tools_parameters,
|
||||
};
|
||||
|
||||
let chat_completions = ChatCompletions {
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(vec![tools_defintion]),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
};
|
||||
|
||||
let msg_body = match serde_json::to_string(&chat_completions) {
|
||||
|
|
@ -331,11 +360,10 @@ impl StreamContext {
|
|||
msg_body
|
||||
}
|
||||
Err(e) => {
|
||||
self.send_server_error(format!(
|
||||
"Error serializing request_params: {:?}",
|
||||
e
|
||||
));
|
||||
return;
|
||||
return self.send_server_error(
|
||||
format!("Error serializing request_params: {:?}", e),
|
||||
None,
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -424,12 +452,10 @@ impl StreamContext {
|
|||
.arguments
|
||||
.contains_key(¶m.name)
|
||||
{
|
||||
warn!("boltfc did not extract required parameter: {}", param.name);
|
||||
return self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
Some("missing required parameter".as_bytes()),
|
||||
);
|
||||
self.send_server_error(
|
||||
format!("missing required parameter: {}", param.name),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
)
|
||||
}
|
||||
}
|
||||
});
|
||||
|
|
@ -510,17 +536,19 @@ impl StreamContext {
|
|||
}
|
||||
});
|
||||
|
||||
let request_message: ChatCompletions = ChatCompletions {
|
||||
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages,
|
||||
tools: None,
|
||||
stream: callout_context.request_body.stream,
|
||||
stream_options: callout_context.request_body.stream_options,
|
||||
};
|
||||
|
||||
let json_string = match serde_json::to_string(&request_message) {
|
||||
let json_string = match serde_json::to_string(&chat_completions_request) {
|
||||
Ok(json_string) => json_string,
|
||||
Err(e) => {
|
||||
self.send_server_error(format!("Error serializing request_body: {:?}", e));
|
||||
return;
|
||||
return self
|
||||
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
|
|
@ -528,22 +556,21 @@ impl StreamContext {
|
|||
json_string
|
||||
);
|
||||
|
||||
let request_body = callout_context.request_body;
|
||||
|
||||
// Tokenize and Ratelimit.
|
||||
if let Some(selector) = self.ratelimit_selector.take() {
|
||||
if let Ok(token_count) = tokenizer::token_count(&request_body.model, &json_string) {
|
||||
if let Ok(token_count) =
|
||||
tokenizer::token_count(&chat_completions_request.model, &json_string)
|
||||
{
|
||||
match ratelimit::ratelimits(None).read().unwrap().check_limit(
|
||||
request_body.model,
|
||||
chat_completions_request.model,
|
||||
selector,
|
||||
NonZero::new(token_count as u32).unwrap(),
|
||||
) {
|
||||
Ok(_) => (),
|
||||
Err(err) => {
|
||||
self.send_http_response(
|
||||
StatusCode::TOO_MANY_REQUESTS.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Exceeded Ratelimit: {}", err).as_bytes()),
|
||||
self.send_server_error(
|
||||
format!("Exceeded Ratelimit: {}", err),
|
||||
Some(StatusCode::TOO_MANY_REQUESTS),
|
||||
);
|
||||
self.metrics.ratelimited_rq.increment(1);
|
||||
return;
|
||||
|
|
@ -583,31 +610,36 @@ impl HttpContext for StreamContext {
|
|||
|
||||
// Deserialize body into spec.
|
||||
// Currently OpenAI API.
|
||||
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
self.send_http_response(
|
||||
StatusCode::BAD_REQUEST.as_u16().into(),
|
||||
vec![],
|
||||
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
|
||||
let mut deserialized_body: ChatCompletionsRequest =
|
||||
match self.get_http_request_body(0, body_size) {
|
||||
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
|
||||
Ok(deserialized) => deserialized,
|
||||
Err(msg) => {
|
||||
self.send_server_error(
|
||||
format!("Failed to deserialize: {}", msg),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
self.send_server_error(
|
||||
format!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
),
|
||||
None,
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
self.send_http_response(
|
||||
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
|
||||
vec![],
|
||||
None,
|
||||
);
|
||||
error!(
|
||||
"Failed to obtain body bytes even though body_size is {}",
|
||||
body_size
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
};
|
||||
|
||||
self.streaming_response = deserialized_body.stream;
|
||||
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
|
||||
deserialized_body.stream_options = Some(StreamOptions {
|
||||
include_usage: true,
|
||||
});
|
||||
}
|
||||
|
||||
let user_message = match deserialized_body
|
||||
.messages
|
||||
|
|
@ -682,6 +714,92 @@ impl HttpContext for StreamContext {
|
|||
|
||||
Action::Pause
|
||||
}
|
||||
|
||||
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
|
||||
if !self.chat_completions_request {
|
||||
return Action::Continue;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"recv [S={}] bytes={} end_stream={}",
|
||||
self.context_id, body_size, end_of_stream
|
||||
);
|
||||
|
||||
if !end_of_stream && !self.streaming_response {
|
||||
return Action::Pause;
|
||||
}
|
||||
|
||||
let body = self
|
||||
.get_http_response_body(0, body_size)
|
||||
.expect("cant get response body");
|
||||
|
||||
let body_str = String::from_utf8(body).expect("body is not utf-8");
|
||||
|
||||
if self.streaming_response {
|
||||
debug!("streaming response");
|
||||
let chat_completions_data = match body_str.split_once("data: ") {
|
||||
Some((_, chat_completions_data)) => chat_completions_data,
|
||||
None => {
|
||||
self.send_server_error(String::from("parsing error in streaming data"), None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
let chat_completions_chunk_response: ChatCompletionChunkResponse =
|
||||
match serde_json::from_str(chat_completions_data) {
|
||||
Ok(de) => de,
|
||||
Err(_) => {
|
||||
if chat_completions_data != "[NONE]" {
|
||||
self.send_server_error(
|
||||
String::from("error in streaming response"),
|
||||
None,
|
||||
);
|
||||
return Action::Continue;
|
||||
}
|
||||
return Action::Continue;
|
||||
}
|
||||
};
|
||||
|
||||
if let Some(content) = chat_completions_chunk_response
|
||||
.choices
|
||||
.first()
|
||||
.unwrap()
|
||||
.delta
|
||||
.content
|
||||
.as_ref()
|
||||
{
|
||||
let model = &chat_completions_chunk_response.model;
|
||||
let token_count = tokenizer::token_count(model, content).unwrap_or(0);
|
||||
self.response_tokens += token_count;
|
||||
}
|
||||
} else {
|
||||
debug!("non streaming response");
|
||||
let chat_completions_response: ChatCompletionsResponse =
|
||||
match serde_json::from_str(&body_str) {
|
||||
Ok(de) => de,
|
||||
Err(e) => {
|
||||
self.send_server_error(
|
||||
format!(
|
||||
"error in non-streaming response: {}\n response was={}",
|
||||
e, body_str
|
||||
),
|
||||
None,
|
||||
);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
self.response_tokens += chat_completions_response.usage.completions_tokens;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"recv [S={}] total_tokens={} end_stream={}",
|
||||
self.context_id, self.response_tokens, end_of_stream
|
||||
);
|
||||
|
||||
// TODO:: ratelimit based on response tokens.
|
||||
Action::Continue
|
||||
}
|
||||
}
|
||||
|
||||
impl Context for StreamContext {
|
||||
|
|
@ -711,9 +829,10 @@ impl Context for StreamContext {
|
|||
}
|
||||
}
|
||||
} else {
|
||||
let error_message = "No response body in inline HTTP request";
|
||||
warn!("{}", error_message);
|
||||
self.send_server_error(error_message.to_owned());
|
||||
self.send_server_error(
|
||||
String::from("No response body in inline HTTP request"),
|
||||
None,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue