From 21e7fe2cef563b46f55fcd369a571b1dca8c9eab Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Thu, 17 Oct 2024 10:16:40 -0700 Subject: [PATCH] Split arch wasm filter code into prompt and llm gateway filters (#190) --- .pre-commit-config.yaml | 2 +- arch/envoy.template.yaml | 26 +- crates/llm_gateway/src/lib.rs | 8 +- crates/llm_gateway/src/llm_filter_context.rs | 108 ++ crates/llm_gateway/src/llm_stream_context.rs | 421 +++++ crates/llm_gateway/tests/integration.rs | 531 +----- crates/prompt_gateway/src/filter_context.rs | 322 ---- crates/prompt_gateway/src/lib.rs | 8 +- .../src/prompt_filter_context.rs} | 32 +- .../src/prompt_stream_context.rs} | 245 +-- crates/prompt_gateway/src/stream_context.rs | 1576 ----------------- crates/prompt_gateway/tests/integration.rs | 216 +-- gateway.code-workspace | 2 +- 13 files changed, 696 insertions(+), 2801 deletions(-) create mode 100644 crates/llm_gateway/src/llm_filter_context.rs create mode 100644 crates/llm_gateway/src/llm_stream_context.rs delete mode 100644 crates/prompt_gateway/src/filter_context.rs rename crates/{llm_gateway/src/filter_context.rs => prompt_gateway/src/prompt_filter_context.rs} (92%) rename crates/{llm_gateway/src/stream_context.rs => prompt_gateway/src/prompt_stream_context.rs} (86%) delete mode 100644 crates/prompt_gateway/src/stream_context.rs diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 868c7548..1e577bbf 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -13,7 +13,7 @@ repos: name: cargo-fmt language: system types: [file, rust] - entry: bash -c "cd crates/llm_gateway && cargo fmt -- --check" + entry: bash -c "cd crates/llm_gateway && cargo fmt" - id: cargo-clippy name: cargo-clippy diff --git a/arch/envoy.template.yaml b/arch/envoy.template.yaml index 300d71d1..4c716e3a 100644 --- a/arch/envoy.template.yaml +++ b/arch/envoy.template.yaml @@ -45,34 +45,12 @@ static_resources: domains: - "*" routes: - {% for provider in arch_llm_providers %} - match: prefix: "/" - headers: - - name: "x-arch-llm-provider" - string_match: - exact: {{ provider.name }} - route: - auto_host_rewrite: true - cluster: {{ provider.provider }} - timeout: 60s - {% endfor %} - - match: - prefix: "/" - headers: - - name: "x-arch-upstream" - string_match: - exact: arch_llm_listener route: auto_host_rewrite: true cluster: arch_llm_listener timeout: 60s - - match: - prefix: "/" - direct_response: - status: 400 - body: - inline_string: "x-arch-llm-provider or x-arch-upstream header not set, cannot perform routing\n" http_filters: - name: envoy.filters.http.wasm typed_config: @@ -232,7 +210,7 @@ static_resources: direct_response: status: 400 body: - inline_string: "x-arch-llm-provider header not set, cannot perform routing\n" + inline_string: "x-arch-llm-provider header not set, llm gateway cannot perform routing\n" http_filters: - name: envoy.filters.http.wasm typed_config: @@ -250,7 +228,7 @@ static_resources: runtime: "envoy.wasm.runtime.v8" code: local: - filename: "/etc/envoy/proxy-wasm-plugins/prompt_gateway.wasm" + filename: "/etc/envoy/proxy-wasm-plugins/llm_gateway.wasm" - name: envoy.filters.http.router typed_config: "@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router diff --git a/crates/llm_gateway/src/lib.rs b/crates/llm_gateway/src/lib.rs index e2ad9025..766d32bb 100644 --- a/crates/llm_gateway/src/lib.rs +++ b/crates/llm_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use filter_context::FilterContext; +use llm_filter_context::LlmGatewayFilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod filter_context; -mod stream_context; +mod llm_filter_context; +mod llm_stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(FilterContext::new()) + Box::new(LlmGatewayFilterContext::new()) }); }} diff --git a/crates/llm_gateway/src/llm_filter_context.rs b/crates/llm_gateway/src/llm_filter_context.rs new file mode 100644 index 00000000..e1ed2620 --- /dev/null +++ b/crates/llm_gateway/src/llm_filter_context.rs @@ -0,0 +1,108 @@ +use crate::llm_stream_context::LlmGatewayStreamContext; +use common::configuration::Configuration; +use common::http::Client; +use common::llm_providers::LlmProviders; +use common::ratelimit; +use common::stats::Counter; +use common::stats::Gauge; +use log::debug; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; +use std::cell::RefCell; +use std::collections::HashMap; +use std::rc::Rc; + +#[derive(Copy, Clone, Debug)] +pub struct WasmMetrics { + pub active_http_calls: Gauge, + pub ratelimited_rq: Counter, +} + +impl WasmMetrics { + fn new() -> WasmMetrics { + WasmMetrics { + active_http_calls: Gauge::new(String::from("active_http_calls")), + ratelimited_rq: Counter::new(String::from("ratelimited_rq")), + } + } +} + +#[derive(Debug)] +pub struct FilterCallContext {} + +#[derive(Debug)] +pub struct LlmGatewayFilterContext { + metrics: Rc, + // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. + callouts: RefCell>, + llm_providers: Option>, +} + +impl LlmGatewayFilterContext { + pub fn new() -> LlmGatewayFilterContext { + LlmGatewayFilterContext { + callouts: RefCell::new(HashMap::new()), + metrics: Rc::new(WasmMetrics::new()), + llm_providers: None, + } + } +} + +impl Client for LlmGatewayFilterContext { + type CallContext = FilterCallContext; + + fn callouts(&self) -> &RefCell> { + &self.callouts + } + + fn active_http_calls(&self) -> &Gauge { + &self.metrics.active_http_calls + } +} + +impl Context for LlmGatewayFilterContext {} + +// RootContext allows the Rust code to reach into the Envoy Config +impl RootContext for LlmGatewayFilterContext { + fn on_configure(&mut self, _: usize) -> bool { + let config_bytes = self + .get_plugin_configuration() + .expect("Arch config cannot be empty"); + + let config: Configuration = match serde_yaml::from_slice(&config_bytes) { + Ok(config) => config, + Err(err) => panic!("Invalid arch config \"{:?}\"", err), + }; + + ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default())); + + match config.llm_providers.try_into() { + Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), + Err(err) => panic!("{err}"), + } + + true + } + + fn create_http_context(&self, context_id: u32) -> Option> { + debug!( + "||| create_http_context called with context_id: {:?} |||", + context_id + ); + + // No StreamContext can be created until the Embedding Store is fully initialized. + Some(Box::new(LlmGatewayStreamContext::new( + context_id, + Rc::clone(&self.metrics), + Rc::clone( + self.llm_providers + .as_ref() + .expect("LLM Providers must exist when Streams are being created"), + ), + ))) + } + + fn get_type(&self) -> Option { + Some(ContextType::HttpContext) + } +} diff --git a/crates/llm_gateway/src/llm_stream_context.rs b/crates/llm_gateway/src/llm_stream_context.rs new file mode 100644 index 00000000..6c585a72 --- /dev/null +++ b/crates/llm_gateway/src/llm_stream_context.rs @@ -0,0 +1,421 @@ +use crate::llm_filter_context::WasmMetrics; +use common::common_types::open_ai::{ + ArchState, ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, + Message, ToolCall, ToolCallState, +}; +use common::configuration::LlmProvider; +use common::consts::{ + ARCH_PROVIDER_HINT_HEADER, ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, CHAT_COMPLETIONS_PATH, + RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, USER_ROLE, +}; +use common::llm_providers::LlmProviders; +use common::ratelimit::Header; +use common::{ratelimit, routing, tokenizer}; +use http::StatusCode; +use log::debug; +use proxy_wasm::traits::*; +use proxy_wasm::types::*; +use serde_json::Value; +use sha2::{Digest, Sha256}; +use std::num::NonZero; +use std::rc::Rc; + +use common::stats::IncrementingMetric; + +#[derive(thiserror::Error, Debug)] +pub enum ServerError { + #[error(transparent)] + Deserialization(serde_json::Error), + #[error("{0}")] + LogicError(String), + #[error(transparent)] + ExceededRatelimit(ratelimit::Error), + #[error("{why}")] + BadRequest { why: String }, +} + +pub struct LlmGatewayStreamContext { + context_id: u32, + metrics: Rc, + tool_calls: Option>, + tool_call_response: Option, + arch_state: Option>, + request_body_size: usize, + ratelimit_selector: Option
, + streaming_response: bool, + user_prompt: Option, + response_tokens: usize, + is_chat_completions_request: bool, + chat_completions_request: Option, + llm_providers: Rc, + llm_provider: Option>, + request_id: Option, +} + +impl LlmGatewayStreamContext { + #[allow(clippy::too_many_arguments)] + pub fn new(context_id: u32, metrics: Rc, llm_providers: Rc) -> Self { + LlmGatewayStreamContext { + context_id, + metrics, + chat_completions_request: None, + tool_calls: None, + tool_call_response: None, + arch_state: None, + request_body_size: 0, + ratelimit_selector: None, + streaming_response: false, + user_prompt: None, + response_tokens: 0, + is_chat_completions_request: false, + llm_providers, + llm_provider: None, + request_id: None, + } + } + fn llm_provider(&self) -> &LlmProvider { + self.llm_provider + .as_ref() + .expect("the provider should be set when asked for it") + } + + fn select_llm_provider(&mut self) { + let provider_hint = self + .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) + .map(|provider_name| provider_name.into()); + + debug!("llm provider hint: {:?}", provider_hint); + self.llm_provider = Some(routing::get_llm_provider( + &self.llm_providers, + provider_hint, + )); + debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name); + } + + fn modify_auth_headers(&mut self) -> Result<(), ServerError> { + let llm_provider_api_key_value = + self.llm_provider() + .access_key + .as_ref() + .ok_or(ServerError::BadRequest { + why: format!( + "No access key configured for selected LLM Provider \"{}\"", + self.llm_provider() + ), + })?; + + let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); + + self.set_http_request_header("Authorization", Some(&authorization_header_value)); + + Ok(()) + } + + fn delete_content_length_header(&mut self) { + // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. + // Server's generally throw away requests whose body length do not match the Content-Length header. + // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could + // manipulate the body in benign ways e.g., compression. + self.set_http_request_header("content-length", None); + } + + fn save_ratelimit_header(&mut self) { + self.ratelimit_selector = self + .get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY) + .and_then(|key| { + self.get_http_request_header(&key) + .map(|value| Header { key, value }) + }); + } + + fn send_server_error(&self, error: ServerError, override_status_code: Option) { + debug!("server error occurred: {}", error); + self.send_http_response( + override_status_code + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) + .as_u16() + .into(), + vec![], + Some(format!("{error}").as_bytes()), + ); + } + + fn enforce_ratelimits( + &mut self, + model: &str, + json_string: &str, + ) -> Result<(), ratelimit::Error> { + if let Some(selector) = self.ratelimit_selector.take() { + // Tokenize and Ratelimit. + if let Ok(token_count) = tokenizer::token_count(model, json_string) { + ratelimit::ratelimits(None).read().unwrap().check_limit( + model.to_owned(), + selector, + NonZero::new(token_count as u32).unwrap(), + )?; + } + } + Ok(()) + } +} + +// HttpContext is the trait that allows the Rust code to interact with HTTP objects. +impl HttpContext for LlmGatewayStreamContext { + // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto + // the lifecycle of the http request and response. + fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + self.select_llm_provider(); + self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name); + + if let Err(error) = self.modify_auth_headers() { + self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); + } + self.delete_content_length_header(); + self.save_ratelimit_header(); + + self.is_chat_completions_request = + self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; + + debug!( + "on_http_request_headers S[{}] req_headers={:?}", + self.context_id, + self.get_http_request_headers() + ); + + self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); + + Action::Continue + } + + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + // Let the client send the gateway all the data before sending to the LLM_provider. + // TODO: consider a streaming API. + if !end_of_stream { + return Action::Pause; + } + + if body_size == 0 { + return Action::Continue; + } + + self.request_body_size = body_size; + + // Deserialize body into spec. + // Currently OpenAI API. + let mut deserialized_body: ChatCompletionsRequest = + match self.get_http_request_body(0, body_size) { + Some(body_bytes) => match serde_json::from_slice(&body_bytes) { + Ok(deserialized) => deserialized, + Err(e) => { + self.send_server_error( + ServerError::Deserialization(e), + Some(StatusCode::BAD_REQUEST), + ); + return Action::Pause; + } + }, + None => { + self.send_server_error( + ServerError::LogicError(format!( + "Failed to obtain body bytes even though body_size is {}", + body_size + )), + None, + ); + return Action::Pause; + } + }; + self.is_chat_completions_request = true; + + // remove metadata from the request body + deserialized_body.metadata = None; + // delete model key from message array + for message in deserialized_body.messages.iter_mut() { + message.model = None; + } + + // override model name from the llm provider + deserialized_body + .model + .clone_from(&self.llm_provider.as_ref().unwrap().model); + let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap(); + + // enforce ratelimits on ingress + if let Err(e) = + self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str) + { + self.send_server_error( + ServerError::ExceededRatelimit(e), + Some(StatusCode::TOO_MANY_REQUESTS), + ); + self.metrics.ratelimited_rq.increment(1); + return Action::Continue; + } + + debug!( + "arch => {:?}, body: {}", + deserialized_body.model, chat_completion_request_str + ); + self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes()); + + Action::Continue + } + + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { + debug!( + "recv [S={}] bytes={} end_stream={}", + self.context_id, body_size, end_of_stream + ); + + if !self.is_chat_completions_request { + if let Some(body_str) = self + .get_http_response_body(0, body_size) + .and_then(|bytes| String::from_utf8(bytes).ok()) + { + debug!("recv [S={}] body_str={}", self.context_id, body_str); + } + return Action::Continue; + } + + if !end_of_stream { + return Action::Pause; + } + + let body = self + .get_http_response_body(0, body_size) + .expect("cant get response body"); + + if self.streaming_response { + let body_str = String::from_utf8(body).expect("body is not utf-8"); + debug!("streaming response"); + let chat_completions_data = match body_str.split_once("data: ") { + Some((_, chat_completions_data)) => chat_completions_data, + None => { + self.send_server_error( + ServerError::LogicError(String::from("parsing error in streaming data")), + None, + ); + return Action::Pause; + } + }; + + let chat_completions_chunk_response: ChatCompletionChunkResponse = + match serde_json::from_str(chat_completions_data) { + Ok(de) => de, + Err(_) => { + if chat_completions_data != "[NONE]" { + self.send_server_error( + ServerError::LogicError(String::from( + "error in streaming response", + )), + None, + ); + return Action::Continue; + } + return Action::Continue; + } + }; + + if let Some(content) = chat_completions_chunk_response + .choices + .first() + .unwrap() + .delta + .content + .as_ref() + { + let model = &chat_completions_chunk_response.model; + let token_count = tokenizer::token_count(model, content).unwrap_or(0); + self.response_tokens += token_count; + } + } else { + debug!("non streaming response"); + let chat_completions_response: ChatCompletionsResponse = + match serde_json::from_slice(&body) { + Ok(de) => de, + Err(e) => { + debug!("invalid response: {}", String::from_utf8_lossy(&body)); + self.send_server_error(ServerError::Deserialization(e), None); + return Action::Pause; + } + }; + + if chat_completions_response.usage.is_some() { + self.response_tokens += chat_completions_response + .usage + .as_ref() + .unwrap() + .completion_tokens; + } + + if let Some(tool_calls) = self.tool_calls.as_ref() { + if !tool_calls.is_empty() { + if self.arch_state.is_none() { + self.arch_state = Some(Vec::new()); + } + + // compute sha hash from message history + let mut hasher = Sha256::new(); + let prompts: Vec = self + .chat_completions_request + .as_ref() + .unwrap() + .messages + .iter() + .filter(|msg| msg.role == USER_ROLE) + .map(|msg| msg.content.clone().unwrap()) + .collect(); + let prompts_merged = prompts.join("#.#"); + hasher.update(prompts_merged.clone()); + let hash_key = hasher.finalize(); + // conver hash to hex string + let hash_key_str = format!("{:x}", hash_key); + debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged); + + // create new tool call state + let tool_call_state = ToolCallState { + key: hash_key_str, + message: self.user_prompt.clone(), + tool_call: tool_calls[0].function.clone(), + tool_response: self.tool_call_response.clone().unwrap(), + }; + + // push tool call state to arch state + self.arch_state + .as_mut() + .unwrap() + .push(ArchState::ToolCall(vec![tool_call_state])); + + let mut data: Value = serde_json::from_slice(&body).unwrap(); + // use serde::Value to manipulate the json object and ensure that we don't lose any data + if let Value::Object(ref mut map) = data { + // serialize arch state and add to metadata + let arch_state_str = serde_json::to_string(&self.arch_state).unwrap(); + debug!("arch_state: {}", arch_state_str); + let metadata = map + .entry("metadata") + .or_insert(Value::Object(serde_json::Map::new())); + metadata.as_object_mut().unwrap().insert( + ARCH_STATE_HEADER.to_string(), + serde_json::Value::String(arch_state_str), + ); + + let data_serialized = serde_json::to_string(&data).unwrap(); + debug!("arch => user: {}", data_serialized); + self.set_http_response_body(0, body_size, data_serialized.as_bytes()); + }; + } + } + } + + debug!( + "recv [S={}] total_tokens={} end_stream={}", + self.context_id, self.response_tokens, end_of_stream + ); + + // TODO:: ratelimit based on response tokens. + Action::Continue + } +} + +impl Context for LlmGatewayStreamContext {} diff --git a/crates/llm_gateway/tests/integration.rs b/crates/llm_gateway/tests/integration.rs index 5821a79a..80ff8d9f 100644 --- a/crates/llm_gateway/tests/integration.rs +++ b/crates/llm_gateway/tests/integration.rs @@ -1,19 +1,9 @@ -use common::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; -use common::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; -use common::common_types::{HallucinationClassificationResponse, PromptGuardResponse}; -use common::embeddings::{ - create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, - Embedding, -}; -use common::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; use http::StatusCode; use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ Action, BufferType, LogLevel, MapType, MetricType, ReturnType, }; -use serde_yaml::Value; use serial_test::serial; -use std::collections::HashMap; use std::path::Path; fn wasm_module() -> String { @@ -34,11 +24,6 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { ) .returning(Some("default")) .expect_log(Some(LogLevel::Debug), None) - .expect_add_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-upstream"), - Some("arch_llm_listener"), - ) .expect_add_header_map_value( Some(MapType::HttpRequestHeaders), Some("x-arch-llm-provider"), @@ -61,6 +46,8 @@ fn request_headers_expectations(module: &mut Tester, http_context: i32) { .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .returning(Some("/v1/chat/completions")) + .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) + .returning(None) .expect_log(Some(LogLevel::Debug), None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) .returning(None) @@ -76,181 +63,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .unwrap(); request_headers_expectations(module, http_context); - - // Request Body - let chat_completions_request_body = "\ -{\ - \"messages\": [\ - {\ - \"role\": \"system\",\ - \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ - },\ - {\ - \"role\": \"user\",\ - \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ - }\ - ],\ - \"model\": \"gpt-4\"\ -}"; - - module - .call_proxy_on_request_body( - http_context, - chat_completions_request_body.len() as i32, - true, - ) - .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) - .returning(Some(chat_completions_request_body)) - // The actual call is not important in this test, we just need to grab the token_id - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/guard"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(1)) - .expect_log(Some(LogLevel::Debug), None) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::Action(Action::Pause)) - .unwrap(); - - let prompt_guard_response = PromptGuardResponse { - toxic_prob: None, - toxic_verdict: None, - jailbreak_prob: None, - jailbreak_verdict: None, - }; - let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 1, - 0, - prompt_guard_response_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&prompt_guard_response_buffer)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(2)) - .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let embedding_response = CreateEmbeddingResponse { - data: vec![Embedding { - index: 0, - embedding: vec![], - object: embedding::Object::default(), - }], - model: String::from("test"), - object: create_embedding_response::Object::default(), - usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)), - }; - let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 2, - 0, - embeddings_response_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embeddings_response_buffer)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/zeroshot"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(3)) - .expect_metric_increment("active_http_calls", 1) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let zero_shot_response = ZeroShotClassificationResponse { - predicted_class: "weather_forecast".to_string(), - predicted_class_score: 0.1, - scores: HashMap::new(), - model: "test-model".to_string(), - }; - let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap(); - module - .call_proxy_on_http_call_response( - http_context, - 3, - 0, - zeroshot_intent_detection_buffer.len() as i32, - 0, - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&zeroshot_intent_detection_buffer)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Info), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - (":method", "POST"), - ("x-arch-upstream", "arch_fc"), - (":path", "/v1/chat/completions"), - (":authority", "arch_fc"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "120000"), - ]), - None, - None, - None, - ) - .returning(Some(4)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); } fn setup_filter(module: &mut Tester, config: &str) -> i32 { @@ -270,69 +82,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { .execute_and_expect(ReturnType::Bool(true)) .unwrap(); - module - .call_proxy_on_tick(filter_context) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(101)) - .expect_metric_increment("active_http_calls", 1) - .expect_set_tick_period_millis(Some(0)) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let embedding_response = CreateEmbeddingResponse { - data: vec![Embedding { - embedding: vec![], - index: 0, - object: embedding::Object::default(), - }], - model: String::from("test"), - object: create_embedding_response::Object::default(), - usage: Box::new(CreateEmbeddingResponseUsage { - prompt_tokens: 0, - total_tokens: 0, - }), - }; - let embedding_response_str = serde_json::to_string(&embedding_response).unwrap(); - module - .call_proxy_on_http_call_response( - filter_context, - 101, - 0, - embedding_response_str.len() as i32, - 0, - ) - .expect_log( - Some(LogLevel::Debug), - Some( - format!( - "filter_context: on_http_call_response called with token_id: {:?}", - 101 - ) - .as_str(), - ), - ) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&embedding_response_str)) - .expect_log(Some(LogLevel::Debug), None) - .execute_and_expect(ReturnType::None) - .unwrap(); - filter_context } @@ -357,6 +106,10 @@ llm_providers: access_key: secret_key model: gpt-4 default: true + - name: open-ai-gpt-4o + provider: openai + access_key: secret_key + model: gpt-4o overrides: # confidence threshold for prompt target intent matching @@ -396,7 +149,7 @@ ratelimits: key: selector-key value: selector-value limit: - tokens: 1 + tokens: 50 unit: minute "# } @@ -440,7 +193,7 @@ fn successful_request_to_open_ai_chat_completions() { },\ {\ \"role\": \"user\",\ - \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + \"content\": \"Compose a poem.\"\ }\ ],\ \"model\": \"gpt-4\"\ @@ -455,10 +208,10 @@ fn successful_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) - .expect_http_call(Some("arch_internal"), None, None, None, None) - .returning(Some(4)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::Action(Action::Pause)) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } @@ -547,111 +300,35 @@ fn request_ratelimited() { normal_flow(&mut module, filter_context, http_context); - let arch_fc_resp = ChatCompletionsResponse { - usage: Some(Usage { - completion_tokens: 0, - }), - choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, - message: Message { - role: "system".to_string(), - content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )]), - }, - }]), - model: None, - }, - }], - model: String::from("test"), - metadata: None, - }; + // Request Body + let chat_completions_request_body = "\ +{\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming. Compose a poem that explains the concept of recursion in programming. Compose a poem that explains the concept of recursion in programming. \"\ + }\ + ],\ + \"model\": \"gpt-4\"\ +}"; - let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&arch_fc_resp_str)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, ) - .returning(Some(5)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let hallucatination_body = HallucinationClassificationResponse { - params_scores: HashMap::from([("city".to_string(), 0.99)]), - model: "nli-model".to_string(), - }; - - let body_text = serde_json::to_string(&hallucatination_body).unwrap(); - - module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "api_server"), - (":method", "POST"), - (":path", "/weather"), - (":authority", "api_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ]), - None, - None, - None, - ) - .returning(Some(6)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let body_text = String::from("test body"); - module - .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) - .returning(Some("200")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) + // .expect_metric_increment("active_http_calls", 1) .expect_send_local_response( Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()), None, @@ -659,7 +336,7 @@ fn request_ratelimited() { None, ) .expect_metric_increment("ratelimited_rq", 1) - .execute_and_expect(ReturnType::None) + .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } @@ -679,127 +356,49 @@ fn request_not_ratelimited() { .unwrap(); // Setup Filter - let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap(); - config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000; - let config_str = serde_json::to_string(&config).unwrap(); - - let filter_context = setup_filter(&mut module, &config_str); + let filter_context = setup_filter(&mut module, default_config()); // Setup HTTP Stream let http_context = 2; normal_flow(&mut module, filter_context, http_context); - let arch_fc_resp = ChatCompletionsResponse { - usage: Some(Usage { - completion_tokens: 0, - }), - choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, - message: Message { - role: "system".to_string(), - content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )]), - }, - }]), - model: None, - }, - }], - model: String::from("test"), - metadata: None, - }; + // give shorter body to avoid rate limiting + let chat_completions_request_body = "\ +{\ + \"messages\": [\ + {\ + \"role\": \"system\",\ + \"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\ + },\ + {\ + \"role\": \"user\",\ + \"content\": \"Compose a poem that explains the concept of recursion in programming.\"\ + }\ + ],\ + \"model\": \"gpt-4\"\ +}"; - let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); module - .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&arch_fc_resp_str)) + .call_proxy_on_request_body( + http_context, + chat_completions_request_body.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) + .returning(Some(chat_completions_request_body)) + // The actual call is not important in this test, we just need to grab the token_id .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), + // .expect_metric_increment("active_http_calls", 1) + .expect_send_local_response( + Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()), None, None, None, ) - .returning(Some(5)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - // hallucination should return that parameters were not halliucinated - // prompt: str - // parameters: dict - // model: str - - let hallucatination_body = HallucinationClassificationResponse { - params_scores: HashMap::from([("city".to_string(), 0.99)]), - model: "nli-model".to_string(), - }; - - let body_text = serde_json::to_string(&hallucatination_body).unwrap(); - - module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "api_server"), - (":method", "POST"), - (":path", "/weather"), - (":authority", "api_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ]), - None, - None, - None, - ) - .returning(Some(6)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let body_text = String::from("test body"); - module - .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) - .returning(Some("200")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) - .execute_and_expect(ReturnType::None) + .expect_metric_increment("ratelimited_rq", 1) + .execute_and_expect(ReturnType::Action(Action::Continue)) .unwrap(); } diff --git a/crates/prompt_gateway/src/filter_context.rs b/crates/prompt_gateway/src/filter_context.rs deleted file mode 100644 index 5d0090a7..00000000 --- a/crates/prompt_gateway/src/filter_context.rs +++ /dev/null @@ -1,322 +0,0 @@ -use crate::stream_context::StreamContext; -use common::common_types::EmbeddingType; -use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; -use common::consts::ARCH_INTERNAL_CLUSTER_NAME; -use common::consts::ARCH_UPSTREAM_HOST_HEADER; -use common::consts::DEFAULT_EMBEDDING_MODEL; -use common::consts::MODEL_SERVER_NAME; -use common::embeddings::{ - CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, -}; -use common::http::CallArgs; -use common::http::Client; -use common::llm_providers::LlmProviders; -use common::ratelimit; -use common::stats::Counter; -use common::stats::Gauge; -use common::stats::IncrementingMetric; -use log::debug; -use proxy_wasm::traits::*; -use proxy_wasm::types::*; -use std::cell::RefCell; -use std::collections::hash_map::Entry; -use std::collections::HashMap; -use std::rc::Rc; -use std::time::Duration; - -#[derive(Copy, Clone, Debug)] -pub struct WasmMetrics { - pub active_http_calls: Gauge, - pub ratelimited_rq: Counter, -} - -impl WasmMetrics { - fn new() -> WasmMetrics { - WasmMetrics { - active_http_calls: Gauge::new(String::from("active_http_calls")), - ratelimited_rq: Counter::new(String::from("ratelimited_rq")), - } - } -} - -pub type EmbeddingTypeMap = HashMap>; -pub type EmbeddingsStore = HashMap; - -#[derive(Debug)] -pub struct FilterCallContext { - pub prompt_target_name: String, - pub embedding_type: EmbeddingType, -} - -#[derive(Debug)] -pub struct FilterContext { - metrics: Rc, - // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. - callouts: RefCell>, - overrides: Rc>, - system_prompt: Rc>, - prompt_targets: Rc>, - mode: GatewayMode, - prompt_guards: Rc, - llm_providers: Option>, - embeddings_store: Option>, - temp_embeddings_store: EmbeddingsStore, -} - -impl FilterContext { - pub fn new() -> FilterContext { - FilterContext { - callouts: RefCell::new(HashMap::new()), - metrics: Rc::new(WasmMetrics::new()), - system_prompt: Rc::new(None), - prompt_targets: Rc::new(HashMap::new()), - overrides: Rc::new(None), - prompt_guards: Rc::new(PromptGuards::default()), - mode: GatewayMode::Prompt, - llm_providers: None, - embeddings_store: Some(Rc::new(HashMap::new())), - temp_embeddings_store: HashMap::new(), - } - } - - fn process_prompt_targets(&self) { - for values in self.prompt_targets.iter() { - let prompt_target = values.1; - self.schedule_embeddings_call( - &prompt_target.name, - &prompt_target.description, - EmbeddingType::Description, - ); - } - } - - fn schedule_embeddings_call( - &self, - prompt_target_name: &str, - input: &str, - embedding_type: EmbeddingType, - ) { - let embeddings_input = CreateEmbeddingRequest { - input: Box::new(CreateEmbeddingRequestInput::String(String::from(input))), - model: String::from(DEFAULT_EMBEDDING_MODEL), - encoding_format: None, - dimensions: None, - user: None, - }; - let json_data = serde_json::to_string(&embeddings_input).unwrap(); - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/embeddings", - vec![ - (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", MODEL_SERVER_NAME), - ("content-type", "application/json"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ], - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(60), - ); - - let call_context = crate::filter_context::FilterCallContext { - prompt_target_name: String::from(prompt_target_name), - embedding_type, - }; - - if let Err(error) = self.http_call(call_args, call_context) { - panic!("{error}") - } - } - - fn embedding_response_handler( - &mut self, - body_size: usize, - embedding_type: EmbeddingType, - prompt_target_name: String, - ) { - let prompt_target = self - .prompt_targets - .get(&prompt_target_name) - .unwrap_or_else(|| { - panic!( - "Received embeddings response for unknown prompt target name={}", - prompt_target_name - ) - }); - - let body = self - .get_http_call_response_body(0, body_size) - .expect("No body in response"); - if !body.is_empty() { - let mut embedding_response: CreateEmbeddingResponse = - match serde_json::from_slice(&body) { - Ok(response) => response, - Err(e) => { - panic!( - "Error deserializing embedding response. body: {:?}: {:?}", - String::from_utf8(body).unwrap(), - e - ); - } - }; - - let embeddings = embedding_response.data.remove(0).embedding; - debug!( - "Adding embeddings for prompt target name: {:?}, description: {:?}, embedding type: {:?}", - prompt_target.name, - prompt_target.description, - embedding_type - ); - - let entry = self.temp_embeddings_store.entry(prompt_target_name); - match entry { - Entry::Occupied(_) => { - entry.and_modify(|e| { - if let Entry::Vacant(e) = e.entry(embedding_type) { - e.insert(embeddings); - } else { - panic!( - "Duplicate {:?} for prompt target with name=\"{}\"", - &embedding_type, prompt_target.name - ) - } - }); - } - Entry::Vacant(_) => { - entry.or_insert(HashMap::from([(embedding_type, embeddings)])); - } - } - - if self.prompt_targets.len() == self.temp_embeddings_store.len() { - self.embeddings_store = - Some(Rc::new(std::mem::take(&mut self.temp_embeddings_store))) - } - } - } -} - -impl Client for FilterContext { - type CallContext = FilterCallContext; - - fn callouts(&self) -> &RefCell> { - &self.callouts - } - - fn active_http_calls(&self) -> &Gauge { - &self.metrics.active_http_calls - } -} - -impl Context for FilterContext { - fn on_http_call_response( - &mut self, - token_id: u32, - _num_headers: usize, - body_size: usize, - _num_trailers: usize, - ) { - debug!( - "filter_context: on_http_call_response called with token_id: {:?}", - token_id - ); - let callout_data = self - .callouts - .borrow_mut() - .remove(&token_id) - .expect("invalid token_id"); - - self.metrics.active_http_calls.increment(-1); - - self.embedding_response_handler( - body_size, - callout_data.embedding_type, - callout_data.prompt_target_name, - ) - } -} - -// RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for FilterContext { - fn on_configure(&mut self, _: usize) -> bool { - let config_bytes = self - .get_plugin_configuration() - .expect("Arch config cannot be empty"); - - let config: Configuration = match serde_yaml::from_slice(&config_bytes) { - Ok(config) => config, - Err(err) => panic!("Invalid arch config \"{:?}\"", err), - }; - - self.overrides = Rc::new(config.overrides); - - let mut prompt_targets = HashMap::new(); - for pt in config.prompt_targets { - prompt_targets.insert(pt.name.clone(), pt.clone()); - } - self.system_prompt = Rc::new(config.system_prompt); - self.prompt_targets = Rc::new(prompt_targets); - self.mode = config.mode.unwrap_or_default(); - - ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default())); - - if let Some(prompt_guards) = config.prompt_guards { - self.prompt_guards = Rc::new(prompt_guards) - } - - match config.llm_providers.try_into() { - Ok(llm_providers) => self.llm_providers = Some(Rc::new(llm_providers)), - Err(err) => panic!("{err}"), - } - - true - } - - fn create_http_context(&self, context_id: u32) -> Option> { - debug!( - "||| create_http_context called with context_id: {:?} |||", - context_id - ); - - // No StreamContext can be created until the Embedding Store is fully initialized. - let embedding_store = match self.mode { - GatewayMode::Llm => None, - GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())), - }; - Some(Box::new(StreamContext::new( - context_id, - Rc::clone(&self.metrics), - Rc::clone(&self.system_prompt), - Rc::clone(&self.prompt_targets), - Rc::clone(&self.prompt_guards), - Rc::clone(&self.overrides), - Rc::clone( - self.llm_providers - .as_ref() - .expect("LLM Providers must exist when Streams are being created"), - ), - embedding_store, - self.mode.clone(), - ))) - } - - fn get_type(&self) -> Option { - Some(ContextType::HttpContext) - } - - fn on_vm_start(&mut self, _: usize) -> bool { - self.set_tick_period(Duration::from_secs(1)); - true - } - - fn on_tick(&mut self) { - debug!("starting up arch filter in mode: {:?}", self.mode); - if self.mode == GatewayMode::Prompt { - self.process_prompt_targets(); - } - - self.set_tick_period(Duration::from_secs(0)); - } -} diff --git a/crates/prompt_gateway/src/lib.rs b/crates/prompt_gateway/src/lib.rs index e2ad9025..75edea5d 100644 --- a/crates/prompt_gateway/src/lib.rs +++ b/crates/prompt_gateway/src/lib.rs @@ -1,13 +1,13 @@ -use filter_context::FilterContext; +use prompt_filter_context::PromptGatewayFilterContext; use proxy_wasm::traits::*; use proxy_wasm::types::*; -mod filter_context; -mod stream_context; +mod prompt_filter_context; +mod prompt_stream_context; proxy_wasm::main! {{ proxy_wasm::set_log_level(LogLevel::Trace); proxy_wasm::set_root_context(|_| -> Box { - Box::new(FilterContext::new()) + Box::new(PromptGatewayFilterContext::new()) }); }} diff --git a/crates/llm_gateway/src/filter_context.rs b/crates/prompt_gateway/src/prompt_filter_context.rs similarity index 92% rename from crates/llm_gateway/src/filter_context.rs rename to crates/prompt_gateway/src/prompt_filter_context.rs index 5d0090a7..0c25ee5c 100644 --- a/crates/llm_gateway/src/filter_context.rs +++ b/crates/prompt_gateway/src/prompt_filter_context.rs @@ -1,4 +1,4 @@ -use crate::stream_context::StreamContext; +use crate::prompt_stream_context::PromptStreamContext; use common::common_types::EmbeddingType; use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget}; use common::consts::ARCH_INTERNAL_CLUSTER_NAME; @@ -11,8 +11,6 @@ use common::embeddings::{ use common::http::CallArgs; use common::http::Client; use common::llm_providers::LlmProviders; -use common::ratelimit; -use common::stats::Counter; use common::stats::Gauge; use common::stats::IncrementingMetric; use log::debug; @@ -27,14 +25,12 @@ use std::time::Duration; #[derive(Copy, Clone, Debug)] pub struct WasmMetrics { pub active_http_calls: Gauge, - pub ratelimited_rq: Counter, } impl WasmMetrics { fn new() -> WasmMetrics { WasmMetrics { active_http_calls: Gauge::new(String::from("active_http_calls")), - ratelimited_rq: Counter::new(String::from("ratelimited_rq")), } } } @@ -49,7 +45,7 @@ pub struct FilterCallContext { } #[derive(Debug)] -pub struct FilterContext { +pub struct PromptGatewayFilterContext { metrics: Rc, // callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request. callouts: RefCell>, @@ -63,9 +59,9 @@ pub struct FilterContext { temp_embeddings_store: EmbeddingsStore, } -impl FilterContext { - pub fn new() -> FilterContext { - FilterContext { +impl PromptGatewayFilterContext { + pub fn new() -> PromptGatewayFilterContext { + PromptGatewayFilterContext { callouts: RefCell::new(HashMap::new()), metrics: Rc::new(WasmMetrics::new()), system_prompt: Rc::new(None), @@ -121,7 +117,7 @@ impl FilterContext { Duration::from_secs(60), ); - let call_context = crate::filter_context::FilterCallContext { + let call_context = crate::prompt_filter_context::FilterCallContext { prompt_target_name: String::from(prompt_target_name), embedding_type, }; @@ -198,7 +194,7 @@ impl FilterContext { } } -impl Client for FilterContext { +impl Client for PromptGatewayFilterContext { type CallContext = FilterCallContext; fn callouts(&self) -> &RefCell> { @@ -210,7 +206,7 @@ impl Client for FilterContext { } } -impl Context for FilterContext { +impl Context for PromptGatewayFilterContext { fn on_http_call_response( &mut self, token_id: u32, @@ -239,7 +235,7 @@ impl Context for FilterContext { } // RootContext allows the Rust code to reach into the Envoy Config -impl RootContext for FilterContext { +impl RootContext for PromptGatewayFilterContext { fn on_configure(&mut self, _: usize) -> bool { let config_bytes = self .get_plugin_configuration() @@ -260,8 +256,6 @@ impl RootContext for FilterContext { self.prompt_targets = Rc::new(prompt_targets); self.mode = config.mode.unwrap_or_default(); - ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default())); - if let Some(prompt_guards) = config.prompt_guards { self.prompt_guards = Rc::new(prompt_guards) } @@ -285,20 +279,14 @@ impl RootContext for FilterContext { GatewayMode::Llm => None, GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())), }; - Some(Box::new(StreamContext::new( + Some(Box::new(PromptStreamContext::new( context_id, Rc::clone(&self.metrics), Rc::clone(&self.system_prompt), Rc::clone(&self.prompt_targets), Rc::clone(&self.prompt_guards), Rc::clone(&self.overrides), - Rc::clone( - self.llm_providers - .as_ref() - .expect("LLM Providers must exist when Streams are being created"), - ), embedding_store, - self.mode.clone(), ))) } diff --git a/crates/llm_gateway/src/stream_context.rs b/crates/prompt_gateway/src/prompt_stream_context.rs similarity index 86% rename from crates/llm_gateway/src/stream_context.rs rename to crates/prompt_gateway/src/prompt_stream_context.rs index 5e4e6149..d208f5e8 100644 --- a/crates/llm_gateway/src/stream_context.rs +++ b/crates/prompt_gateway/src/prompt_stream_context.rs @@ -1,33 +1,28 @@ -use crate::filter_context::{EmbeddingsStore, WasmMetrics}; +use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics}; use acap::cos; use common::common_types::open_ai::{ - ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, - ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters, - Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType, + ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice, + FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType, + StreamOptions, ToolCall, ToolCallState, ToolType, }; use common::common_types::{ EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; -use common::configuration::{GatewayMode, LlmProvider}; use common::configuration::{Overrides, PromptGuards, PromptTarget}; use common::consts::{ - ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, - ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER, - ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, + ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY, + ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, - RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, + REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, }; use common::embeddings::{ CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, }; use common::http::{CallArgs, Client, ClientError}; -use common::llm_providers::LlmProviders; -use common::ratelimit::Header; use common::stats::Gauge; -use common::{ratelimit, routing, tokenizer}; use http::StatusCode; use log::{debug, info, warn}; use proxy_wasm::traits::*; @@ -36,7 +31,6 @@ use serde_json::Value; use sha2::{Digest, Sha256}; use std::cell::RefCell; use std::collections::HashMap; -use std::num::NonZero; use std::rc::Rc; use std::time::Duration; @@ -81,17 +75,13 @@ pub enum ServerError { path: String, status: String, }, - #[error(transparent)] - ExceededRatelimit(ratelimit::Error), #[error("jailbreak detected: {0}")] Jailbreak(String), #[error("{why}")] - BadRequest { why: String }, - #[error("{why}")] NoMessagesFound { why: String }, } -pub struct StreamContext { +pub struct PromptStreamContext { context_id: u32, metrics: Rc, system_prompt: Rc>, @@ -103,20 +93,16 @@ pub struct StreamContext { tool_call_response: Option, arch_state: Option>, request_body_size: usize, - ratelimit_selector: Option
, streaming_response: bool, user_prompt: Option, response_tokens: usize, is_chat_completions_request: bool, chat_completions_request: Option, prompt_guards: Rc, - llm_providers: Rc, - llm_provider: Option>, request_id: Option, - mode: GatewayMode, } -impl StreamContext { +impl PromptStreamContext { #[allow(clippy::too_many_arguments)] pub fn new( context_id: u32, @@ -125,11 +111,9 @@ impl StreamContext { prompt_targets: Rc>, prompt_guards: Rc, overrides: Rc>, - llm_providers: Rc, embeddings_store: Option>, - mode: GatewayMode, ) -> Self { - StreamContext { + PromptStreamContext { context_id, metrics, system_prompt, @@ -141,75 +125,21 @@ impl StreamContext { tool_call_response: None, arch_state: None, request_body_size: 0, - ratelimit_selector: None, streaming_response: false, user_prompt: None, response_tokens: 0, is_chat_completions_request: false, - llm_providers, - llm_provider: None, prompt_guards, overrides, request_id: None, - mode, } } - fn llm_provider(&self) -> &LlmProvider { - self.llm_provider - .as_ref() - .expect("the provider should be set when asked for it") - } - fn embeddings_store(&self) -> &EmbeddingsStore { self.embeddings_store .as_ref() .expect("embeddings store is not set") } - fn select_llm_provider(&mut self) { - let provider_hint = self - .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) - .map(|provider_name| provider_name.into()); - - debug!("llm provider hint: {:?}", provider_hint); - self.llm_provider = Some(routing::get_llm_provider( - &self.llm_providers, - provider_hint, - )); - debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name); - } - - fn add_routing_header(&mut self) { - match self.mode { - GatewayMode::Prompt => { - // in prompt gateway mode, we need to route to llm upstream listener - self.add_http_request_header(ARCH_UPSTREAM_HOST_HEADER, ARCH_LLM_UPSTREAM_LISTENER); - } - _ => { - self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name); - } - } - } - - fn modify_auth_headers(&mut self) -> Result<(), ServerError> { - let llm_provider_api_key_value = - self.llm_provider() - .access_key - .as_ref() - .ok_or(ServerError::BadRequest { - why: format!( - "No access key configured for selected LLM Provider \"{}\"", - self.llm_provider() - ), - })?; - - let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); - - self.set_http_request_header("Authorization", Some(&authorization_header_value)); - - Ok(()) - } - fn delete_content_length_header(&mut self) { // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. // Server's generally throw away requests whose body length do not match the Content-Length header. @@ -218,15 +148,6 @@ impl StreamContext { self.set_http_request_header("content-length", None); } - fn save_ratelimit_header(&mut self) { - self.ratelimit_selector = self - .get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY) - .and_then(|key| { - self.get_http_request_header(&key) - .map(|value| Header { key, value }) - }); - } - fn send_server_error(&self, error: ServerError, override_status_code: Option) { debug!("server error occurred: {}", error); self.send_http_response( @@ -682,6 +603,7 @@ impl StreamContext { } let tool_calls = model_resp.message.tool_calls.as_ref().unwrap(); + self.tool_calls = Some(tool_calls.clone()); // TODO CO: pass nli check // If hallucination, pass chat template to check parameters @@ -954,40 +876,12 @@ impl StreamContext { return self.send_server_error(ServerError::Serialization(e), None); } }; - debug!("arch => openai request body: {}", json_string); - - // Tokenize and Ratelimit. - if let Err(e) = self.enforce_ratelimits(&chat_completions_request.model, &json_string) { - self.send_server_error( - ServerError::ExceededRatelimit(e), - Some(StatusCode::TOO_MANY_REQUESTS), - ); - self.metrics.ratelimited_rq.increment(1); - return; - } + debug!("arch => upstream llm request body: {}", json_string); self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes()); self.resume_http_request(); } - fn enforce_ratelimits( - &mut self, - model: &str, - json_string: &str, - ) -> Result<(), ratelimit::Error> { - if let Some(selector) = self.ratelimit_selector.take() { - // Tokenize and Ratelimit. - if let Ok(token_count) = tokenizer::token_count(model, json_string) { - ratelimit::ratelimits(None).read().unwrap().check_limit( - model.to_owned(), - selector, - NonZero::new(token_count as u32).unwrap(), - )?; - } - } - Ok(()) - } - fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { debug!("response received for arch guard"); let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); @@ -1137,23 +1031,17 @@ impl StreamContext { } // HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for StreamContext { +impl HttpContext for PromptStreamContext { // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto // the lifecycle of the http request and response. fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - self.select_llm_provider(); - self.add_routing_header(); - if let Err(error) = self.modify_auth_headers() { - self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); - } self.delete_content_length_header(); - self.save_ratelimit_header(); self.is_chat_completions_request = self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; debug!( - "S[{}] req_headers={:?}", + "on_http_request_headers S[{}] req_headers={:?}", self.context_id, self.get_http_request_headers() ); @@ -1176,6 +1064,11 @@ impl HttpContext for StreamContext { self.request_body_size = body_size; + debug!( + "on_http_request_body S[{}] body_size={}", + self.context_id, body_size + ); + // Deserialize body into spec. // Currently OpenAI API. let mut deserialized_body: ChatCompletionsRequest = @@ -1203,40 +1096,6 @@ impl HttpContext for StreamContext { }; self.is_chat_completions_request = true; - if self.mode == GatewayMode::Llm { - debug!("llm gateway mode, skipping over all prompt targets"); - - // remove metadata from the request body - deserialized_body.metadata = None; - // delete model key from message array - for message in deserialized_body.messages.iter_mut() { - message.model = None; - } - deserialized_body - .model - .clone_from(&self.llm_provider.as_ref().unwrap().model); - let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap(); - - // enforce ratelimits - if let Err(e) = - self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str) - { - self.send_server_error( - ServerError::ExceededRatelimit(e), - Some(StatusCode::TOO_MANY_REQUESTS), - ); - self.metrics.ratelimited_rq.increment(1); - return Action::Continue; - } - - debug!( - "arch => {:?}, body: {}", - deserialized_body.model, chat_completion_request_str - ); - self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes()); - return Action::Continue; - } - self.arch_state = match deserialized_body.metadata { Some(ref metadata) => { if metadata.contains_key(ARCH_STATE_HEADER) { @@ -1250,9 +1109,6 @@ impl HttpContext for StreamContext { None => None, }; - // Set the model based on the chosen LLM Provider - deserialized_body.model = String::from(&self.llm_provider().model); - self.streaming_response = deserialized_body.stream; if deserialized_body.stream && deserialized_body.stream_options.is_none() { deserialized_body.stream_options = Some(StreamOptions { @@ -1285,7 +1141,7 @@ impl HttpContext for StreamContext { self.chat_completions_request = Some(deserialized_body); if !prompt_guard_jailbreak_task { - debug!("Missing input guard. Making inline call to retrieve"); + debug!("Missing input guard. Making inline call to retrieve embeddings"); let callout_context = StreamCallContext { response_handler_type: ResponseHandlerType::ArchGuard, user_message: user_message_str.clone(), @@ -1360,6 +1216,18 @@ impl HttpContext for StreamContext { Action::Pause } + fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { + debug!( + "on_http_response_headers recv [S={}] headers={:?}", + self.context_id, + self.get_http_response_headers() + ); + // delete content-lenght header let envoy calculate it, because we modify the response body + // that would result in a different content-length + self.set_http_response_header("content-length", None); + Action::Continue + } + fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { debug!( "recv [S={}] bytes={} end_stream={}", @@ -1385,48 +1253,7 @@ impl HttpContext for StreamContext { .expect("cant get response body"); if self.streaming_response { - let body_str = String::from_utf8(body).expect("body is not utf-8"); debug!("streaming response"); - let chat_completions_data = match body_str.split_once("data: ") { - Some((_, chat_completions_data)) => chat_completions_data, - None => { - self.send_server_error( - ServerError::LogicError(String::from("parsing error in streaming data")), - None, - ); - return Action::Pause; - } - }; - - let chat_completions_chunk_response: ChatCompletionChunkResponse = - match serde_json::from_str(chat_completions_data) { - Ok(de) => de, - Err(_) => { - if chat_completions_data != "[NONE]" { - self.send_server_error( - ServerError::LogicError(String::from( - "error in streaming response", - )), - None, - ); - return Action::Continue; - } - return Action::Continue; - } - }; - - if let Some(content) = chat_completions_chunk_response - .choices - .first() - .unwrap() - .delta - .content - .as_ref() - { - let model = &chat_completions_chunk_response.model; - let token_count = tokenizer::token_count(model, content).unwrap_or(0); - self.response_tokens += token_count; - } } else { debug!("non streaming response"); let chat_completions_response: ChatCompletionsResponse = @@ -1494,6 +1321,9 @@ impl HttpContext for StreamContext { let metadata = map .entry("metadata") .or_insert(Value::Object(serde_json::Map::new())); + if metadata == &Value::Null { + *metadata = Value::Object(serde_json::Map::new()); + } metadata.as_object_mut().unwrap().insert( ARCH_STATE_HEADER.to_string(), serde_json::Value::String(arch_state_str), @@ -1512,12 +1342,11 @@ impl HttpContext for StreamContext { self.context_id, self.response_tokens, end_of_stream ); - // TODO:: ratelimit based on response tokens. Action::Continue } } -impl Context for StreamContext { +impl Context for PromptStreamContext { fn on_http_call_response( &mut self, token_id: u32, @@ -1563,7 +1392,7 @@ impl Context for StreamContext { } } -impl Client for StreamContext { +impl Client for PromptStreamContext { type CallContext = StreamCallContext; fn callouts(&self) -> &RefCell> { diff --git a/crates/prompt_gateway/src/stream_context.rs b/crates/prompt_gateway/src/stream_context.rs deleted file mode 100644 index 5e4e6149..00000000 --- a/crates/prompt_gateway/src/stream_context.rs +++ /dev/null @@ -1,1576 +0,0 @@ -use crate::filter_context::{EmbeddingsStore, WasmMetrics}; -use acap::cos; -use common::common_types::open_ai::{ - ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, - ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters, - Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType, -}; -use common::common_types::{ - EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse, - PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest, - ZeroShotClassificationResponse, -}; -use common::configuration::{GatewayMode, LlmProvider}; -use common::configuration::{Overrides, PromptGuards, PromptTarget}; -use common::consts::{ - ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, - ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER, - ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER, - CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD, - DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME, - RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE, -}; -use common::embeddings::{ - CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse, -}; -use common::http::{CallArgs, Client, ClientError}; -use common::llm_providers::LlmProviders; -use common::ratelimit::Header; -use common::stats::Gauge; -use common::{ratelimit, routing, tokenizer}; -use http::StatusCode; -use log::{debug, info, warn}; -use proxy_wasm::traits::*; -use proxy_wasm::types::*; -use serde_json::Value; -use sha2::{Digest, Sha256}; -use std::cell::RefCell; -use std::collections::HashMap; -use std::num::NonZero; -use std::rc::Rc; -use std::time::Duration; - -use common::stats::IncrementingMetric; - -#[derive(Debug, Clone)] -enum ResponseHandlerType { - GetEmbeddings, - FunctionResolver, - FunctionCall, - ZeroShotIntent, - HallucinationDetect, - ArchGuard, - DefaultTarget, -} - -#[derive(Debug, Clone)] -pub struct StreamCallContext { - response_handler_type: ResponseHandlerType, - user_message: Option, - prompt_target_name: Option, - request_body: ChatCompletionsRequest, - tool_calls: Option>, - similarity_scores: Option>, - upstream_cluster: Option, - upstream_cluster_path: Option, -} - -#[derive(thiserror::Error, Debug)] -pub enum ServerError { - #[error(transparent)] - HttpDispatch(ClientError), - #[error(transparent)] - Deserialization(serde_json::Error), - #[error(transparent)] - Serialization(serde_json::Error), - #[error("{0}")] - LogicError(String), - #[error("upstream error response authority={authority}, path={path}, status={status}")] - Upstream { - authority: String, - path: String, - status: String, - }, - #[error(transparent)] - ExceededRatelimit(ratelimit::Error), - #[error("jailbreak detected: {0}")] - Jailbreak(String), - #[error("{why}")] - BadRequest { why: String }, - #[error("{why}")] - NoMessagesFound { why: String }, -} - -pub struct StreamContext { - context_id: u32, - metrics: Rc, - system_prompt: Rc>, - prompt_targets: Rc>, - embeddings_store: Option>, - overrides: Rc>, - callouts: RefCell>, - tool_calls: Option>, - tool_call_response: Option, - arch_state: Option>, - request_body_size: usize, - ratelimit_selector: Option
, - streaming_response: bool, - user_prompt: Option, - response_tokens: usize, - is_chat_completions_request: bool, - chat_completions_request: Option, - prompt_guards: Rc, - llm_providers: Rc, - llm_provider: Option>, - request_id: Option, - mode: GatewayMode, -} - -impl StreamContext { - #[allow(clippy::too_many_arguments)] - pub fn new( - context_id: u32, - metrics: Rc, - system_prompt: Rc>, - prompt_targets: Rc>, - prompt_guards: Rc, - overrides: Rc>, - llm_providers: Rc, - embeddings_store: Option>, - mode: GatewayMode, - ) -> Self { - StreamContext { - context_id, - metrics, - system_prompt, - prompt_targets, - embeddings_store, - callouts: RefCell::new(HashMap::new()), - chat_completions_request: None, - tool_calls: None, - tool_call_response: None, - arch_state: None, - request_body_size: 0, - ratelimit_selector: None, - streaming_response: false, - user_prompt: None, - response_tokens: 0, - is_chat_completions_request: false, - llm_providers, - llm_provider: None, - prompt_guards, - overrides, - request_id: None, - mode, - } - } - fn llm_provider(&self) -> &LlmProvider { - self.llm_provider - .as_ref() - .expect("the provider should be set when asked for it") - } - - fn embeddings_store(&self) -> &EmbeddingsStore { - self.embeddings_store - .as_ref() - .expect("embeddings store is not set") - } - - fn select_llm_provider(&mut self) { - let provider_hint = self - .get_http_request_header(ARCH_PROVIDER_HINT_HEADER) - .map(|provider_name| provider_name.into()); - - debug!("llm provider hint: {:?}", provider_hint); - self.llm_provider = Some(routing::get_llm_provider( - &self.llm_providers, - provider_hint, - )); - debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name); - } - - fn add_routing_header(&mut self) { - match self.mode { - GatewayMode::Prompt => { - // in prompt gateway mode, we need to route to llm upstream listener - self.add_http_request_header(ARCH_UPSTREAM_HOST_HEADER, ARCH_LLM_UPSTREAM_LISTENER); - } - _ => { - self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name); - } - } - } - - fn modify_auth_headers(&mut self) -> Result<(), ServerError> { - let llm_provider_api_key_value = - self.llm_provider() - .access_key - .as_ref() - .ok_or(ServerError::BadRequest { - why: format!( - "No access key configured for selected LLM Provider \"{}\"", - self.llm_provider() - ), - })?; - - let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value); - - self.set_http_request_header("Authorization", Some(&authorization_header_value)); - - Ok(()) - } - - fn delete_content_length_header(&mut self) { - // Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it. - // Server's generally throw away requests whose body length do not match the Content-Length header. - // However, a missing Content-Length header is not grounds for bad requests given that intermediary hops could - // manipulate the body in benign ways e.g., compression. - self.set_http_request_header("content-length", None); - } - - fn save_ratelimit_header(&mut self) { - self.ratelimit_selector = self - .get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY) - .and_then(|key| { - self.get_http_request_header(&key) - .map(|value| Header { key, value }) - }); - } - - fn send_server_error(&self, error: ServerError, override_status_code: Option) { - debug!("server error occurred: {}", error); - self.send_http_response( - override_status_code - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR) - .as_u16() - .into(), - vec![], - Some(format!("{error}").as_bytes()), - ); - } - - fn embeddings_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { - let embedding_response: CreateEmbeddingResponse = match serde_json::from_slice(&body) { - Ok(embedding_response) => embedding_response, - Err(e) => { - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - - let prompt_embeddings_vector = &embedding_response.data[0].embedding; - - debug!( - "embedding model: {}, vector length: {:?}", - embedding_response.model, - prompt_embeddings_vector.len() - ); - - let prompt_target_names = self - .prompt_targets - .iter() - // exclude default target - .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) - .map(|(name, _)| name.clone()) - .collect(); - - let similarity_scores: Vec<(String, f64)> = self - .prompt_targets - .iter() - // exclude default prompt target - .filter(|(_, prompt_target)| !prompt_target.default.unwrap_or(false)) - .map(|(prompt_name, _)| { - let pte = match self.embeddings_store().get(prompt_name) { - Some(embeddings) => embeddings, - None => { - warn!( - "embeddings not found for prompt target name: {}", - prompt_name - ); - return (prompt_name.clone(), f64::NAN); - } - }; - - let description_embeddings = match pte.get(&EmbeddingType::Description) { - Some(embeddings) => embeddings, - None => { - warn!( - "description embeddings not found for prompt target name: {}", - prompt_name - ); - return (prompt_name.clone(), f64::NAN); - } - }; - let similarity_score_description = - cos::cosine_similarity(&prompt_embeddings_vector, &description_embeddings); - (prompt_name.clone(), similarity_score_description) - }) - .collect(); - - debug!( - "similarity scores based on description embeddings match: {:?}", - similarity_scores - ); - - callout_context.similarity_scores = Some(similarity_scores); - - let zero_shot_classification_request = ZeroShotClassificationRequest { - // Need to clone into input because user_message is used below. - input: callout_context.user_message.as_ref().unwrap().clone(), - model: String::from(DEFAULT_INTENT_MODEL), - labels: prompt_target_names, - }; - - let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { - Ok(json_data) => json_data, - Err(error) => { - return self.send_server_error(ServerError::Serialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), - (":method", "POST"), - (":path", "/zeroshot"), - (":authority", MODEL_SERVER_NAME), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/zeroshot", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::ZeroShotIntent; - - if let Err(e) = self.http_call(call_args, callout_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } - - fn hallucination_classification_resp_handler( - &mut self, - body: Vec, - callout_context: StreamCallContext, - ) { - let hallucination_response: HallucinationClassificationResponse = - match serde_json::from_slice(&body) { - Ok(hallucination_response) => hallucination_response, - Err(e) => { - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - let mut keys_with_low_score: Vec = Vec::new(); - for (key, value) in &hallucination_response.params_scores { - if *value < DEFAULT_HALLUCINATED_THRESHOLD { - debug!( - "hallucination detected: score for {} : {} is less than threshold {}", - key, value, DEFAULT_HALLUCINATED_THRESHOLD - ); - keys_with_low_score.push(key.clone().to_string()); - } - } - - if !keys_with_low_score.is_empty() { - let response = - "It seems I’m missing some information. Could you provide the following details: " - .to_string() - + &keys_with_low_score.join(", ") - + " ?"; - let message = Message { - role: SYSTEM_ROLE.to_string(), - content: Some(response), - model: Some(ARCH_FC_MODEL_NAME.to_string()), - tool_calls: None, - }; - - let chat_completion_response = ChatCompletionsResponse { - choices: vec![Choice { - message, - index: 0, - finish_reason: "done".to_string(), - }], - usage: None, - model: ARCH_FC_MODEL_NAME.to_string(), - metadata: None, - }; - - debug!("hallucination response: {:?}", chat_completion_response); - self.send_http_response( - StatusCode::OK.as_u16().into(), - vec![("Powered-By", "Katanemo")], - Some( - serde_json::to_string(&chat_completion_response) - .unwrap() - .as_bytes(), - ), - ); - } else { - // not a hallucination, resume the flow - self.schedule_api_call_request(callout_context); - } - } - - fn zero_shot_intent_detection_resp_handler( - &mut self, - body: Vec, - mut callout_context: StreamCallContext, - ) { - let zeroshot_intent_response: ZeroShotClassificationResponse = - match serde_json::from_slice(&body) { - Ok(zeroshot_response) => zeroshot_response, - Err(e) => { - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - - debug!("zeroshot intent response: {:?}", zeroshot_intent_response); - - let desc_emb_similarity_map: HashMap = callout_context - .similarity_scores - .clone() - .unwrap() - .into_iter() - .collect(); - - let pred_class_desc_emb_similarity = desc_emb_similarity_map - .get(&zeroshot_intent_response.predicted_class) - .unwrap(); - - let prompt_target_similarity_score = zeroshot_intent_response.predicted_class_score * 0.7 - + pred_class_desc_emb_similarity * 0.3; - - debug!( - "similarity score: {:.3}, intent score: {:.3}, description embedding score: {:.3}, prompt: {}", - prompt_target_similarity_score, - zeroshot_intent_response.predicted_class_score, - pred_class_desc_emb_similarity, - callout_context.user_message.as_ref().unwrap() - ); - - let prompt_target_name = zeroshot_intent_response.predicted_class.clone(); - - // Check to see who responded to user message. This will help us identify if control should be passed to Arch FC or not. - // If the last message was from Arch FC, then Arch FC is handling the conversation (possibly for parameter collection). - let mut arch_assistant = false; - let messages = &callout_context.request_body.messages; - if messages.len() >= 2 { - let latest_assistant_message = &messages[messages.len() - 2]; - if let Some(model) = latest_assistant_message.model.as_ref() { - if model.contains(ARCH_MODEL_PREFIX) { - arch_assistant = true; - } - } - } else { - info!("no assistant message found, probably first interaction"); - } - - // get prompt target similarity thresold from overrides - let prompt_target_intent_matching_threshold = match self.overrides.as_ref() { - Some(overrides) => match overrides.prompt_target_intent_matching_threshold { - Some(threshold) => threshold, - None => DEFAULT_PROMPT_TARGET_THRESHOLD, - }, - None => DEFAULT_PROMPT_TARGET_THRESHOLD, - }; - - // check to ensure that the prompt target similarity score is above the threshold - if prompt_target_similarity_score < prompt_target_intent_matching_threshold - || arch_assistant - { - debug!("intent score is low or arch assistant is handling the conversation"); - // if arch fc responded to the user message, then we don't need to check the similarity score - // it may be that arch fc is handling the conversation for parameter collection - if arch_assistant { - info!("arch assistant is handling the conversation"); - } else { - debug!("checking for default prompt target"); - if let Some(default_prompt_target) = self - .prompt_targets - .values() - .find(|pt| pt.default.unwrap_or(false)) - { - debug!("default prompt target found"); - let endpoint = default_prompt_target.endpoint.clone().unwrap(); - let upstream_path: String = endpoint.path.unwrap_or(String::from("/")); - - let upstream_endpoint = endpoint.name; - let mut params = HashMap::new(); - params.insert( - ARCH_MESSAGES_KEY.to_string(), - callout_context.request_body.messages.clone(), - ); - let arch_messages_json = serde_json::to_string(¶ms).unwrap(); - debug!("no prompt target found with similarity score above threshold, using default prompt target"); - - let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); - - let mut headers = vec![ - (":method", "POST"), - (ARCH_UPSTREAM_HOST_HEADER, &upstream_endpoint), - (":path", &upstream_path), - (":authority", &upstream_endpoint), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - &upstream_path, - headers, - Some(arch_messages_json.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::DefaultTarget; - callout_context.prompt_target_name = Some(default_prompt_target.name.clone()); - - if let Err(e) = self.http_call(call_args, callout_context) { - return self.send_server_error( - ServerError::HttpDispatch(e), - Some(StatusCode::BAD_REQUEST), - ); - } - } - - self.resume_http_request(); - return; - } - } - - let prompt_target = match self.prompt_targets.get(&prompt_target_name) { - Some(prompt_target) => prompt_target.clone(), - None => { - return self.send_server_error( - ServerError::LogicError(format!( - "Prompt target not found: {prompt_target_name}" - )), - None, - ); - } - }; - - info!("prompt_target name: {:?}", prompt_target_name); - let mut chat_completion_tools: Vec = Vec::new(); - for pt in self.prompt_targets.values() { - if pt.default.unwrap_or_default() { - continue; - } - // only extract entity names - let properties: HashMap = match pt.parameters { - // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. - Some(ref entities) => { - let mut properties: HashMap = HashMap::new(); - for entity in entities.iter() { - let param = FunctionParameter { - parameter_type: ParameterType::from( - entity.parameter_type.clone().unwrap_or("str".to_string()), - ), - description: entity.description.clone(), - required: entity.required, - enum_values: entity.enum_values.clone(), - default: entity.default.clone(), - }; - properties.insert(entity.name.clone(), param); - } - properties - } - None => HashMap::new(), - }; - let tools_parameters = FunctionParameters { properties }; - - chat_completion_tools.push({ - ChatCompletionTool { - tool_type: ToolType::Function, - function: FunctionDefinition { - name: pt.name.clone(), - description: pt.description.clone(), - parameters: tools_parameters, - }, - } - }); - } - - // archfc handler needs state so it can expand tool calls - let mut metadata = HashMap::new(); - metadata.insert( - ARCH_STATE_HEADER.to_string(), - serde_json::to_string(&self.arch_state).unwrap(), - ); - - let chat_completions = ChatCompletionsRequest { - model: GPT_35_TURBO.to_string(), - messages: callout_context.request_body.messages.clone(), - tools: Some(chat_completion_tools), - stream: false, - stream_options: None, - metadata: Some(metadata), - }; - - let msg_body = match serde_json::to_string(&chat_completions) { - Ok(msg_body) => { - debug!("arch_fc request body content: {}", msg_body); - msg_body - } - Err(e) => { - return self.send_server_error(ServerError::Serialization(e), None); - } - }; - - let timeout_str = ARCH_FC_REQUEST_TIMEOUT_MS.to_string(); - - let mut headers = vec![ - (":method", "POST"), - (ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER), - (":path", "/v1/chat/completions"), - (":authority", ARC_FC_CLUSTER), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", timeout_str.as_str()), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/v1/chat/completions", - headers, - Some(msg_body.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::FunctionResolver; - callout_context.prompt_target_name = Some(prompt_target.name); - - if let Err(e) = self.http_call(call_args, callout_context) { - self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); - } - } - - fn function_resolver_handler(&mut self, body: Vec, mut callout_context: StreamCallContext) { - let body_str = String::from_utf8(body).unwrap(); - debug!("arch <= app response body: {}", body_str); - - let arch_fc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { - Ok(arch_fc_response) => arch_fc_response, - Err(e) => { - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - - let model_resp = &arch_fc_response.choices[0]; - - if model_resp.message.tool_calls.is_none() - || model_resp.message.tool_calls.as_ref().unwrap().is_empty() - { - // This means that Arch FC did not have enough information to resolve the function call - // Arch FC probably responded with a message asking for more information. - // Let's send the response back to the user to initalize lightweight dialog for parameter collection - - //TODO: add resolver name to the response so the client can send the response back to the correct resolver - - return self.send_http_response( - StatusCode::OK.as_u16().into(), - vec![("Powered-By", "Katanemo")], - Some(body_str.as_bytes()), - ); - } - - let tool_calls = model_resp.message.tool_calls.as_ref().unwrap(); - - // TODO CO: pass nli check - // If hallucination, pass chat template to check parameters - - // extract all tool names - let tool_names: Vec = tool_calls - .iter() - .map(|tool_call| tool_call.function.name.clone()) - .collect(); - - debug!( - "call context similarity score: {:?}", - callout_context.similarity_scores - ); - //HACK: for now we only support one tool call, we will support multiple tool calls in the future - let mut tool_params = tool_calls[0].function.arguments.clone(); - tool_params.insert( - String::from(ARCH_MESSAGES_KEY), - serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), - ); - - let tools_call_name = tool_calls[0].function.name.clone(); - let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); - let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - callout_context.tool_calls = Some(tool_calls.clone()); - - debug!( - "prompt_target_name: {}, tool_name(s): {:?}", - prompt_target.name, tool_names - ); - debug!("tool_params: {}", tool_params_json_str); - - if model_resp.message.tool_calls.is_some() - && !model_resp.message.tool_calls.as_ref().unwrap().is_empty() - { - use serde_json::Value; - let v: Value = serde_json::from_str(&tool_params_json_str).unwrap(); - let tool_params_dict: HashMap = match v.as_object() { - Some(obj) => obj - .iter() - .map(|(key, value)| { - // Convert each value to a string, regardless of its type - (key.clone(), value.to_string()) - }) - .collect(), - None => HashMap::new(), // Return an empty HashMap if v is not an object - }; - - let messages = &callout_context.request_body.messages; - let mut arch_assistant = false; - let mut user_messages = Vec::new(); - - if messages.len() >= 2 { - let latest_assistant_message = &messages[messages.len() - 2]; - if let Some(model) = latest_assistant_message.model.as_ref() { - if model.starts_with(ARCH_MODEL_PREFIX) { - arch_assistant = true; - } - } - } - if arch_assistant { - for message in messages.iter() { - if let Some(model) = message.model.as_ref() { - if !model.starts_with(ARCH_MODEL_PREFIX) { - break; - } - } - if message.role == "user" { - if let Some(content) = &message.content { - user_messages.push(content.clone()); - } - } - } - } else if let Some(user_message) = callout_context.user_message.as_ref() { - user_messages.push(user_message.clone()); - } - let user_messages_str = user_messages.join(", "); - debug!("user messages: {}", user_messages_str); - - let hallucination_classification_request = HallucinationClassificationRequest { - prompt: user_messages_str, - model: String::from(DEFAULT_INTENT_MODEL), - parameters: tool_params_dict, - }; - - let json_data: String = - match serde_json::to_string(&hallucination_classification_request) { - Ok(json_data) => json_data, - Err(error) => { - return self.send_server_error(ServerError::Serialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", MODEL_SERVER_NAME), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/hallucination", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.response_handler_type = ResponseHandlerType::HallucinationDetect; - - if let Err(e) = self.http_call(call_args, callout_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } else { - self.schedule_api_call_request(callout_context); - } - } - - fn schedule_api_call_request(&mut self, mut callout_context: StreamCallContext) { - let tools_call_name = callout_context.tool_calls.as_ref().unwrap()[0] - .function - .name - .clone(); - - let prompt_target = self.prompt_targets.get(&tools_call_name).unwrap().clone(); - - //HACK: for now we only support one tool call, we will support multiple tool calls in the future - let mut tool_params = callout_context.tool_calls.as_ref().unwrap()[0] - .function - .arguments - .clone(); - tool_params.insert( - String::from(ARCH_MESSAGES_KEY), - serde_yaml::to_value(&callout_context.request_body.messages).unwrap(), - ); - - let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); - - let endpoint = prompt_target.endpoint.unwrap(); - let path: String = endpoint.path.unwrap_or(String::from("/")); - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, endpoint.name.as_str()), - (":method", "POST"), - (":path", &path), - (":authority", endpoint.name.as_str()), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - &path, - headers, - Some(tool_params_json_str.as_bytes()), - vec![], - Duration::from_secs(5), - ); - callout_context.upstream_cluster = Some(endpoint.name.clone()); - callout_context.upstream_cluster_path = Some(path.clone()); - callout_context.response_handler_type = ResponseHandlerType::FunctionCall; - - if let Err(e) = self.http_call(call_args, callout_context) { - self.send_server_error(ServerError::HttpDispatch(e), Some(StatusCode::BAD_REQUEST)); - } - } - - fn function_call_response_handler( - &mut self, - body: Vec, - mut callout_context: StreamCallContext, - ) { - if let Some(http_status) = self.get_http_call_response_header(":status") { - if http_status != StatusCode::OK.as_str() { - return self.send_server_error( - ServerError::Upstream { - authority: callout_context.upstream_cluster.unwrap(), - path: callout_context.upstream_cluster_path.unwrap(), - status: http_status, - }, - None, - ); - } - } else { - warn!("http status code not found in api response"); - } - let app_function_call_response_str: String = String::from_utf8(body).unwrap(); - self.tool_call_response = Some(app_function_call_response_str.clone()); - debug!( - "arch <= app response body: {}", - app_function_call_response_str - ); - let prompt_target_name = callout_context.prompt_target_name.unwrap(); - let prompt_target = self - .prompt_targets - .get(&prompt_target_name) - .unwrap() - .clone(); - - let mut messages: Vec = Vec::new(); - - // add system prompt - let system_prompt = match prompt_target.system_prompt.as_ref() { - None => self.system_prompt.as_ref().clone(), - Some(system_prompt) => Some(system_prompt.clone()), - }; - if system_prompt.is_some() { - let system_prompt_message = Message { - role: SYSTEM_ROLE.to_string(), - content: system_prompt, - model: None, - tool_calls: None, - }; - messages.push(system_prompt_message); - } - - messages.append(callout_context.request_body.messages.as_mut()); - - let user_message = match messages.pop() { - Some(user_message) => user_message, - None => { - return self.send_server_error( - ServerError::NoMessagesFound { - why: "no user messages found".to_string(), - }, - None, - ); - } - }; - - let final_prompt = format!( - "{}\ncontext: {}", - user_message.content.unwrap(), - app_function_call_response_str - ); - - // add original user prompt - messages.push({ - Message { - role: USER_ROLE.to_string(), - content: Some(final_prompt), - model: None, - tool_calls: None, - } - }); - - let chat_completions_request: ChatCompletionsRequest = ChatCompletionsRequest { - model: callout_context.request_body.model, - messages, - tools: None, - stream: callout_context.request_body.stream, - stream_options: callout_context.request_body.stream_options, - metadata: None, - }; - - let json_string = match serde_json::to_string(&chat_completions_request) { - Ok(json_string) => json_string, - Err(e) => { - return self.send_server_error(ServerError::Serialization(e), None); - } - }; - debug!("arch => openai request body: {}", json_string); - - // Tokenize and Ratelimit. - if let Err(e) = self.enforce_ratelimits(&chat_completions_request.model, &json_string) { - self.send_server_error( - ServerError::ExceededRatelimit(e), - Some(StatusCode::TOO_MANY_REQUESTS), - ); - self.metrics.ratelimited_rq.increment(1); - return; - } - - self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes()); - self.resume_http_request(); - } - - fn enforce_ratelimits( - &mut self, - model: &str, - json_string: &str, - ) -> Result<(), ratelimit::Error> { - if let Some(selector) = self.ratelimit_selector.take() { - // Tokenize and Ratelimit. - if let Ok(token_count) = tokenizer::token_count(model, json_string) { - ratelimit::ratelimits(None).read().unwrap().check_limit( - model.to_owned(), - selector, - NonZero::new(token_count as u32).unwrap(), - )?; - } - } - Ok(()) - } - - fn arch_guard_handler(&mut self, body: Vec, callout_context: StreamCallContext) { - debug!("response received for arch guard"); - let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap(); - debug!("prompt_guard_resp: {:?}", prompt_guard_resp); - - if prompt_guard_resp.jailbreak_verdict.unwrap_or_default() { - //TODO: handle other scenarios like forward to error target - let msg = self - .prompt_guards - .jailbreak_on_exception_message() - .unwrap_or("refrain from discussing jailbreaking."); - return self.send_server_error( - ServerError::Jailbreak(String::from(msg)), - Some(StatusCode::BAD_REQUEST), - ); - } - - self.get_embeddings(callout_context); - } - - fn get_embeddings(&mut self, callout_context: StreamCallContext) { - let user_message = callout_context.user_message.unwrap(); - let get_embeddings_input = CreateEmbeddingRequest { - // Need to clone into input because user_message is used below. - input: Box::new(CreateEmbeddingRequestInput::String(user_message.clone())), - model: String::from(DEFAULT_EMBEDDING_MODEL), - encoding_format: None, - dimensions: None, - user: None, - }; - - let json_data: String = match serde_json::to_string(&get_embeddings_input) { - Ok(json_data) => json_data, - Err(error) => { - return self.send_server_error(ServerError::Deserialization(error), None); - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), - (":method", "POST"), - (":path", "/embeddings"), - (":authority", MODEL_SERVER_NAME), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/embeddings", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::GetEmbeddings, - user_message: Some(user_message), - prompt_target_name: None, - request_body: callout_context.request_body, - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - - if let Err(e) = self.http_call(call_args, call_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } - } - - fn default_target_handler(&self, body: Vec, callout_context: StreamCallContext) { - let prompt_target = self - .prompt_targets - .get(callout_context.prompt_target_name.as_ref().unwrap()) - .unwrap() - .clone(); - debug!( - "response received for default target: {}", - prompt_target.name - ); - // check if the default target should be dispatched to the LLM provider - if !prompt_target.auto_llm_dispatch_on_response.unwrap_or(false) { - let default_target_response_str = String::from_utf8(body).unwrap(); - debug!( - "sending response back to developer: {}", - default_target_response_str - ); - self.send_http_response( - StatusCode::OK.as_u16().into(), - vec![("Powered-By", "Katanemo")], - Some(default_target_response_str.as_bytes()), - ); - // self.resume_http_request(); - return; - } - debug!("default_target: sending api response to default llm"); - let chat_completions_resp: ChatCompletionsResponse = match serde_json::from_slice(&body) { - Ok(chat_completions_resp) => chat_completions_resp, - Err(e) => { - return self.send_server_error(ServerError::Deserialization(e), None); - } - }; - let api_resp = chat_completions_resp.choices[0] - .message - .content - .as_ref() - .unwrap(); - let mut messages = callout_context.request_body.messages; - - // add system prompt - match prompt_target.system_prompt.as_ref() { - None => {} - Some(system_prompt) => { - let system_prompt_message = Message { - role: SYSTEM_ROLE.to_string(), - content: Some(system_prompt.clone()), - model: None, - tool_calls: None, - }; - messages.push(system_prompt_message); - } - } - - messages.push(Message { - role: USER_ROLE.to_string(), - content: Some(api_resp.clone()), - model: None, - tool_calls: None, - }); - let chat_completion_request = ChatCompletionsRequest { - model: GPT_35_TURBO.to_string(), - messages, - tools: None, - stream: callout_context.request_body.stream, - stream_options: callout_context.request_body.stream_options, - metadata: None, - }; - let json_resp = serde_json::to_string(&chat_completion_request).unwrap(); - debug!("sending response back to default llm: {}", json_resp); - self.set_http_request_body(0, self.request_body_size, json_resp.as_bytes()); - self.resume_http_request(); - } -} - -// HttpContext is the trait that allows the Rust code to interact with HTTP objects. -impl HttpContext for StreamContext { - // Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto - // the lifecycle of the http request and response. - fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action { - self.select_llm_provider(); - self.add_routing_header(); - if let Err(error) = self.modify_auth_headers() { - self.send_server_error(error, Some(StatusCode::BAD_REQUEST)); - } - self.delete_content_length_header(); - self.save_ratelimit_header(); - - self.is_chat_completions_request = - self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH; - - debug!( - "S[{}] req_headers={:?}", - self.context_id, - self.get_http_request_headers() - ); - - self.request_id = self.get_http_request_header(REQUEST_ID_HEADER); - - Action::Continue - } - - fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - // Let the client send the gateway all the data before sending to the LLM_provider. - // TODO: consider a streaming API. - if !end_of_stream { - return Action::Pause; - } - - if body_size == 0 { - return Action::Continue; - } - - self.request_body_size = body_size; - - // Deserialize body into spec. - // Currently OpenAI API. - let mut deserialized_body: ChatCompletionsRequest = - match self.get_http_request_body(0, body_size) { - Some(body_bytes) => match serde_json::from_slice(&body_bytes) { - Ok(deserialized) => deserialized, - Err(e) => { - self.send_server_error( - ServerError::Deserialization(e), - Some(StatusCode::BAD_REQUEST), - ); - return Action::Pause; - } - }, - None => { - self.send_server_error( - ServerError::LogicError(format!( - "Failed to obtain body bytes even though body_size is {}", - body_size - )), - None, - ); - return Action::Pause; - } - }; - self.is_chat_completions_request = true; - - if self.mode == GatewayMode::Llm { - debug!("llm gateway mode, skipping over all prompt targets"); - - // remove metadata from the request body - deserialized_body.metadata = None; - // delete model key from message array - for message in deserialized_body.messages.iter_mut() { - message.model = None; - } - deserialized_body - .model - .clone_from(&self.llm_provider.as_ref().unwrap().model); - let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap(); - - // enforce ratelimits - if let Err(e) = - self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str) - { - self.send_server_error( - ServerError::ExceededRatelimit(e), - Some(StatusCode::TOO_MANY_REQUESTS), - ); - self.metrics.ratelimited_rq.increment(1); - return Action::Continue; - } - - debug!( - "arch => {:?}, body: {}", - deserialized_body.model, chat_completion_request_str - ); - self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes()); - return Action::Continue; - } - - self.arch_state = match deserialized_body.metadata { - Some(ref metadata) => { - if metadata.contains_key(ARCH_STATE_HEADER) { - let arch_state_str = metadata[ARCH_STATE_HEADER].clone(); - let arch_state: Vec = serde_json::from_str(&arch_state_str).unwrap(); - Some(arch_state) - } else { - None - } - } - None => None, - }; - - // Set the model based on the chosen LLM Provider - deserialized_body.model = String::from(&self.llm_provider().model); - - self.streaming_response = deserialized_body.stream; - if deserialized_body.stream && deserialized_body.stream_options.is_none() { - deserialized_body.stream_options = Some(StreamOptions { - include_usage: true, - }); - } - - let last_user_prompt = match deserialized_body - .messages - .iter() - .filter(|msg| msg.role == USER_ROLE) - .last() - { - Some(content) => content, - None => { - warn!("No messages in the request body"); - return Action::Continue; - } - }; - - self.user_prompt = Some(last_user_prompt.clone()); - - let user_message_str = self.user_prompt.as_ref().unwrap().content.clone(); - - let prompt_guard_jailbreak_task = self - .prompt_guards - .input_guards - .contains_key(&common::configuration::GuardType::Jailbreak); - - self.chat_completions_request = Some(deserialized_body); - - if !prompt_guard_jailbreak_task { - debug!("Missing input guard. Making inline call to retrieve"); - let callout_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: user_message_str.clone(), - prompt_target_name: None, - request_body: self.chat_completions_request.as_ref().unwrap().clone(), - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - self.get_embeddings(callout_context); - return Action::Pause; - } - - let get_prompt_guards_request = PromptGuardRequest { - input: self - .user_prompt - .as_ref() - .unwrap() - .content - .as_ref() - .unwrap() - .clone(), - task: PromptGuardTask::Jailbreak, - }; - - let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { - Ok(json_data) => json_data, - Err(error) => { - self.send_server_error(ServerError::Serialization(error), None); - return Action::Pause; - } - }; - - let mut headers = vec![ - (ARCH_UPSTREAM_HOST_HEADER, MODEL_SERVER_NAME), - (":method", "POST"), - (":path", "/guard"), - (":authority", MODEL_SERVER_NAME), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]; - - if self.request_id.is_some() { - headers.push((REQUEST_ID_HEADER, self.request_id.as_ref().unwrap())); - } - - let call_args = CallArgs::new( - ARCH_INTERNAL_CLUSTER_NAME, - "/guard", - headers, - Some(json_data.as_bytes()), - vec![], - Duration::from_secs(5), - ); - let call_context = StreamCallContext { - response_handler_type: ResponseHandlerType::ArchGuard, - user_message: self.user_prompt.as_ref().unwrap().content.clone(), - prompt_target_name: None, - request_body: self.chat_completions_request.as_ref().unwrap().clone(), - similarity_scores: None, - upstream_cluster: None, - upstream_cluster_path: None, - tool_calls: None, - }; - - if let Err(e) = self.http_call(call_args, call_context) { - self.send_server_error(ServerError::HttpDispatch(e), None); - } - - Action::Pause - } - - fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { - debug!( - "recv [S={}] bytes={} end_stream={}", - self.context_id, body_size, end_of_stream - ); - - if !self.is_chat_completions_request { - if let Some(body_str) = self - .get_http_response_body(0, body_size) - .and_then(|bytes| String::from_utf8(bytes).ok()) - { - debug!("recv [S={}] body_str={}", self.context_id, body_str); - } - return Action::Continue; - } - - if !end_of_stream { - return Action::Pause; - } - - let body = self - .get_http_response_body(0, body_size) - .expect("cant get response body"); - - if self.streaming_response { - let body_str = String::from_utf8(body).expect("body is not utf-8"); - debug!("streaming response"); - let chat_completions_data = match body_str.split_once("data: ") { - Some((_, chat_completions_data)) => chat_completions_data, - None => { - self.send_server_error( - ServerError::LogicError(String::from("parsing error in streaming data")), - None, - ); - return Action::Pause; - } - }; - - let chat_completions_chunk_response: ChatCompletionChunkResponse = - match serde_json::from_str(chat_completions_data) { - Ok(de) => de, - Err(_) => { - if chat_completions_data != "[NONE]" { - self.send_server_error( - ServerError::LogicError(String::from( - "error in streaming response", - )), - None, - ); - return Action::Continue; - } - return Action::Continue; - } - }; - - if let Some(content) = chat_completions_chunk_response - .choices - .first() - .unwrap() - .delta - .content - .as_ref() - { - let model = &chat_completions_chunk_response.model; - let token_count = tokenizer::token_count(model, content).unwrap_or(0); - self.response_tokens += token_count; - } - } else { - debug!("non streaming response"); - let chat_completions_response: ChatCompletionsResponse = - match serde_json::from_slice(&body) { - Ok(de) => de, - Err(e) => { - debug!("invalid response: {}", String::from_utf8_lossy(&body)); - self.send_server_error(ServerError::Deserialization(e), None); - return Action::Pause; - } - }; - - if chat_completions_response.usage.is_some() { - self.response_tokens += chat_completions_response - .usage - .as_ref() - .unwrap() - .completion_tokens; - } - - if let Some(tool_calls) = self.tool_calls.as_ref() { - if !tool_calls.is_empty() { - if self.arch_state.is_none() { - self.arch_state = Some(Vec::new()); - } - - // compute sha hash from message history - let mut hasher = Sha256::new(); - let prompts: Vec = self - .chat_completions_request - .as_ref() - .unwrap() - .messages - .iter() - .filter(|msg| msg.role == USER_ROLE) - .map(|msg| msg.content.clone().unwrap()) - .collect(); - let prompts_merged = prompts.join("#.#"); - hasher.update(prompts_merged.clone()); - let hash_key = hasher.finalize(); - // conver hash to hex string - let hash_key_str = format!("{:x}", hash_key); - debug!("hash key: {}, prompts: {}", hash_key_str, prompts_merged); - - // create new tool call state - let tool_call_state = ToolCallState { - key: hash_key_str, - message: self.user_prompt.clone(), - tool_call: tool_calls[0].function.clone(), - tool_response: self.tool_call_response.clone().unwrap(), - }; - - // push tool call state to arch state - self.arch_state - .as_mut() - .unwrap() - .push(ArchState::ToolCall(vec![tool_call_state])); - - let mut data: Value = serde_json::from_slice(&body).unwrap(); - // use serde::Value to manipulate the json object and ensure that we don't lose any data - if let Value::Object(ref mut map) = data { - // serialize arch state and add to metadata - let arch_state_str = serde_json::to_string(&self.arch_state).unwrap(); - debug!("arch_state: {}", arch_state_str); - let metadata = map - .entry("metadata") - .or_insert(Value::Object(serde_json::Map::new())); - metadata.as_object_mut().unwrap().insert( - ARCH_STATE_HEADER.to_string(), - serde_json::Value::String(arch_state_str), - ); - - let data_serialized = serde_json::to_string(&data).unwrap(); - debug!("arch => user: {}", data_serialized); - self.set_http_response_body(0, body_size, data_serialized.as_bytes()); - }; - } - } - } - - debug!( - "recv [S={}] total_tokens={} end_stream={}", - self.context_id, self.response_tokens, end_of_stream - ); - - // TODO:: ratelimit based on response tokens. - Action::Continue - } -} - -impl Context for StreamContext { - fn on_http_call_response( - &mut self, - token_id: u32, - _num_headers: usize, - body_size: usize, - _num_trailers: usize, - ) { - let callout_context = self - .callouts - .get_mut() - .remove(&token_id) - .expect("invalid token_id"); - self.metrics.active_http_calls.increment(-1); - - if let Some(body) = self.get_http_call_response_body(0, body_size) { - match callout_context.response_handler_type { - ResponseHandlerType::GetEmbeddings => { - self.embeddings_handler(body, callout_context) - } - ResponseHandlerType::ZeroShotIntent => { - self.zero_shot_intent_detection_resp_handler(body, callout_context) - } - ResponseHandlerType::HallucinationDetect => { - self.hallucination_classification_resp_handler(body, callout_context) - } - ResponseHandlerType::FunctionResolver => { - self.function_resolver_handler(body, callout_context) - } - ResponseHandlerType::FunctionCall => { - self.function_call_response_handler(body, callout_context) - } - ResponseHandlerType::ArchGuard => self.arch_guard_handler(body, callout_context), - ResponseHandlerType::DefaultTarget => { - self.default_target_handler(body, callout_context) - } - } - } else { - self.send_server_error( - ServerError::LogicError(String::from("No response body in inline HTTP request")), - None, - ); - } - } -} - -impl Client for StreamContext { - type CallContext = StreamCallContext; - - fn callouts(&self) -> &RefCell> { - &self.callouts - } - - fn active_http_calls(&self) -> &Gauge { - &self.metrics.active_http_calls - } -} diff --git a/crates/prompt_gateway/tests/integration.rs b/crates/prompt_gateway/tests/integration.rs index 2e9e984e..04168305 100644 --- a/crates/prompt_gateway/tests/integration.rs +++ b/crates/prompt_gateway/tests/integration.rs @@ -28,39 +28,11 @@ fn wasm_module() -> String { fn request_headers_expectations(module: &mut Tester, http_context: i32) { module .call_proxy_on_request_headers(http_context, 0, false) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-llm-provider-hint"), - ) - .returning(Some("default")) - .expect_log(Some(LogLevel::Debug), None) - .expect_add_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-upstream"), - Some("arch_llm_listener"), - ) - .expect_add_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-llm-provider"), - Some("open-ai-gpt-4"), - ) - .expect_replace_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("Authorization"), - Some("Bearer secret_key"), - ) .expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length")) - .expect_get_header_map_value( - Some(MapType::HttpRequestHeaders), - Some("x-arch-ratelimit-selector"), - ) - .returning(Some("selector-key")) - .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key")) - .returning(Some("selector-value")) - .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) - .returning(None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path")) .returning(Some("/v1/chat/completions")) + .expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders)) + .returning(None) .expect_log(Some(LogLevel::Debug), None) .expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id")) .returning(None) @@ -102,6 +74,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) // The actual call is not important in this test, we just need to grab the token_id + .expect_log(Some(LogLevel::Debug), None) .expect_http_call( Some("arch_internal"), Some(vec![ @@ -259,7 +232,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 { module .call_proxy_on_context_create(filter_context, 0) .expect_metric_creation(MetricType::Gauge, "active_http_calls") - .expect_metric_creation(MetricType::Counter, "ratelimited_rq") .execute_and_expect(ReturnType::None) .unwrap(); @@ -455,6 +427,7 @@ fn successful_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_http_call(Some("arch_internal"), None, None, None, None) .returning(Some(4)) .expect_metric_increment("active_http_calls", 1) @@ -514,6 +487,7 @@ fn bad_request_to_open_ai_chat_completions() { .expect_get_buffer_bytes(Some(BufferType::HttpRequestBody)) .returning(Some(incomplete_chat_completions_request_body)) .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) .expect_send_local_response( Some(StatusCode::BAD_REQUEST.as_u16().into()), None, @@ -526,146 +500,7 @@ fn bad_request_to_open_ai_chat_completions() { #[test] #[serial] -fn request_ratelimited() { - let args = tester::MockSettings { - wasm_path: wasm_module(), - quiet: false, - allow_unexpected: false, - }; - let mut module = tester::mock(args).unwrap(); - - module - .call_start() - .execute_and_expect(ReturnType::None) - .unwrap(); - - // Setup Filter - let filter_context = setup_filter(&mut module, default_config()); - - // Setup HTTP Stream - let http_context = 2; - - normal_flow(&mut module, filter_context, http_context); - - let arch_fc_resp = ChatCompletionsResponse { - usage: Some(Usage { - completion_tokens: 0, - }), - choices: vec![Choice { - finish_reason: "test".to_string(), - index: 0, - message: Message { - role: "system".to_string(), - content: None, - tool_calls: Some(vec![ToolCall { - id: String::from("test"), - tool_type: ToolType::Function, - function: FunctionCallDetail { - name: String::from("weather_forecast"), - arguments: HashMap::from([( - String::from("city"), - Value::String(String::from("seattle")), - )]), - }, - }]), - model: None, - }, - }], - model: String::from("test"), - metadata: None, - }; - - let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap(); - module - .call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&arch_fc_resp_str)) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "model_server"), - (":method", "POST"), - (":path", "/hallucination"), - (":authority", "model_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ("x-envoy-upstream-rq-timeout-ms", "60000"), - ]), - None, - None, - None, - ) - .returning(Some(5)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let hallucatination_body = HallucinationClassificationResponse { - params_scores: HashMap::from([("city".to_string(), 0.99)]), - model: "nli-model".to_string(), - }; - - let body_text = serde_json::to_string(&hallucatination_body).unwrap(); - - module - .call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_log(Some(LogLevel::Debug), None) - .expect_http_call( - Some("arch_internal"), - Some(vec![ - ("x-arch-upstream", "api_server"), - (":method", "POST"), - (":path", "/weather"), - (":authority", "api_server"), - ("content-type", "application/json"), - ("x-envoy-max-retries", "3"), - ]), - None, - None, - None, - ) - .returning(Some(6)) - .expect_metric_increment("active_http_calls", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); - - let body_text = String::from("test body"); - module - .call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0) - .expect_metric_increment("active_http_calls", -1) - .expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody)) - .returning(Some(&body_text)) - .expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status")) - .returning(Some("200")) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_send_local_response( - Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()), - None, - None, - None, - ) - .expect_metric_increment("ratelimited_rq", 1) - .execute_and_expect(ReturnType::None) - .unwrap(); -} - -#[test] -#[serial] -fn request_not_ratelimited() { +fn request_to_llm_gateway() { let args = tester::MockSettings { wasm_path: wasm_module(), quiet: false, @@ -797,9 +632,44 @@ fn request_not_ratelimited() { .returning(Some("200")) .expect_log(Some(LogLevel::Debug), None) .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) - .expect_log(Some(LogLevel::Debug), None) .expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None) .execute_and_expect(ReturnType::None) .unwrap(); + + let chat_completion_response = ChatCompletionsResponse { + usage: Some(Usage { + completion_tokens: 0, + }), + choices: vec![Choice { + finish_reason: "test".to_string(), + index: 0, + message: Message { + role: "assistant".to_string(), + content: Some("hello from fake llm gateway".to_string()), + model: None, + tool_calls: None, + }, + }], + model: String::from("test"), + metadata: None, + }; + + let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap(); + module + .call_proxy_on_response_body( + http_context, + chat_completion_response_str.len() as i32, + true, + ) + .expect_get_buffer_bytes(Some(BufferType::HttpResponseBody)) + .returning(Some(chat_completion_response_str.as_str())) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_log(Some(LogLevel::Debug), None) + .expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None) + .expect_log(Some(LogLevel::Debug), None) + .execute_and_expect(ReturnType::Action(Action::Continue)) + .unwrap(); } diff --git a/gateway.code-workspace b/gateway.code-workspace index ed15406b..cc1b4efc 100644 --- a/gateway.code-workspace +++ b/gateway.code-workspace @@ -14,7 +14,7 @@ }, { "name": "llm_gateway", - "path": "crates/prompt_gateway" + "path": "crates/llm_gateway" }, { "name": "arch/tools",