mirror of
https://github.com/katanemo/plano.git
synced 2026-05-15 11:02:39 +02:00
Hallucination integration with rust (#122)
This commit is contained in:
parent
43dc2a0a73
commit
b1fa127704
9 changed files with 278 additions and 56 deletions
|
|
@ -5,7 +5,7 @@ use proxy_wasm_test_framework::types::{
|
|||
};
|
||||
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use public_types::common_types::PromptGuardResponse;
|
||||
use public_types::common_types::{HallucinationClassificationResponse, PromptGuardResponse};
|
||||
use public_types::embeddings::{
|
||||
create_embedding_response, embedding, CreateEmbeddingResponse, CreateEmbeddingResponseUsage,
|
||||
Embedding,
|
||||
|
|
@ -534,9 +534,9 @@ fn request_ratelimited() {
|
|||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Usage {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
},
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
|
|
@ -572,6 +572,38 @@ fn request_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("api_server"),
|
||||
Some(vec![
|
||||
|
|
@ -585,14 +617,14 @@ fn request_ratelimited() {
|
|||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.returning(Some(6))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
|
|
@ -642,9 +674,9 @@ fn request_not_ratelimited() {
|
|||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let arch_fc_resp = ChatCompletionsResponse {
|
||||
usage: Usage {
|
||||
usage: Some(Usage {
|
||||
completion_tokens: 0,
|
||||
},
|
||||
}),
|
||||
choices: vec![Choice {
|
||||
finish_reason: "test".to_string(),
|
||||
index: 0,
|
||||
|
|
@ -680,6 +712,43 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("model_server"),
|
||||
Some(vec![
|
||||
(":method", "POST"),
|
||||
(":path", "/hallucination"),
|
||||
(":authority", "model_server"),
|
||||
("content-type", "application/json"),
|
||||
("x-envoy-max-retries", "3"),
|
||||
("x-envoy-upstream-rq-timeout-ms", "60000"),
|
||||
]),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
// hallucination should return that parameters were not halliucinated
|
||||
// prompt: str
|
||||
// parameters: dict
|
||||
// model: str
|
||||
|
||||
let hallucatination_body = HallucinationClassificationResponse {
|
||||
params_scores: HashMap::from([("city".to_string(), 0.99)]),
|
||||
model: "nli-model".to_string(),
|
||||
};
|
||||
|
||||
let body_text = serde_json::to_string(&hallucatination_body).unwrap();
|
||||
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(
|
||||
Some("api_server"),
|
||||
Some(vec![
|
||||
|
|
@ -693,14 +762,14 @@ fn request_not_ratelimited() {
|
|||
None,
|
||||
None,
|
||||
)
|
||||
.returning(Some(5))
|
||||
.returning(Some(6))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 5, 0, body_text.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 6, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue