Improve prompt target intent matching (#51)

This commit is contained in:
Adil Hafeez 2024-09-16 19:20:07 -07:00 committed by GitHub
parent 8565462ec4
commit 9e50957f22
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 461 additions and 415 deletions

View file

@ -8,16 +8,10 @@ use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::{
common_types::{
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
},
configuration::{self, Endpoint, PromptTarget},
};
use public_types::{
common_types::{SearchPointResult, SearchPointsResponse},
configuration::Configuration,
use public_types::common_types::{
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
@ -118,59 +112,39 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&embeddings_response_buffer))
.expect_http_call(Some("qdrant"), None, None, None, None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("embeddingserver"), None, None, None, None)
.returning(Some(2))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let prompt_target = PromptTarget {
name: String::from("test-prompt-target"),
description: None,
prompt_type: configuration::PromptType::FunctionResolver,
few_shot_examples: vec![],
parameters: Some(vec![configuration::Parameter {
name: String::from("test-entity"),
parameter_type: Some(String::from("string")),
description: String::from("test-description"),
required: Some(true),
}]),
endpoint: Some(Endpoint {
cluster: String::from("test-endpoint-cluster"),
path: None,
method: None,
}),
system_prompt: None,
let zero_shot_response = ZeroShotClassificationResponse {
predicted_class: "weather_forecast".to_string(),
predicted_class_score: 0.1,
scores: HashMap::new(),
model: "test-model".to_string(),
};
let prompt_target_str = serde_json::to_string(&prompt_target).unwrap();
let search_points_response = SearchPointsResponse {
status: String::new(),
time: 0.0,
result: vec![SearchPointResult {
id: String::new(),
version: 0,
score: 0.7,
payload: HashMap::from([(String::from("prompt-target"), prompt_target_str)]),
}],
};
let search_points_response_buffer = serde_json::to_string(&search_points_response).unwrap();
let zeroshot_intent_detection_buffer = serde_json::to_string(&zero_shot_response).unwrap();
module
.call_proxy_on_http_call_response(
http_context,
2,
0,
search_points_response_buffer.len() as i32,
zeroshot_intent_detection_buffer.len() as i32,
0,
)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&search_points_response_buffer))
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Info), None)
.returning(Some(&zeroshot_intent_detection_buffer))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(Some("bolt_fc_1b"), None, None, None, None)
.returning(Some(3))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
@ -199,20 +173,22 @@ system_prompt: |
prompt_targets:
- type: function_resolver
name: weather_forecast
few_shot_examples:
- what is the weather in New York?
description: This resolver provides weather forecast information.
endpoint:
cluster: weatherhost
path: /weather
entities:
- name: location
parameters:
- name: city
required: true
description: "The location for which the weather is requested"
description: The city for which the weather forecast is requested.
- name: days
description: The number of days for which the weather forecast is requested.
- name: units
description: The units in which the weather forecast is requested.
- type: function_resolver
name: weather_forecast_2
few_shot_examples:
- what is the weather in New York?
description: This resolver provides weather forecast information.
endpoint:
cluster: weatherhost
path: /weather
@ -450,10 +426,10 @@ fn request_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("test-tool"),
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("test-entity"),
IntOrString::Text(String::from("test-value")),
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
@ -485,7 +461,7 @@ fn request_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
.expect_http_call(Some("weatherhost"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
@ -555,10 +531,10 @@ fn request_not_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("test-tool"),
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("test-entity"),
IntOrString::Text(String::from("test-value")),
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
@ -590,7 +566,7 @@ fn request_not_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
.expect_http_call(Some("weatherhost"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
@ -608,7 +584,7 @@ fn request_not_ratelimited() {
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.expect_log(Some(LogLevel::Debug), None)
// .expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
}