mirror of
https://github.com/katanemo/plano.git
synced 2026-04-26 01:06:25 +02:00
803 lines
27 KiB
Rust
803 lines
27 KiB
Rust
use http::StatusCode;
|
|
use proxy_wasm_test_framework::tester::{self, Tester};
|
|
use proxy_wasm_test_framework::types::{
|
|
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
|
};
|
|
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
|
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
|
use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
|
|
use public_types::embeddings::{
|
|
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
|
|
Embedding,
|
|
};
|
|
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
|
use serde_yaml::Value;
|
|
use serial_test::serial;
|
|
use std::collections::HashMap;
|
|
use std::path::Path;
|
|
|
|
fn wasm_module() -> String {
|
|
let wasm_file = Path::new("target/wasm32-wasi/release/intelligent_prompt_gateway.wasm");
|
|
assert!(
|
|
wasm_file.exists(),
|
|
"Run `cargo build --release --target=wasm32-wasi` first"
|
|
);
|
|
wasm_file.to_str().unwrap().to_string()
|
|
}
|
|
|
|
fn request_headers_expectations(module: &mut Tester, http_context: i32) {
|
|
module
|
|
.call_proxy_on_request_headers(http_context, 0, false)
|
|
.expect_get_header_map_value(
|
|
Some(MapType::HttpRequestHeaders),
|
|
Some("x-arch-llm-provider-hint"),
|
|
)
|
|
.returning(Some("default"))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_add_header_map_value(
|
|
Some(MapType::HttpRequestHeaders),
|
|
Some("x-arch-upstream"),
|
|
Some("arch_llm_listener"),
|
|
)
|
|
.expect_add_header_map_value(
|
|
Some(MapType::HttpRequestHeaders),
|
|
Some("x-arch-llm-provider"),
|
|
Some("open-ai-gpt-4"),
|
|
)
|
|
.expect_replace_header_map_value(
|
|
Some(MapType::HttpRequestHeaders),
|
|
Some("Authorization"),
|
|
Some("Bearer secret_key"),
|
|
)
|
|
.expect_remove_header_map_value(Some(MapType::HttpRequestHeaders), Some("content-length"))
|
|
.expect_get_header_map_value(
|
|
Some(MapType::HttpRequestHeaders),
|
|
Some("x-arch-ratelimit-selector"),
|
|
)
|
|
.returning(Some("selector-key"))
|
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("selector-key"))
|
|
.returning(Some("selector-value"))
|
|
.expect_get_header_map_pairs(Some(MapType::HttpRequestHeaders))
|
|
.returning(None)
|
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some(":path"))
|
|
.returning(Some("/v1/chat/completions"))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_get_header_map_value(Some(MapType::HttpRequestHeaders), Some("x-request-id"))
|
|
.returning(None)
|
|
.execute_and_expect(ReturnType::Action(Action::Continue))
|
|
.unwrap();
|
|
}
|
|
|
|
fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|
module
|
|
.call_proxy_on_context_create(http_context, filter_context)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
request_headers_expectations(module, http_context);
|
|
|
|
// Request Body
|
|
let chat_completions_request_body = "\
|
|
{\
|
|
\"messages\": [\
|
|
{\
|
|
\"role\": \"system\",\
|
|
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
|
},\
|
|
{\
|
|
\"role\": \"user\",\
|
|
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
|
}\
|
|
],\
|
|
\"model\": \"gpt-4\"\
|
|
}";
|
|
|
|
module
|
|
.call_proxy_on_request_body(
|
|
http_context,
|
|
chat_completions_request_body.len() as i32,
|
|
true,
|
|
)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
|
.returning(Some(chat_completions_request_body))
|
|
// The actual call is not important in this test, we just need to grab the token_id
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
("x-arch-upstream", "model_server"),
|
|
(":method", "POST"),
|
|
(":path", "/guard"),
|
|
(":authority", "model_server"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-max-retries", "3"),
|
|
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(1))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.execute_and_expect(ReturnType::Action(Action::Pause))
|
|
.unwrap();
|
|
|
|
let prompt_guard_response = PromptGuardResponse {
|
|
toxic_prob: None,
|
|
toxic_verdict: None,
|
|
jailbreak_prob: None,
|
|
jailbreak_verdict: None,
|
|
};
|
|
let prompt_guard_response_buffer = serde_json::to_string(&prompt_guard_response).unwrap();
|
|
module
|
|
.call_proxy_on_http_call_response(
|
|
http_context,
|
|
1,
|
|
0,
|
|
prompt_guard_response_buffer.len() as i32,
|
|
0,
|
|
)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&prompt_guard_response_buffer))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
("x-arch-upstream", "model_server"),
|
|
(":method", "POST"),
|
|
(":path", "/embeddings"),
|
|
(":authority", "model_server"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-max-retries", "3"),
|
|
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(2))
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
let embedding_response = CreateEmbeddingResponse {
|
|
data: vec![Embedding {
|
|
index: 0,
|
|
embedding: vec![],
|
|
object: embedding::Object::default(),
|
|
}],
|
|
model: String::from("test"),
|
|
object: create_embedding_response::Object::default(),
|
|
usage: Box::new(CreateEmbeddingResponseUsage::new(0, 0)),
|
|
};
|
|
let embeddings_response_buffer = serde_json::to_string(&embedding_response).unwrap();
|
|
module
|
|
.call_proxy_on_http_call_response(
|
|
http_context,
|
|
2,
|
|
0,
|
|
embeddings_response_buffer.len() as i32,
|
|
0,
|
|
)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&embeddings_response_buffer))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
("x-arch-upstream", "model_server"),
|
|
(":method", "POST"),
|
|
(":path", "/zeroshot"),
|
|
(":authority", "model_server"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-max-retries", "3"),
|
|
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(3))
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
let zero_shot_response = ZeroShotClassificationResponse {
|
|
predicted_class: "weather_forecast".to_string(),
|
|
predicted_class_score: 0.1,
|
|
scores: HashMap::new(),
|
|
model: "test-model".to_string(),
|
|
};
|
|
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
|
|
module
|
|
.call_proxy_on_http_call_response(
|
|
http_context,
|
|
3,
|
|
0,
|
|
zeroshot_intent_detection_buffer.len() as i32,
|
|
0,
|
|
)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&zeroshot_intent_detection_buffer))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Info), None)
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
(":method", "POST"),
|
|
("x-arch-upstream", "arch_fc"),
|
|
(":path", "/v1/chat/completions"),
|
|
(":authority", "arch_fc"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-max-retries", "3"),
|
|
("x-envoy-upstream-rq-timeout-ms", "120000"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(4))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
}
|
|
|
|
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
|
|
let filter_context = 1;
|
|
|
|
module
|
|
.call_proxy_on_context_create(filter_context, 0)
|
|
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
|
|
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
module
|
|
.call_proxy_on_configure(filter_context, config.len() as i32)
|
|
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
|
|
.returning(Some(config))
|
|
.execute_and_expect(ReturnType::Bool(true))
|
|
.unwrap();
|
|
|
|
module
|
|
.call_proxy_on_tick(filter_context)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
("x-arch-upstream", "model_server"),
|
|
(":method", "POST"),
|
|
(":path", "/embeddings"),
|
|
(":authority", "model_server"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(101))
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.expect_set_tick_period_millis(Some(0))
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
let embedding_response = CreateEmbeddingResponse {
|
|
data: vec![Embedding {
|
|
embedding: vec![],
|
|
index: 0,
|
|
object: embedding::Object::default(),
|
|
}],
|
|
model: String::from("test"),
|
|
object: create_embedding_response::Object::default(),
|
|
usage: Box::new(CreateEmbeddingResponseUsage {
|
|
prompt_tokens: 0,
|
|
total_tokens: 0,
|
|
}),
|
|
};
|
|
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
|
|
module
|
|
.call_proxy_on_http_call_response(
|
|
filter_context,
|
|
101,
|
|
0,
|
|
embedding_response_str.len() as i32,
|
|
0,
|
|
)
|
|
.expect_log(
|
|
Some(LogLevel::Debug),
|
|
Some(
|
|
format!(
|
|
"filter_context: on_http_call_response called with token_id: {:?}",
|
|
101
|
|
)
|
|
.as_str(),
|
|
),
|
|
)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&embedding_response_str))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
filter_context
|
|
}
|
|
|
|
fn default_config() -> &'static str {
|
|
r#"
|
|
version: "0.1-beta"
|
|
|
|
listener:
|
|
address: 0.0.0.0
|
|
port: 10000
|
|
message_format: huggingface
|
|
connect_timeout: 0.005s
|
|
|
|
endpoints:
|
|
api_server:
|
|
endpoint: api_server:80
|
|
connect_timeout: 0.005s
|
|
|
|
llm_providers:
|
|
- name: open-ai-gpt-4
|
|
provider: openai
|
|
access_key: secret_key
|
|
model: gpt-4
|
|
default: true
|
|
|
|
overrides:
|
|
# confidence threshold for prompt target intent matching
|
|
prompt_target_intent_matching_threshold: 0.6
|
|
|
|
system_prompt: |
|
|
You are a helpful assistant.
|
|
|
|
prompt_guards:
|
|
input_guards:
|
|
jailbreak:
|
|
on_exception:
|
|
message: "Looks like you're curious about my abilities, but I can only provide assistance within my programmed parameters."
|
|
|
|
prompt_targets:
|
|
- name: weather_forecast
|
|
description: This function provides realtime weather forecast information for a given city.
|
|
parameters:
|
|
- name: city
|
|
required: true
|
|
description: The city for which the weather forecast is requested.
|
|
- name: days
|
|
description: The number of days for which the weather forecast is requested.
|
|
- name: units
|
|
description: The units in which the weather forecast is requested.
|
|
endpoint:
|
|
name: api_server
|
|
path: /weather
|
|
system_prompt: |
|
|
You are a helpful weather forecaster. Use weater data that is provided to you. Please following following guidelines when responding to user queries:
|
|
- Use farenheight for temperature
|
|
- Use miles per hour for wind speed
|
|
|
|
ratelimits:
|
|
- model: gpt-4
|
|
selector:
|
|
key: selector-key
|
|
value: selector-value
|
|
limit:
|
|
tokens: 1
|
|
unit: minute
|
|
"#
|
|
}
|
|
|
|
#[test]
|
|
#[serial]
|
|
fn successful_request_to_open_ai_chat_completions() {
|
|
let args = tester::MockSettings {
|
|
wasm_path: wasm_module(),
|
|
quiet: false,
|
|
allow_unexpected: false,
|
|
};
|
|
let mut module = tester::mock(args).unwrap();
|
|
|
|
module
|
|
.call_start()
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
// Setup Filter
|
|
let filter_context = setup_filter(&mut module, default_config());
|
|
|
|
// Setup HTTP Stream
|
|
let http_context = 2;
|
|
|
|
module
|
|
.call_proxy_on_context_create(http_context, filter_context)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
request_headers_expectations(&mut module, http_context);
|
|
|
|
// Request Body
|
|
let chat_completions_request_body = "\
|
|
{\
|
|
\"messages\": [\
|
|
{\
|
|
\"role\": \"system\",\
|
|
\"content\": \"You are a poetic assistant, skilled in explaining complex programming concepts with creative flair.\"\
|
|
},\
|
|
{\
|
|
\"role\": \"user\",\
|
|
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
|
}\
|
|
],\
|
|
\"model\": \"gpt-4\"\
|
|
}";
|
|
|
|
module
|
|
.call_proxy_on_request_body(
|
|
http_context,
|
|
chat_completions_request_body.len() as i32,
|
|
true,
|
|
)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
|
.returning(Some(chat_completions_request_body))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_http_call(Some("arch_internal"), None, None, None, None)
|
|
.returning(Some(4))
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.execute_and_expect(ReturnType::Action(Action::Pause))
|
|
.unwrap();
|
|
}
|
|
|
|
#[test]
|
|
#[serial]
|
|
fn bad_request_to_open_ai_chat_completions() {
|
|
let args = tester::MockSettings {
|
|
wasm_path: wasm_module(),
|
|
quiet: false,
|
|
allow_unexpected: false,
|
|
};
|
|
let mut module = tester::mock(args).unwrap();
|
|
|
|
module
|
|
.call_start()
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
// Setup Filter
|
|
let filter_context = setup_filter(&mut module, default_config());
|
|
|
|
// Setup HTTP Stream
|
|
let http_context = 2;
|
|
|
|
module
|
|
.call_proxy_on_context_create(http_context, filter_context)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
request_headers_expectations(&mut module, http_context);
|
|
|
|
// Request Body
|
|
let incomplete_chat_completions_request_body = "\
|
|
{\
|
|
\"messages\": [\
|
|
{\
|
|
\"role\": \"system\",\
|
|
},\
|
|
{\
|
|
\"role\": \"user\",\
|
|
\"content\": \"Compose a poem that explains the concept of recursion in programming.\"\
|
|
}\
|
|
]\
|
|
}";
|
|
|
|
module
|
|
.call_proxy_on_request_body(
|
|
http_context,
|
|
incomplete_chat_completions_request_body.len() as i32,
|
|
true,
|
|
)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpRequestBody))
|
|
.returning(Some(incomplete_chat_completions_request_body))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_send_local_response(
|
|
Some(StatusCode::BAD_REQUEST.as_u16().into()),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.execute_and_expect(ReturnType::Action(Action::Pause))
|
|
.unwrap();
|
|
}
|
|
|
|
#[test]
|
|
#[serial]
|
|
fn request_ratelimited() {
|
|
let args = tester::MockSettings {
|
|
wasm_path: wasm_module(),
|
|
quiet: false,
|
|
allow_unexpected: false,
|
|
};
|
|
let mut module = tester::mock(args).unwrap();
|
|
|
|
module
|
|
.call_start()
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
// Setup Filter
|
|
let filter_context = setup_filter(&mut module, default_config());
|
|
|
|
// Setup HTTP Stream
|
|
let http_context = 2;
|
|
|
|
normal_flow(&mut module, filter_context, http_context);
|
|
|
|
let arch_fc_resp = ChatCompletionsResponse {
|
|
usage: Some(Usage {
|
|
completion_tokens: 0,
|
|
}),
|
|
choices: vec![Choice {
|
|
finish_reason: "test".to_string(),
|
|
index: 0,
|
|
message: Message {
|
|
role: "system".to_string(),
|
|
content: None,
|
|
tool_calls: Some(vec![ToolCall {
|
|
id: String::from("test"),
|
|
tool_type: ToolType::Function,
|
|
function: FunctionCallDetail {
|
|
name: String::from("weather_forecast"),
|
|
arguments: HashMap::from([(
|
|
String::from("city"),
|
|
Value::String(String::from("seattle")),
|
|
)]),
|
|
},
|
|
}]),
|
|
model: None,
|
|
},
|
|
}],
|
|
model: String::from("test"),
|
|
metadata: None,
|
|
};
|
|
|
|
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
|
module
|
|
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&arch_fc_resp_str))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
("x-arch-upstream", "model_server"),
|
|
(":method", "POST"),
|
|
(":path", "/hallucination"),
|
|
(":authority", "model_server"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-max-retries", "3"),
|
|
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(5))
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
let hallucatination_body = HallucinationClassificationResponse {
|
|
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
|
model: "nli-model".to_string(),
|
|
};
|
|
|
|
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
|
|
|
module
|
|
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&body_text))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
("x-arch-upstream", "api_server"),
|
|
(":method", "POST"),
|
|
(":path", "/weather"),
|
|
(":authority", "api_server"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-max-retries", "3"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(6))
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
let body_text = String::from("test body");
|
|
module
|
|
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&body_text))
|
|
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
|
.returning(Some("200"))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_send_local_response(
|
|
Some(StatusCode::TOO_MANY_REQUESTS.as_u16().into()),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.expect_metric_increment("ratelimited_rq", 1)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
}
|
|
|
|
#[test]
|
|
#[serial]
|
|
fn request_not_ratelimited() {
|
|
let args = tester::MockSettings {
|
|
wasm_path: wasm_module(),
|
|
quiet: false,
|
|
allow_unexpected: false,
|
|
};
|
|
let mut module = tester::mock(args).unwrap();
|
|
|
|
module
|
|
.call_start()
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
// Setup Filter
|
|
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
|
|
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
|
|
let config_str = serde_json::to_string(&config).unwrap();
|
|
|
|
let filter_context = setup_filter(&mut module, &config_str);
|
|
|
|
// Setup HTTP Stream
|
|
let http_context = 2;
|
|
|
|
normal_flow(&mut module, filter_context, http_context);
|
|
|
|
let arch_fc_resp = ChatCompletionsResponse {
|
|
usage: Some(Usage {
|
|
completion_tokens: 0,
|
|
}),
|
|
choices: vec![Choice {
|
|
finish_reason: "test".to_string(),
|
|
index: 0,
|
|
message: Message {
|
|
role: "system".to_string(),
|
|
content: None,
|
|
tool_calls: Some(vec![ToolCall {
|
|
id: String::from("test"),
|
|
tool_type: ToolType::Function,
|
|
function: FunctionCallDetail {
|
|
name: String::from("weather_forecast"),
|
|
arguments: HashMap::from([(
|
|
String::from("city"),
|
|
Value::String(String::from("seattle")),
|
|
)]),
|
|
},
|
|
}]),
|
|
model: None,
|
|
},
|
|
}],
|
|
model: String::from("test"),
|
|
metadata: None,
|
|
};
|
|
|
|
let arch_fc_resp_str = serde_json::to_string(&arch_fc_resp).unwrap();
|
|
module
|
|
.call_proxy_on_http_call_response(http_context, 4, 0, arch_fc_resp_str.len() as i32, 0)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&arch_fc_resp_str))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
("x-arch-upstream", "model_server"),
|
|
(":method", "POST"),
|
|
(":path", "/hallucination"),
|
|
(":authority", "model_server"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-max-retries", "3"),
|
|
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(5))
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
// hallucination should return that parameters were not halliucinated
|
|
// prompt: str
|
|
// parameters: dict
|
|
// model: str
|
|
|
|
let hallucatination_body = HallucinationClassificationResponse {
|
|
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
|
model: "nli-model".to_string(),
|
|
};
|
|
|
|
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
|
|
|
module
|
|
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&body_text))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_http_call(
|
|
Some("arch_internal"),
|
|
Some(vec![
|
|
("x-arch-upstream", "api_server"),
|
|
(":method", "POST"),
|
|
(":path", "/weather"),
|
|
(":authority", "api_server"),
|
|
("content-type", "application/json"),
|
|
("x-envoy-max-retries", "3"),
|
|
]),
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
.returning(Some(6))
|
|
.expect_metric_increment("active_http_calls", 1)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
|
|
let body_text = String::from("test body");
|
|
module
|
|
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
|
.expect_metric_increment("active_http_calls", -1)
|
|
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
|
.returning(Some(&body_text))
|
|
.expect_get_header_map_value(Some(MapType::HttpCallResponseHeaders), Some(":status"))
|
|
.returning(Some("200"))
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_log(Some(LogLevel::Debug), None)
|
|
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
|
.execute_and_expect(ReturnType::None)
|
|
.unwrap();
|
|
}
|