Hallucination integration with rust (#122)

This commit is contained in:
Co Tran 2024-10-07 18:38:55 -07:00 committed by GitHub
parent 43dc2a0a73
commit b1fa127704
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 278 additions and 56 deletions

View file

@ -5,7 +5,7 @@ use proxy_wasm_test_framework::types::{
};
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::common_types::PromptGuardResponse;
use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
use public_types::embeddings::{
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
Embedding,
@ -534,9 +534,9 @@ fn request_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Usage {
usage: Some(Usage {
completion_tokens: 0,
},
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
@ -572,6 +572,38 @@ fn request_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/hallucination"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let hallucatination_body = HallucinationClassificationResponse {
params_scores: HashMap::from([("city".to_string(), 0.99)]),
model: "nli-model".to_string(),
};
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("api_server"),
Some(vec![
@ -585,14 +617,14 @@ fn request_ratelimited() {
None,
None,
)
.returning(Some(5))
.returning(Some(6))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
@ -642,9 +674,9 @@ fn request_not_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let arch_fc_resp = ChatCompletionsResponse {
usage: Usage {
usage: Some(Usage {
completion_tokens: 0,
},
}),
choices: vec![Choice {
finish_reason: "test".to_string(),
index: 0,
@ -680,6 +712,43 @@ fn request_not_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("model_server"),
Some(vec![
(":method", "POST"),
(":path", "/hallucination"),
(":authority", "model_server"),
("content-type", "application/json"),
("x-envoy-max-retries", "3"),
("x-envoy-upstream-rq-timeout-ms", "60000"),
]),
None,
None,
None,
)
.returning(Some(5))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
// hallucination should return that parameters were not halliucinated
// prompt: str
// parameters: dict
// model: str
let hallucatination_body = HallucinationClassificationResponse {
params_scores: HashMap::from([("city".to_string(), 0.99)]),
model: "nli-model".to_string(),
};
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(
Some("api_server"),
Some(vec![
@ -693,14 +762,14 @@ fn request_not_ratelimited() {
None,
None,
)
.returning(Some(5))
.returning(Some(6))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))