diff --git a/envoyfilter/src/consts.rs b/envoyfilter/src/consts.rs index ffb999b7..6b5f17e2 100644 --- a/envoyfilter/src/consts.rs +++ b/envoyfilter/src/consts.rs @@ -7,4 +7,5 @@ pub const USER_ROLE: &str = "user"; pub const GPT_35_TURBO: &str = "gpt-3.5-turbo"; pub const BOLT_FC_CLUSTER: &str = "bolt_fc_1b"; pub const BOLT_FC_REQUEST_TIMEOUT_MS: u64 = 120000; // 2 minutes +pub const OPENAI_CHAT_COMPLETIONS_PATH: &str = "/v1/chat/completions"; pub const MODEL_SERVER_NAME: &str = "model_server"; diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs index cbe6f0e9..4774b010 100644 --- a/envoyfilter/src/filter_context.rs +++ b/envoyfilter/src/filter_context.rs @@ -232,14 +232,12 @@ impl RootContext for FilterContext { true } - fn create_http_context(&self, _context_id: u32) -> Option> { - Some(Box::new(StreamContext { - host_header: None, - ratelimit_selector: None, - callouts: HashMap::new(), - metrics: Rc::clone(&self.metrics), - prompt_targets: Rc::clone(&self.prompt_targets), - })) + fn create_http_context(&self, context_id: u32) -> Option> { + Some(Box::new(StreamContext::new( + context_id, + Rc::clone(&self.metrics), + Rc::clone(&self.prompt_targets), + ))) } fn get_type(&self) -> Option { diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 9d71099a..771772ec 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -1,6 +1,6 @@ use crate::consts::{ BOLT_FC_CLUSTER, BOLT_FC_REQUEST_TIMEOUT_MS, DEFAULT_EMBEDDING_MODEL, DEFAULT_INTENT_MODEL, - DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, + DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, OPENAI_CHAT_COMPLETIONS_PATH, RATELIMIT_SELECTOR_HEADER_KEY, SYSTEM_ROLE, USER_ROLE, }; use crate::filter_context::{embeddings_store, WasmMetrics}; @@ -10,13 +10,16 @@ use crate::stats::IncrementingMetric; use crate::tokenizer; use acap::cos; use http::StatusCode; -use log::{debug, error, info, warn}; +use log::{debug, info, warn}; use open_message_format_embeddings::models::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; use proxy_wasm::traits::*; use proxy_wasm::types::*; -use public_types::common_types::open_ai::{ChatCompletions, Message}; +use public_types::common_types::open_ai::{ + ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message, + StreamOptions, +}; use public_types::common_types::{ BoltFCResponse, BoltFCToolsCall, EmbeddingType, ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest, ZeroShotClassificationResponse, @@ -39,19 +42,40 @@ pub struct CallContext { response_handler_type: ResponseHandlerType, user_message: Option, prompt_target: Option, - request_body: ChatCompletions, + request_body: ChatCompletionsRequest, similarity_scores: Option>, } pub struct StreamContext { - pub host_header: Option, - pub ratelimit_selector: Option
, - pub callouts: HashMap, + pub context_id: u32, pub metrics: Rc, pub prompt_targets: Rc>>, + callouts: HashMap, + host_header: Option, + ratelimit_selector: Option
, + streaming_response: bool, + response_tokens: usize, + chat_completions_request: bool, } impl StreamContext { + pub fn new( + context_id: u32, + metrics: Rc, + prompt_targets: Rc>>, + ) -> Self { + StreamContext { + context_id, + metrics, + prompt_targets, + callouts: HashMap::new(), + host_header: None, + ratelimit_selector: None, + streaming_response: false, + response_tokens: 0, + chat_completions_request: false, + } + } fn save_host_header(&mut self) { // Save the host header to be used by filter logic later on. self.host_header = self.get_http_request_header(":host"); @@ -70,7 +94,8 @@ impl StreamContext { // The gateway can start gathering information necessary for routing. For now change the path to an // OpenAI API path. Some(path) if path == "/llmrouting" => { - self.set_http_request_header(":path", Some("/v1/chat/completions")); + self.set_http_request_header(":path", Some(OPENAI_CHAT_COMPLETIONS_PATH)); + self.chat_completions_request = true; } // Otherwise let the filter continue. _ => (), @@ -86,21 +111,26 @@ impl StreamContext { }); } - fn send_server_error(&self, error: String) { + fn send_server_error(&self, error: String, override_status_code: Option) { debug!("server error occurred: {}", error); self.send_http_response( - StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), + override_status_code + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + .as_u16() + .into(), vec![], Some(error.as_bytes()), - ) + ); } fn embeddings_handler(&mut self, body: Vec, mut callout_context: CallContext) { let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { Ok(embedding_response) => embedding_response, Err(e) => { - self.send_server_error(format!("Error deserializing embedding response: {:?}", e)); - return; + return self.send_server_error( + format!("Error deserializing embedding response: {:?}", e), + None, + ); } }; @@ -115,19 +145,15 @@ impl StreamContext { let prompt_target_embeddings = match embeddings_store().read() { Ok(embeddings) => embeddings, Err(e) => { - let error_message = format!("Error reading embeddings store: {:?}", e); - warn!("{}", error_message); - self.send_server_error(error_message); - return; + return self + .send_server_error(format!("Error reading embeddings store: {:?}", e), None); } }; let prompt_targets = match self.prompt_targets.read() { Ok(prompt_targets) => prompt_targets, Err(e) => { - let error_message = format!("Error reading prompt targets: {:?}", e); - warn!("{}", error_message); - self.send_server_error(error_message); + self.send_server_error(format!("Error reading prompt targets: {:?}", e), None); return; } }; @@ -220,12 +246,13 @@ impl StreamContext { match serde_json::from_slice(&body) { Ok(zeroshot_response) => zeroshot_response, Err(e) => { - warn!( - "Error deserializing zeroshot intent detection response: {:?}", - e + self.send_server_error( + format!( + "Error deserializing zeroshot intent detection response: {:?}", + e + ), + None, ); - info!("body: {:?}", String::from_utf8(body).unwrap()); - self.resume_http_request(); return; } }; @@ -319,10 +346,12 @@ impl StreamContext { parameters: tools_parameters, }; - let chat_completions = ChatCompletions { + let chat_completions = ChatCompletionsRequest { model: GPT_35_TURBO.to_string(), messages: callout_context.request_body.messages.clone(), tools: Some(vec![tools_defintion]), + stream: false, + stream_options: None, }; let msg_body = match serde_json::to_string(&chat_completions) { @@ -331,11 +360,10 @@ impl StreamContext { msg_body } Err(e) => { - self.send_server_error(format!( - "Error serializing request_params: {:?}", - e - )); - return; + return self.send_server_error( + format!("Error serializing request_params: {:?}", e), + None, + ); } }; @@ -424,12 +452,10 @@ impl StreamContext { .arguments .contains_key(¶m.name) { - warn!("boltfc did not extract required parameter: {}", param.name); - return self.send_http_response( - StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), - vec![], - Some("missing required parameter".as_bytes()), - ); + self.send_server_error( + format!("missing required parameter: {}", param.name), + Some(StatusCode::BAD_REQUEST), + ) } } }); @@ -510,17 +536,19 @@ impl StreamContext { } }); - let request_message: ChatCompletions = ChatCompletions { + let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { model: GPT_35_TURBO.to_string(), messages, tools: None, + stream: callout_context.request_body.stream, + stream_options: callout_context.request_body.stream_options, }; - let json_string = match serde_json::to_string(&request_message) { + let json_string = match serde_json::to_string(&chat_completions_request) { Ok(json_string) => json_string, Err(e) => { - self.send_server_error(format!("Error serializing request_body: {:?}", e)); - return; + return self + .send_server_error(format!("Error serializing request_body: {:?}", e), None); } }; debug!( @@ -528,22 +556,21 @@ impl StreamContext { json_string ); - let request_body = callout_context.request_body; - // Tokenize and Ratelimit. if let Some(selector) = self.ratelimit_selector.take() { - if let Ok(token_count) = tokenizer::token_count(&request_body.model, &json_string) { + if let Ok(token_count) = + tokenizer::token_count(&chat_completions_request.model, &json_string) + { match ratelimit::ratelimits(None).read().unwrap().check_limit( - request_body.model, + chat_completions_request.model, selector, NonZero::new(token_count as u32).unwrap(), ) { Ok(_) => (), Err(err) => { - self.send_http_response( - StatusCode::TOO_MANY_REQUESTS.as_u16().into(), - vec![], - Some(format!("Exceeded Ratelimit: {}", err).as_bytes()), + self.send_server_error( + format!("Exceeded Ratelimit: {}", err), + Some(StatusCode::TOO_MANY_REQUESTS), ); self.metrics.ratelimited_rq.increment(1); return; @@ -583,31 +610,36 @@ impl HttpContext for StreamContext { // Deserialize body into spec. // Currently OpenAI API. - let deserialized_body: ChatCompletions = match self.get_http_request_body(0, body_size) { - Some(body_bytes) => match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(msg) => { - self.send_http_response( - StatusCode::BAD_REQUEST.as_u16().into(), - vec![], - Some(format!("Failed to deserialize: {}", msg).as_bytes()), + let mut deserialized_body: ChatCompletionsRequest = + match self.get_http_request_body(0, body_size) { + Some(body_bytes) => match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(msg) => { + self.send_server_error( + format!("Failed to deserialize: {}", msg), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }, + None => { + self.send_server_error( + format!( + "Failed to obtain body bytes even though body_size is {}", + body_size + ), + None, ); return Action::Pause; } - }, - None => { - self.send_http_response( - StatusCode::INTERNAL_SERVER_ERROR.as_u16().into(), - vec![], - None, - ); - error!( - "Failed to obtain body bytes even though body_size is {}", - body_size - ); - return Action::Pause; - } - }; + }; + + self.streaming_response = deserialized_body.stream; + if deserialized_body.stream && deserialized_body.stream_options.is_none() { + deserialized_body.stream_options = Some(StreamOptions { + include_usage: true, + }); + } let user_message = match deserialized_body .messages @@ -682,6 +714,92 @@ impl HttpContext for StreamContext { Action::Pause } + + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + if !self.chat_completions_request { + return Action::Continue; + } + + debug!( + "recv [S={}] bytes={} end_stream={}", + self.context_id, body_size, end_of_stream + ); + + if !end_of_stream && !self.streaming_response { + return Action::Pause; + } + + let body = self + .get_http_response_body(0, body_size) + .expect("cant get response body"); + + let body_str = String::from_utf8(body).expect("body is not utf-8"); + + if self.streaming_response { + debug!("streaming response"); + let chat_completions_data = match body_str.split_once("data: ") { + Some((_, chat_completions_data)) => chat_completions_data, + None => { + self.send_server_error(String::from("parsing error in streaming data"), None); + return Action::Pause; + } + }; + + let chat_completions_chunk_response: ChatCompletionChunkResponse = + match serde_json::from_str(chat_completions_data) { + Ok(de) => de, + Err(_) => { + if chat_completions_data != "[NONE]" { + self.send_server_error( + String::from("error in streaming response"), + None, + ); + return Action::Continue; + } + return Action::Continue; + } + }; + + if let Some(content) = chat_completions_chunk_response + .choices + .first() + .unwrap() + .delta + .content + .as_ref() + { + let model = &chat_completions_chunk_response.model; + let token_count = tokenizer::token_count(model, content).unwrap_or(0); + self.response_tokens += token_count; + } + } else { + debug!("non streaming response"); + let chat_completions_response: ChatCompletionsResponse = + match serde_json::from_str(&body_str) { + Ok(de) => de, + Err(e) => { + self.send_server_error( + format!( + "error in non-streaming response: {}\n response was={}", + e, body_str + ), + None, + ); + return Action::Pause; + } + }; + + self.response_tokens += chat_completions_response.usage.completions_tokens; + } + + debug!( + "recv [S={}] total_tokens={} end_stream={}", + self.context_id, self.response_tokens, end_of_stream + ); + + // TODO:: ratelimit based on response tokens. + Action::Continue + } } impl Context for StreamContext { @@ -711,9 +829,10 @@ impl Context for StreamContext { } } } else { - let error_message = "No response body in inline HTTP request"; - warn!("{}", error_message); - self.send_server_error(error_message.to_owned()); + self.send_server_error( + String::from("No response body in inline HTTP request"), + None, + ); } } } diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index ecaba388..6b3e80cd 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -36,14 +36,10 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .call_proxy_on_request_headers(http_context, 0, false) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) .returning(Some("api.openai.com")) - .expect_add_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("content-length"), - Some(""), - ) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .returning(Some("/llmrouting")) - .expect_add_header_map_value( + .expect_replace_header_map_value( Some(MapType::HttpRequestHeaders), Some(":path"), Some("/v1/chat/completions"), @@ -196,7 +192,7 @@ prompt_targets: - name: city ratelimits: - - provider: gpt-4 + - provider: gpt-3.5-turbo selector: key: selector-key value: selector-value @@ -245,14 +241,10 @@ fn successful_request_to_open_ai_chat_completions() { .call_proxy_on_request_headers(http_context, 0, false) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) .returning(Some("api.openai.com")) - .expect_add_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("content-length"), - Some(""), - ) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .returning(Some("/llmrouting")) - .expect_add_header_map_value( + .expect_replace_header_map_value( Some(MapType::HttpRequestHeaders), Some(":path"), Some("/v1/chat/completions"), @@ -289,9 +281,9 @@ fn successful_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) - // TODO: assert that the model field was added. - .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .expect_log(Some(LogLevel::Debug), None) + .expect_http_call(Some("model_server"), None, None, None, None) + .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) .execute_and_expect(ReturnType::Action(Action::Pause)) .unwrap(); @@ -335,14 +327,10 @@ fn bad_request_to_open_ai_chat_completions() { .call_proxy_on_request_headers(http_context, 0, false) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":host")) .returning(Some("api.openai.com")) - .expect_add_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("content-length"), - Some(""), - ) + .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .returning(Some("/llmrouting")) - .expect_add_header_map_value( + .expect_replace_header_map_value( Some(MapType::HttpRequestHeaders), Some(":path"), Some("/v1/chat/completions"), @@ -377,6 +365,7 @@ fn bad_request_to_open_ai_chat_completions() { ) .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) + .expect_log(Some(LogLevel::Debug), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), None, @@ -485,6 +474,10 @@ fn request_ratelimited() { None, ) .expect_metric_increment("ratelimited_rq", 1) + .expect_log( + Some(LogLevel::Debug), + Some("server error occurred: Exceeded Ratelimit: Not allowed"), + ) .execute_and_expect(ReturnType::None) .unwrap(); } diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index 85c2e2f8..05a17fdc 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -94,12 +94,21 @@ pub mod open_ai { use super::ToolsDefinition; #[derive(Debug, Clone, Serialize, Deserialize)] - pub struct ChatCompletions { + pub struct ChatCompletionsRequest { #[serde(default)] pub model: String, pub messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] pub tools: Option>, + #[serde(default)] + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct StreamOptions { + pub include_usage: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -109,6 +118,33 @@ pub mod open_ai { #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, } + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct ChatCompletionsResponse { + pub usage: Usage, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct Usage { + pub completions_tokens: usize, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct ChatCompletionChunkResponse { + pub model: String, + pub choices: Vec, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct Choice { + pub delta: Delta, + // TODO: could this be an enum? + pub finish_reason: Option, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct Delta { + pub content: Option, + } } #[derive(Debug, Clone, Serialize, Deserialize)]