Clean up Embeddings Store (#121)

Signed-off-by: José Ulises Niño Rivera <junr03@users.noreply.github.com>
This commit is contained in:
José Ulises Niño Rivera 2024-10-04 19:33:52 -07:00 committed by GitHub
parent 10b5c5b42c
commit 2a9b9486f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 366 additions and 238 deletions

View file

@ -6,9 +6,9 @@ use proxy_wasm_test_framework::types::{
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::common_types::PromptGuardResponse;
use public_types::embeddings::embedding::Object;
use public_types::embeddings::{
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serde_yaml::Value;
@ -158,7 +158,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
data: vec![Embedding {
index: 0,
embedding: vec![],
object: Object::default(),
object: embedding::Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
@ -177,8 +177,6 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Warn), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("model_server"),
@ -243,8 +241,130 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.unwrap();
}
fn default_config() -> Configuration {
let config: &str = r#"
fn setup_filter(module: &mut Tester, config: &str) -> i32 {
let filter_context = 1;
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
module
.call_proxy_on_tick(filter_context)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(101))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/embeddings"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(102))
.expect_metric_increment("active_http_calls", 1)
.expect_set_tick_period_millis(Some(0))
.execute_and_expect(ReturnType::None)
.unwrap();
let embedding_response = CreateEmbeddingResponse {
data: vec![Embedding {
embedding: vec![],
index: 0,
object: embedding::Object::default(),
}],
model: String::from("test"),
object: create_embedding_response::Object::default(),
usage: Box::new(CreateEmbeddingResponseUsage {
prompt_tokens: 0,
total_tokens: 0,
}),
};
let embedding_response_str = serde_json::to_string(&embedding_response).unwrap();
module
.call_proxy_on_http_call_response(
filter_context,
101,
0,
embedding_response_str.len() as i32,
0,
)
.expect_log(
Some(LogLevel::Debug),
Some(
format!(
"filter_context: on_http_call_response called with token_id: {:?}",
101
)
.as_str(),
),
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embedding_response_str))
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_http_call_response(
filter_context,
102,
0,
embedding_response_str.len() as i32,
0,
)
.expect_log(
Some(LogLevel::Debug),
Some(
format!(
"filter_context: on_http_call_response called with token_id: {:?}",
102
)
.as_str(),
),
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embedding_response_str))
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
filter_context
}
fn default_config() -> &'static str {
r#"
version: "0.1-beta"
listener:
@ -297,24 +417,6 @@ prompt_targets:
- Use farenheight for temperature
- Use miles per hour for wind speed
- name: insurance_claim_details
type: function_resolver
description: This function resolver provides insurance claim details for a given policy number.
parameters:
- name: policy_number
required: true
description: The policy number for which the insurance claim details are requested.
type: string
- name: include_expired
description: whether to include expired insurance claims in the response.
type: bool
required: true
endpoint:
name: api_server
path: /insurance_claim_details
system_prompt: |
You are a helpful insurance claim details provider. Use insurance claim data that is provided to you. Please following following guidelines when responding to user queries:
- Use policy number to retrieve insurance claim details
ratelimits:
- model: gpt-4
selector:
@ -323,8 +425,7 @@ ratelimits:
limit:
tokens: 1
unit: minute
"#;
serde_yaml::from_str(config).unwrap()
"#
}
#[test]
@ -343,22 +444,7 @@ fn successful_request_to_open_ai_chat_completions() {
.unwrap();
// Setup Filter
let filter_context = 1;
let config = serde_json::to_string(&default_config()).unwrap();
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
@ -419,22 +505,7 @@ fn bad_request_to_open_ai_chat_completions() {
.unwrap();
// Setup Filter
let filter_context = 1;
let config = serde_json::to_string(&default_config()).unwrap();
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
@ -496,21 +567,7 @@ fn request_ratelimited() {
.unwrap();
// Setup Filter
let filter_context = 1;
let config = serde_json::to_string(&default_config()).unwrap();
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
let filter_context = setup_filter(&mut module, default_config());
// Setup HTTP Stream
let http_context = 2;
@ -619,24 +676,11 @@ fn request_not_ratelimited() {
.unwrap();
// Setup Filter
let filter_context = 1;
let mut config = default_config();
let mut config: Configuration = serde_yaml::from_str(default_config()).unwrap();
config.ratelimits.as_mut().unwrap()[0].limit.tokens += 1000;
let config_str = serde_json::to_string(&config).unwrap();
module
.call_proxy_on_context_create(filter_context, 0)
.expect_metric_creation(MetricType::Gauge, "active_http_calls")
.expect_metric_creation(MetricType::Counter, "ratelimited_rq")
.execute_and_expect(ReturnType::None)
.unwrap();
module
.call_proxy_on_configure(filter_context, config_str.len() as i32)
.expect_get_buffer_bytes(Some(BufferType::PluginConfiguration))
.returning(Some(&config_str))
.execute_and_expect(ReturnType::Bool(true))
.unwrap();
let filter_context = setup_filter(&mut module, &config_str);
// Setup HTTP Stream
let http_context = 2;