Split arch wasm filter code into prompt and llm gateway filters (#190)

This commit is contained in:
Adil Hafeez 2024-10-17 10:16:40 -07:00 committed by GitHub
parent 8e54ac20d8
commit 21e7fe2cef
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 696 additions and 2801 deletions

View file

@ -1,13 +1,13 @@
use filter_context::FilterContext;
use prompt_filter_context::PromptGatewayFilterContext;
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
mod filter_context;
mod stream_context;
mod prompt_filter_context;
mod prompt_stream_context;
proxy_wasm::main! {{
proxy_wasm::set_log_level(LogLevel::Trace);
proxy_wasm::set_root_context(|_| -> Box<dyn RootContext> {
Box::new(FilterContext::new())
Box::new(PromptGatewayFilterContext::new())
});
}}

View file

@ -1,4 +1,4 @@
use crate::stream_context::StreamContext;
use crate::prompt_stream_context::PromptStreamContext;
use common::common_types::EmbeddingType;
use common::configuration::{Configuration, GatewayMode, Overrides, PromptGuards, PromptTarget};
use common::consts::ARCH_INTERNAL_CLUSTER_NAME;
@ -11,8 +11,6 @@ use common::embeddings::{
use common::http::CallArgs;
use common::http::Client;
use common::llm_providers::LlmProviders;
use common::ratelimit;
use common::stats::Counter;
use common::stats::Gauge;
use common::stats::IncrementingMetric;
use log::debug;
@ -27,14 +25,12 @@ use std::time::Duration;
#[derive(Copy, Clone, Debug)]
pub struct WasmMetrics {
pub active_http_calls: Gauge,
pub ratelimited_rq: Counter,
}
impl WasmMetrics {
fn new() -> WasmMetrics {
WasmMetrics {
active_http_calls: Gauge::new(String::from("active_http_calls")),
ratelimited_rq: Counter::new(String::from("ratelimited_rq")),
}
}
}
@ -49,7 +45,7 @@ pub struct FilterCallContext {
}
#[derive(Debug)]
pub struct FilterContext {
pub struct PromptGatewayFilterContext {
metrics: Rc<WasmMetrics>,
// callouts stores token_id to request mapping that we use during #on_http_call_response to match the response to the request.
callouts: RefCell<HashMap<u32, FilterCallContext>>,
@ -63,9 +59,9 @@ pub struct FilterContext {
temp_embeddings_store: EmbeddingsStore,
}
impl FilterContext {
pub fn new() -> FilterContext {
FilterContext {
impl PromptGatewayFilterContext {
pub fn new() -> PromptGatewayFilterContext {
PromptGatewayFilterContext {
callouts: RefCell::new(HashMap::new()),
metrics: Rc::new(WasmMetrics::new()),
system_prompt: Rc::new(None),
@ -121,7 +117,7 @@ impl FilterContext {
Duration::from_secs(60),
);
let call_context = crate::filter_context::FilterCallContext {
let call_context = crate::prompt_filter_context::FilterCallContext {
prompt_target_name: String::from(prompt_target_name),
embedding_type,
};
@ -198,7 +194,7 @@ impl FilterContext {
}
}
impl Client for FilterContext {
impl Client for PromptGatewayFilterContext {
type CallContext = FilterCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {
@ -210,7 +206,7 @@ impl Client for FilterContext {
}
}
impl Context for FilterContext {
impl Context for PromptGatewayFilterContext {
fn on_http_call_response(
&mut self,
token_id: u32,
@ -239,7 +235,7 @@ impl Context for FilterContext {
}
// RootContext allows the Rust code to reach into the Envoy Config
impl RootContext for FilterContext {
impl RootContext for PromptGatewayFilterContext {
fn on_configure(&mut self, _: usize) -> bool {
let config_bytes = self
.get_plugin_configuration()
@ -260,8 +256,6 @@ impl RootContext for FilterContext {
self.prompt_targets = Rc::new(prompt_targets);
self.mode = config.mode.unwrap_or_default();
ratelimit::ratelimits(Some(config.ratelimits.unwrap_or_default()));
if let Some(prompt_guards) = config.prompt_guards {
self.prompt_guards = Rc::new(prompt_guards)
}
@ -285,20 +279,14 @@ impl RootContext for FilterContext {
GatewayMode::Llm => None,
GatewayMode::Prompt => Some(Rc::clone(self.embeddings_store.as_ref().unwrap())),
};
Some(Box::new(StreamContext::new(
Some(Box::new(PromptStreamContext::new(
context_id,
Rc::clone(&self.metrics),
Rc::clone(&self.system_prompt),
Rc::clone(&self.prompt_targets),
Rc::clone(&self.prompt_guards),
Rc::clone(&self.overrides),
Rc::clone(
self.llm_providers
.as_ref()
.expect("LLM Providers must exist when Streams are being created"),
),
embedding_store,
self.mode.clone(),
)))
}

View file

@ -1,33 +1,28 @@
use crate::filter_context::{EmbeddingsStore, WasmMetrics};
use crate::prompt_filter_context::{EmbeddingsStore, WasmMetrics};
use acap::cos;
use common::common_types::open_ai::{
ArchState, ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, Choice, FunctionDefinition, FunctionParameter, FunctionParameters,
Message, ParameterType, StreamOptions, ToolCall, ToolCallState, ToolType,
ArchState, ChatCompletionTool, ChatCompletionsRequest, ChatCompletionsResponse, Choice,
FunctionDefinition, FunctionParameter, FunctionParameters, Message, ParameterType,
StreamOptions, ToolCall, ToolCallState, ToolType,
};
use common::common_types::{
EmbeddingType, HallucinationClassificationRequest, HallucinationClassificationResponse,
PromptGuardRequest, PromptGuardResponse, PromptGuardTask, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
};
use common::configuration::{GatewayMode, LlmProvider};
use common::configuration::{Overrides, PromptGuards, PromptTarget};
use common::consts::{
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME,
ARCH_LLM_UPSTREAM_LISTENER, ARCH_MESSAGES_KEY, ARCH_MODEL_PREFIX, ARCH_PROVIDER_HINT_HEADER,
ARCH_ROUTING_HEADER, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
ARCH_FC_MODEL_NAME, ARCH_FC_REQUEST_TIMEOUT_MS, ARCH_INTERNAL_CLUSTER_NAME, ARCH_MESSAGES_KEY,
ARCH_MODEL_PREFIX, ARCH_STATE_HEADER, ARCH_UPSTREAM_HOST_HEADER, ARC_FC_CLUSTER,
CHAT_COMPLETIONS_PATH, DEFAULT_EMBEDDING_MODEL, DEFAULT_HALLUCINATED_THRESHOLD,
DEFAULT_INTENT_MODEL, DEFAULT_PROMPT_TARGET_THRESHOLD, GPT_35_TURBO, MODEL_SERVER_NAME,
RATELIMIT_SELECTOR_HEADER_KEY, REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
REQUEST_ID_HEADER, SYSTEM_ROLE, USER_ROLE,
};
use common::embeddings::{
CreateEmbeddingRequest, CreateEmbeddingRequestInput, CreateEmbeddingResponse,
};
use common::http::{CallArgs, Client, ClientError};
use common::llm_providers::LlmProviders;
use common::ratelimit::Header;
use common::stats::Gauge;
use common::{ratelimit, routing, tokenizer};
use http::StatusCode;
use log::{debug, info, warn};
use proxy_wasm::traits::*;
@ -36,7 +31,6 @@ use serde_json::Value;
use sha2::{Digest, Sha256};
use std::cell::RefCell;
use std::collections::HashMap;
use std::num::NonZero;
use std::rc::Rc;
use std::time::Duration;
@ -81,17 +75,13 @@ pub enum ServerError {
path: String,
status: String,
},
#[error(transparent)]
ExceededRatelimit(ratelimit::Error),
#[error("jailbreak detected: {0}")]
Jailbreak(String),
#[error("{why}")]
BadRequest { why: String },
#[error("{why}")]
NoMessagesFound { why: String },
}
pub struct StreamContext {
pub struct PromptStreamContext {
context_id: u32,
metrics: Rc<WasmMetrics>,
system_prompt: Rc<Option<String>>,
@ -103,20 +93,16 @@ pub struct StreamContext {
tool_call_response: Option<String>,
arch_state: Option<Vec<ArchState>>,
request_body_size: usize,
ratelimit_selector: Option<Header>,
streaming_response: bool,
user_prompt: Option<Message>,
response_tokens: usize,
is_chat_completions_request: bool,
chat_completions_request: Option<ChatCompletionsRequest>,
prompt_guards: Rc<PromptGuards>,
llm_providers: Rc<LlmProviders>,
llm_provider: Option<Rc<LlmProvider>>,
request_id: Option<String>,
mode: GatewayMode,
}
impl StreamContext {
impl PromptStreamContext {
#[allow(clippy::too_many_arguments)]
pub fn new(
context_id: u32,
@ -125,11 +111,9 @@ impl StreamContext {
prompt_targets: Rc<HashMap<String, PromptTarget>>,
prompt_guards: Rc<PromptGuards>,
overrides: Rc<Option<Overrides>>,
llm_providers: Rc<LlmProviders>,
embeddings_store: Option<Rc<EmbeddingsStore>>,
mode: GatewayMode,
) -> Self {
StreamContext {
PromptStreamContext {
context_id,
metrics,
system_prompt,
@ -141,75 +125,21 @@ impl StreamContext {
tool_call_response: None,
arch_state: None,
request_body_size: 0,
ratelimit_selector: None,
streaming_response: false,
user_prompt: None,
response_tokens: 0,
is_chat_completions_request: false,
llm_providers,
llm_provider: None,
prompt_guards,
overrides,
request_id: None,
mode,
}
}
fn llm_provider(&self) -> &LlmProvider {
self.llm_provider
.as_ref()
.expect("the provider should be set when asked for it")
}
fn embeddings_store(&self) -> &EmbeddingsStore {
self.embeddings_store
.as_ref()
.expect("embeddings store is not set")
}
fn select_llm_provider(&mut self) {
let provider_hint = self
.get_http_request_header(ARCH_PROVIDER_HINT_HEADER)
.map(|provider_name| provider_name.into());
debug!("llm provider hint: {:?}", provider_hint);
self.llm_provider = Some(routing::get_llm_provider(
&self.llm_providers,
provider_hint,
));
debug!("selected llm: {}", self.llm_provider.as_ref().unwrap().name);
}
fn add_routing_header(&mut self) {
match self.mode {
GatewayMode::Prompt => {
// in prompt gateway mode, we need to route to llm upstream listener
self.add_http_request_header(ARCH_UPSTREAM_HOST_HEADER, ARCH_LLM_UPSTREAM_LISTENER);
}
_ => {
self.add_http_request_header(ARCH_ROUTING_HEADER, &self.llm_provider().name);
}
}
}
fn modify_auth_headers(&mut self) -> Result<(), ServerError> {
let llm_provider_api_key_value =
self.llm_provider()
.access_key
.as_ref()
.ok_or(ServerError::BadRequest {
why: format!(
"No access key configured for selected LLM Provider \"{}\"",
self.llm_provider()
),
})?;
let authorization_header_value = format!("Bearer {}", llm_provider_api_key_value);
self.set_http_request_header("Authorization", Some(&authorization_header_value));
Ok(())
}
fn delete_content_length_header(&mut self) {
// Remove the Content-Length header because further body manipulations in the gateway logic will invalidate it.
// Server's generally throw away requests whose body length do not match the Content-Length header.
@ -218,15 +148,6 @@ impl StreamContext {
self.set_http_request_header("content-length", None);
}
fn save_ratelimit_header(&mut self) {
self.ratelimit_selector = self
.get_http_request_header(RATELIMIT_SELECTOR_HEADER_KEY)
.and_then(|key| {
self.get_http_request_header(&key)
.map(|value| Header { key, value })
});
}
fn send_server_error(&self, error: ServerError, override_status_code: Option<StatusCode>) {
debug!("server error occurred: {}", error);
self.send_http_response(
@ -682,6 +603,7 @@ impl StreamContext {
}
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
self.tool_calls = Some(tool_calls.clone());
// TODO CO: pass nli check
// If hallucination, pass chat template to check parameters
@ -954,40 +876,12 @@ impl StreamContext {
return self.send_server_error(ServerError::Serialization(e), None);
}
};
debug!("arch => openai request body: {}", json_string);
// Tokenize and Ratelimit.
if let Err(e) = self.enforce_ratelimits(&chat_completions_request.model, &json_string) {
self.send_server_error(
ServerError::ExceededRatelimit(e),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return;
}
debug!("arch => upstream llm request body: {}", json_string);
self.set_http_request_body(0, self.request_body_size, &json_string.into_bytes());
self.resume_http_request();
}
fn enforce_ratelimits(
&mut self,
model: &str,
json_string: &str,
) -> Result<(), ratelimit::Error> {
if let Some(selector) = self.ratelimit_selector.take() {
// Tokenize and Ratelimit.
if let Ok(token_count) = tokenizer::token_count(model, json_string) {
ratelimit::ratelimits(None).read().unwrap().check_limit(
model.to_owned(),
selector,
NonZero::new(token_count as u32).unwrap(),
)?;
}
}
Ok(())
}
fn arch_guard_handler(&mut self, body: Vec<u8>, callout_context: StreamCallContext) {
debug!("response received for arch guard");
let prompt_guard_resp: PromptGuardResponse = serde_json::from_slice(&body).unwrap();
@ -1137,23 +1031,17 @@ impl StreamContext {
}
// HttpContext is the trait that allows the Rust code to interact with HTTP objects.
impl HttpContext for StreamContext {
impl HttpContext for PromptStreamContext {
// Envoy's HTTP model is event driven. The WASM ABI has given implementors events to hook onto
// the lifecycle of the http request and response.
fn on_http_request_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
self.select_llm_provider();
self.add_routing_header();
if let Err(error) = self.modify_auth_headers() {
self.send_server_error(error, Some(StatusCode::BAD_REQUEST));
}
self.delete_content_length_header();
self.save_ratelimit_header();
self.is_chat_completions_request =
self.get_http_request_header(":path").unwrap_or_default() == CHAT_COMPLETIONS_PATH;
debug!(
"S[{}] req_headers={:?}",
"on_http_request_headers S[{}] req_headers={:?}",
self.context_id,
self.get_http_request_headers()
);
@ -1176,6 +1064,11 @@ impl HttpContext for StreamContext {
self.request_body_size = body_size;
debug!(
"on_http_request_body S[{}] body_size={}",
self.context_id, body_size
);
// Deserialize body into spec.
// Currently OpenAI API.
let mut deserialized_body: ChatCompletionsRequest =
@ -1203,40 +1096,6 @@ impl HttpContext for StreamContext {
};
self.is_chat_completions_request = true;
if self.mode == GatewayMode::Llm {
debug!("llm gateway mode, skipping over all prompt targets");
// remove metadata from the request body
deserialized_body.metadata = None;
// delete model key from message array
for message in deserialized_body.messages.iter_mut() {
message.model = None;
}
deserialized_body
.model
.clone_from(&self.llm_provider.as_ref().unwrap().model);
let chat_completion_request_str = serde_json::to_string(&deserialized_body).unwrap();
// enforce ratelimits
if let Err(e) =
self.enforce_ratelimits(&deserialized_body.model, &chat_completion_request_str)
{
self.send_server_error(
ServerError::ExceededRatelimit(e),
Some(StatusCode::TOO_MANY_REQUESTS),
);
self.metrics.ratelimited_rq.increment(1);
return Action::Continue;
}
debug!(
"arch => {:?}, body: {}",
deserialized_body.model, chat_completion_request_str
);
self.set_http_request_body(0, body_size, chat_completion_request_str.as_bytes());
return Action::Continue;
}
self.arch_state = match deserialized_body.metadata {
Some(ref metadata) => {
if metadata.contains_key(ARCH_STATE_HEADER) {
@ -1250,9 +1109,6 @@ impl HttpContext for StreamContext {
None => None,
};
// Set the model based on the chosen LLM Provider
deserialized_body.model = String::from(&self.llm_provider().model);
self.streaming_response = deserialized_body.stream;
if deserialized_body.stream && deserialized_body.stream_options.is_none() {
deserialized_body.stream_options = Some(StreamOptions {
@ -1285,7 +1141,7 @@ impl HttpContext for StreamContext {
self.chat_completions_request = Some(deserialized_body);
if !prompt_guard_jailbreak_task {
debug!("Missing input guard. Making inline call to retrieve");
debug!("Missing input guard. Making inline call to retrieve embeddings");
let callout_context = StreamCallContext {
response_handler_type: ResponseHandlerType::ArchGuard,
user_message: user_message_str.clone(),
@ -1360,6 +1216,18 @@ impl HttpContext for StreamContext {
Action::Pause
}
fn on_http_response_headers(&mut self, _num_headers: usize, _end_of_stream: bool) -> Action {
debug!(
"on_http_response_headers recv [S={}] headers={:?}",
self.context_id,
self.get_http_response_headers()
);
// delete content-lenght header let envoy calculate it, because we modify the response body
// that would result in a different content-length
self.set_http_response_header("content-length", None);
Action::Continue
}
fn on_http_response_body(&mut self, body_size: usize, end_of_stream: bool) -> Action {
debug!(
"recv [S={}] bytes={} end_stream={}",
@ -1385,48 +1253,7 @@ impl HttpContext for StreamContext {
.expect("cant get response body");
if self.streaming_response {
let body_str = String::from_utf8(body).expect("body is not utf-8");
debug!("streaming response");
let chat_completions_data = match body_str.split_once("data: ") {
Some((_, chat_completions_data)) => chat_completions_data,
None => {
self.send_server_error(
ServerError::LogicError(String::from("parsing error in streaming data")),
None,
);
return Action::Pause;
}
};
let chat_completions_chunk_response: ChatCompletionChunkResponse =
match serde_json::from_str(chat_completions_data) {
Ok(de) => de,
Err(_) => {
if chat_completions_data != "[NONE]" {
self.send_server_error(
ServerError::LogicError(String::from(
"error in streaming response",
)),
None,
);
return Action::Continue;
}
return Action::Continue;
}
};
if let Some(content) = chat_completions_chunk_response
.choices
.first()
.unwrap()
.delta
.content
.as_ref()
{
let model = &chat_completions_chunk_response.model;
let token_count = tokenizer::token_count(model, content).unwrap_or(0);
self.response_tokens += token_count;
}
} else {
debug!("non streaming response");
let chat_completions_response: ChatCompletionsResponse =
@ -1494,6 +1321,9 @@ impl HttpContext for StreamContext {
let metadata = map
.entry("metadata")
.or_insert(Value::Object(serde_json::Map::new()));
if metadata == &Value::Null {
*metadata = Value::Object(serde_json::Map::new());
}
metadata.as_object_mut().unwrap().insert(
ARCH_STATE_HEADER.to_string(),
serde_json::Value::String(arch_state_str),
@ -1512,12 +1342,11 @@ impl HttpContext for StreamContext {
self.context_id, self.response_tokens, end_of_stream
);
// TODO:: ratelimit based on response tokens.
Action::Continue
}
}
impl Context for StreamContext {
impl Context for PromptStreamContext {
fn on_http_call_response(
&mut self,
token_id: u32,
@ -1563,7 +1392,7 @@ impl Context for StreamContext {
}
}
impl Client for StreamContext {
impl Client for PromptStreamContext {
type CallContext = StreamCallContext;
fn callouts(&self) -> &RefCell<HashMap<u32, Self::CallContext>> {

View file

@ -28,39 +28,11 @@ fn wasm_module() -> String {
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
module
.call_proxy_on_request_headers(http_context, 0, false)
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider-hint"),
)
.returning(Some("default"))
.expect_log(Some(LogLevel::Debug), None)
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-upstream"),
Some("arch_llm_listener"),
)
.expect_add_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-llm-provider"),
Some("open-ai-gpt-4"),
)
.expect_replace_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("Authorization"),
Some("Bearer secret_key"),
)
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
.expect_get_header_map_value(
Some(MapType::HttpRequestHeaders),
Some("x-arch-ratelimit-selector"),
)
.returning(Some("selector-key"))
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key"))
.returning(Some("selector-value"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
.returning(Some("/v1/chat/completions"))
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
.returning(None)
.expect_log(Some(LogLevel::Debug), None)
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
.returning(None)
@ -102,6 +74,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
// The actual call is not important in this test, we just need to grab the token_id
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
@ -259,7 +232,6 @@ fn setup_filter(module: &mut Tester, config: &str) -> i32 {
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
@ -455,6 +427,7 @@ fn successful_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("arch_internal"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
@ -514,6 +487,7 @@ fn bad_request_to_open_ai_chat_completions() {
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
.returning(Some(incomplete_chat_completions_request_body))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::BAD_REQUEST.as_u16().into()),
None,
@ -526,146 +500,7 @@ fn bad_request_to_open_ai_chat_completions() {
#[test]
#[serial]
fn request_ratelimited() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
allow_unexpected: false,
};
let mut module = tester::mock(args).unwrap();
module
.call_start()
.execute_and_expect(ReturnType::None)
.unwrap();
// Setup Filter
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
message: Message {
role: "system".to_string(),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],
model: String::from("test"),
metadata: None,
};
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&arch_fc_resp_str))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "model_server"),
(":method", "POST"),
(":path", "/hallucination"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let hallucatination_body = HallucinationClassificationResponse {
params_scores: HashMap::from([("city".to_string(), 0.99)]),
model: "nli-model".to_string(),
};
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("arch_internal"),
Some(vec![
("x-arch-upstream", "api_server"),
(":method", "POST"),
(":path", "/weather"),
(":authority", "api_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
]),
None,
None,
None,
)
.returning(Some(6))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
.returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
None,
None,
None,
)
.expect_metric_increment("ratelimited_rq", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
}
#[test]
#[serial]
fn request_not_ratelimited() {
fn request_to_llm_gateway() {
let args = tester::MockSettings {
wasm_path: wasm_module(),
quiet: false,
@ -797,9 +632,44 @@ fn request_not_ratelimited() {
.returning(Some("200"))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.execute_and_expect(ReturnType::None)
.unwrap();
let chat_completion_response = ChatCompletionsResponse {
usage: Some(Usage {
completion_tokens: 0,
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
message: Message {
role: "assistant".to_string(),
content: Some("hello from fake llm gateway".to_string()),
model: None,
tool_calls: None,
},
}],
model: String::from("test"),
metadata: None,
};
let chat_completion_response_str = serde_json::to_string(&chat_completion_response).unwrap();
module
.call_proxy_on_response_body(
http_context,
chat_completion_response_str.len() as i32,
true,
)
.expect_get_buffer_bytes(Some(BufferType::HttpResponseBody))
.returning(Some(chat_completion_response_str.as_str()))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpResponseBody), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Continue))
.unwrap();
}