Add ability to stream a response (#50)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-09-17 16:12:41 -07:00 committed by GitHub
parent 118bff7c7e
commit 9f3c845610
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 251 additions and 104 deletions

View file

@ -7,4 +7,5 @@ pub const USER_ROLE: &str = "user";
pub const GPT_35_TURBO: &str = "gpt-3.5-turbo";
pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b";
pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes
pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions";
pub const MODEL_SERVER_NAME: &str = "model_server";

View file

@ -232,14 +232,12 @@ impl RootContext for FilterContext {
true
}
fn create_http_context(&self, _context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(StreamContext {
host_header: None,
ratelimit_selector: None,
callouts: HashMap::new(),
metrics: Rc::clone(&self.metrics),
prompt_targets: Rc::clone(&self.prompt_targets),
}))
fn create_http_context(&self, context_id: u32) -> Option<Box<dyn HttpContext>> {
Some(Box::new(StreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.prompt_targets),
)))
}
fn get_type(&self) -> Option<ContextType> {

View file

@ -1,6 +1,6 @@
use crate::consts::{
BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH,
RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE,
};
use crate::filter_context::{embeddings_store, WasmMetrics};
@ -10,13 +10,16 @@ use crate::stats::IncrementingMetric;
use crate::tokenizer;
use acap::cos;
use http::StatusCode;
use log::{debug, error, info, warn};
use log::{debug, info, warn};
use open_message_format_embeddings::models::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{ChatCompletions, Message};
use public_types::common_types::open_ai::{
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message,
StreamOptions,
};
use public_types::common_types::{
BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
@ -39,19 +42,40 @@ pub struct CallContext {
response_handler_type: ResponseHandlerType,
user_message: Option<String>,
prompt_target: Option<PromptTarget>,
request_body: ChatCompletions,
request_body: ChatCompletionsRequest,
similarity_scores: Option<Vec<(String, f64)>>,
}
pub struct StreamContext {
pub host_header: Option<String>,
pub ratelimit_selector: Option<Header>,
pub callouts: HashMap<u32, CallContext>,
pub context_id: u32,
pub metrics: Rc<WasmMetrics>,
pub prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
callouts: HashMap<u32, CallContext>,
host_header: Option<String>,
ratelimit_selector: Option<Header>,
streaming_response: bool,
response_tokens: usize,
chat_completions_request: bool,
}
impl StreamContext {
pub fn new(
context_id: u32,
metrics: Rc<WasmMetrics>,
prompt_targets: Rc<RwLock<HashMap<String, PromptTarget>>>,
) -> Self {
StreamContext {
context_id,
metrics,
prompt_targets,
callouts: HashMap::new(),
host_header: None,
ratelimit_selector: None,
streaming_response: false,
response_tokens: 0,
chat_completions_request: false,
}
}
fn save_host_header(&mut self) {
// Save the host header to be used by filter logic later on.
self.host_header = self.get_http_request_header(":host");
@ -70,7 +94,8 @@ impl StreamContext {
// The gateway can start gathering information necessary for routing. For now change the path to an
// OpenAI API path.
Some(path) if path == "/llmrouting" => {
self.set_http_request_header(":path", Some("/v1/chat/completions"));
self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH));
self.chat_completions_request = true;
}
// Otherwise let the filter continue.
_ => (),
@ -86,21 +111,26 @@ impl StreamContext {
});
}
fn send_server_error(&self, error: String) {
fn send_server_error(&self, error: String, override_status_code: Option<StatusCode>) {
debug!("server error occurred: {}", error);
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
override_status_code
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR)
.as_u16()
.into(),
vec![],
Some(error.as_bytes()),
)
);
}
fn embeddings_handler(&mut self, body: Vec<u8>, mut callout_context: CallContext) {
let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) {
Ok(embedding_response) => embedding_response,
Err(e) => {
self.send_server_error(format!("Error deserializing embedding response: {:?}", e));
return;
return self.send_server_error(
format!("Error deserializing embedding response: {:?}", e),
None,
);
}
};
@ -115,19 +145,15 @@ impl StreamContext {
let prompt_target_embeddings = match embeddings_store().read() {
Ok(embeddings) => embeddings,
Err(e) => {
let error_message = format!("Error reading embeddings store: {:?}", e);
warn!("{}", error_message);
self.send_server_error(error_message);
return;
return self
.send_server_error(format!("Error reading embeddings store: {:?}", e), None);
}
};
let prompt_targets = match self.prompt_targets.read() {
Ok(prompt_targets) => prompt_targets,
Err(e) => {
let error_message = format!("Error reading prompt targets: {:?}", e);
warn!("{}", error_message);
self.send_server_error(error_message);
self.send_server_error(format!("Error reading prompt targets: {:?}", e), None);
return;
}
};
@ -220,12 +246,13 @@ impl StreamContext {
match serde_json::from_slice(&body) {
Ok(zeroshot_response) => zeroshot_response,
Err(e) => {
warn!(
"Error deserializing zeroshot intent detection response: {:?}",
e
self.send_server_error(
format!(
"Error deserializing zeroshot intent detection response: {:?}",
e
),
None,
);
info!("body: {:?}", String::from_utf8(body).unwrap());
self.resume_http_request();
return;
}
};
@ -319,10 +346,12 @@ impl StreamContext {
parameters: tools_parameters,
};
let chat_completions = ChatCompletions {
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(vec![tools_defintion]),
stream: false,
stream_options: None,
};
let msg_body = match serde_json::to_string(&chat_completions) {
@ -331,11 +360,10 @@ impl StreamContext {
msg_body
}
Err(e) => {
self.send_server_error(format!(
"Error serializing request_params: {:?}",
e
));
return;
return self.send_server_error(
format!("Error serializing request_params: {:?}", e),
None,
);
}
};
@ -424,12 +452,10 @@ impl StreamContext {
.arguments
.contains_key(&param.name)
{
warn!("boltfc did not extract required parameter: {}", param.name);
return self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
Some("missing required parameter".as_bytes()),
);
self.send_server_error(
format!("missing required parameter: {}", param.name),
Some(StatusCode::BAD_REQUEST),
)
}
}
});
@ -510,17 +536,19 @@ impl StreamContext {
}
});
let request_message: ChatCompletions = ChatCompletions {
let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages,
tools: None,
stream: callout_context.request_body.stream,
stream_options: callout_context.request_body.stream_options,
};
let json_string = match serde_json::to_string(&request_message) {
let json_string = match serde_json::to_string(&chat_completions_request) {
Ok(json_string) => json_string,
Err(e) => {
self.send_server_error(format!("Error serializing request_body: {:?}", e));
return;
return self
.send_server_error(format!("Error serializing request_body: {:?}", e), None);
}
};
debug!(
@ -528,22 +556,21 @@ impl StreamContext {
json_string
);
let request_body = callout_context.request_body;
// Tokenize and Ratelimit.
if let Some(selector) = self.ratelimit_selector.take() {
if let Ok(token_count) = tokenizer::token_count(&request_body.model, &json_string) {
if let Ok(token_count) =
tokenizer::token_count(&chat_completions_request.model, &json_string)
{
match ratelimit::ratelimits(None).read().unwrap().check_limit(
request_body.model,
chat_completions_request.model,
selector,
NonZero::new(token_count as u32).unwrap(),
) {
Ok(_) => (),
Err(err) => {
self.send_http_response(
StatusCode::TOO_MANY_REQUESTS.as_u16().into(),
vec![],
Some(format!("Exceeded Ratelimit: {}", err).as_bytes()),
self.send_server_error(
format!("Exceeded Ratelimit: {}", err),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return;
@ -583,31 +610,36 @@ impl HttpContext for StreamContext {
// Deserialize body into spec.
// Currently OpenAI API.
let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
self.send_http_response(
StatusCode::BAD_REQUEST.as_u16().into(),
vec![],
Some(format!("Failed to deserialize: {}", msg).as_bytes()),
let mut deserialized_body: ChatCompletionsRequest =
match self.get_http_request_body(0, body_size) {
Some(body_bytes) => match serde_json::from_slice(&body_bytes) {
Ok(deserialized) => deserialized,
Err(msg) => {
self.send_server_error(
format!("Failed to deserialize: {}", msg),
Some(StatusCode::BAD_REQUEST),
);
return Action::Pause;
}
},
None => {
self.send_server_error(
format!(
"Failed to obtain body bytes even though body_size is {}",
body_size
),
None,
);
return Action::Pause;
}
},
None => {
self.send_http_response(
StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(),
vec![],
None,
);
error!(
"Failed to obtain body bytes even though body_size is {}",
body_size
);
return Action::Pause;
}
};
};
self.streaming_response = deserialized_body.stream;
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
deserialized_body.stream_options = Some(StreamOptions {
include_usage: true,
});
}
let user_message = match deserialized_body
.messages
@ -682,6 +714,92 @@ impl HttpContext for StreamContext {
Action::Pause
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
if !self.chat_completions_request {
return Action::Continue;
}
debug!(
"recv [S={}] bytes={} end_stream={}",
self.context_id, body_size, end_of_stream
);
if !end_of_stream && !self.streaming_response {
return Action::Pause;
}
let body = self
.get_http_response_body(0, body_size)
.expect("cant get response body");
let body_str = String::from_utf8(body).expect("body is not utf-8");
if self.streaming_response {
debug!("streaming response");
let chat_completions_data = match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => chat_completions_data,
None => {
self.send_server_error(String::from("parsing error in streaming data"), None);
return Action::Pause;
}
};
let chat_completions_chunk_response: ChatCompletionChunkResponse =
match serde_json::from_str(chat_completions_data) {
Ok(de) => de,
Err(_) => {
if chat_completions_data != "[NONE]" {
self.send_server_error(
String::from("error in streaming response"),
None,
);
return Action::Continue;
}
return Action::Continue;
}
};
if let Some(content) = chat_completions_chunk_response
.choices
.first()
.unwrap()
.delta
.content
.as_ref()
{
let model = &chat_completions_chunk_response.model;
let token_count = tokenizer::token_count(model, content).unwrap_or(0);
self.response_tokens += token_count;
}
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
match serde_json::from_str(&body_str) {
Ok(de) => de,
Err(e) => {
self.send_server_error(
format!(
"error in non-streaming response: {}\n response was={}",
e, body_str
),
None,
);
return Action::Pause;
}
};
self.response_tokens += chat_completions_response.usage.completions_tokens;
}
debug!(
"recv [S={}] total_tokens={} end_stream={}",
self.context_id, self.response_tokens, end_of_stream
);
// TODO:: ratelimit based on response tokens.
Action::Continue
}
}
impl Context for StreamContext {
@ -711,9 +829,10 @@ impl Context for StreamContext {
}
}
} else {
let error_message = "No response body in inline HTTP request";
warn!("{}", error_message);
self.send_server_error(error_message.to_owned());
self.send_server_error(
String::from("No response body in inline HTTP request"),
None,
);
}
}
}