mirror of
https://github.com/katanemo/plano.git
synced 2026-04-28 10:26:36 +02:00
Improve prompt target intent matching (#51)
This commit is contained in:
parent
8565462ec4
commit
9e50957f22
14 changed files with 461 additions and 415 deletions
|
|
@ -8,16 +8,10 @@ use proxy_wasm_test_framework::tester::{self, Tester};
|
|||
use proxy_wasm_test_framework::types::{
|
||||
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
||||
};
|
||||
use public_types::{
|
||||
common_types::{
|
||||
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
|
||||
},
|
||||
configuration::{self, Endpoint, PromptTarget},
|
||||
};
|
||||
use public_types::{
|
||||
common_types::{SearchPointResult, SearchPointsResponse},
|
||||
configuration::Configuration,
|
||||
use public_types::common_types::{
|
||||
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
|
||||
};
|
||||
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
||||
use serial_test::serial;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
|
@ -118,59 +112,39 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&embeddings_response_buffer))
|
||||
.expect_http_call(Some("qdrant"), None, None, None, None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("embeddingserver"), None, None, None, None)
|
||||
.returning(Some(2))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let prompt_target = PromptTarget {
|
||||
name: String::from("test-prompt-target"),
|
||||
description: None,
|
||||
prompt_type: configuration::PromptType::FunctionResolver,
|
||||
few_shot_examples: vec![],
|
||||
parameters: Some(vec![configuration::Parameter {
|
||||
name: String::from("test-entity"),
|
||||
parameter_type: Some(String::from("string")),
|
||||
description: String::from("test-description"),
|
||||
required: Some(true),
|
||||
}]),
|
||||
endpoint: Some(Endpoint {
|
||||
cluster: String::from("test-endpoint-cluster"),
|
||||
path: None,
|
||||
method: None,
|
||||
}),
|
||||
system_prompt: None,
|
||||
let zero_shot_response = ZeroShotClassificationResponse {
|
||||
predicted_class: "weather_forecast".to_string(),
|
||||
predicted_class_score: 0.1,
|
||||
scores: HashMap::new(),
|
||||
model: "test-model".to_string(),
|
||||
};
|
||||
let prompt_target_str = serde_json::to_string(&prompt_target).unwrap();
|
||||
let search_points_response = SearchPointsResponse {
|
||||
status: String::new(),
|
||||
time: 0.0,
|
||||
result: vec![SearchPointResult {
|
||||
id: String::new(),
|
||||
version: 0,
|
||||
score: 0.7,
|
||||
payload: HashMap::from([(String::from("prompt-target"), prompt_target_str)]),
|
||||
}],
|
||||
};
|
||||
let search_points_response_buffer = serde_json::to_string(&search_points_response).unwrap();
|
||||
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(
|
||||
http_context,
|
||||
2,
|
||||
0,
|
||||
search_points_response_buffer.len() as i32,
|
||||
zeroshot_intent_detection_buffer.len() as i32,
|
||||
0,
|
||||
)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&search_points_response_buffer))
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.returning(Some(&zeroshot_intent_detection_buffer))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(Some("bolt_fc_1b"), None, None, None, None)
|
||||
.returning(Some(3))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
|
@ -199,20 +173,22 @@ system_prompt: |
|
|||
prompt_targets:
|
||||
- type: function_resolver
|
||||
name: weather_forecast
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
description: This resolver provides weather forecast information.
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
entities:
|
||||
- name: location
|
||||
parameters:
|
||||
- name: city
|
||||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
description: The city for which the weather forecast is requested.
|
||||
- name: days
|
||||
description: The number of days for which the weather forecast is requested.
|
||||
- name: units
|
||||
description: The units in which the weather forecast is requested.
|
||||
|
||||
- type: function_resolver
|
||||
name: weather_forecast_2
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
description: This resolver provides weather forecast information.
|
||||
endpoint:
|
||||
cluster: weatherhost
|
||||
path: /weather
|
||||
|
|
@ -450,10 +426,10 @@ fn request_ratelimited() {
|
|||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("test-tool"),
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("test-entity"),
|
||||
IntOrString::Text(String::from("test-value")),
|
||||
String::from("city"),
|
||||
IntOrString::Text(String::from("seattle")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
|
|
@ -485,7 +461,7 @@ fn request_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
|
||||
.expect_http_call(Some("weatherhost"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
|
|
@ -555,10 +531,10 @@ fn request_not_ratelimited() {
|
|||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("test-tool"),
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("test-entity"),
|
||||
IntOrString::Text(String::from("test-value")),
|
||||
String::from("city"),
|
||||
IntOrString::Text(String::from("seattle")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
|
|
@ -590,7 +566,7 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
|
||||
.expect_http_call(Some("weatherhost"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
|
|
@ -608,7 +584,7 @@ fn request_not_ratelimited() {
|
|||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
// .expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue