mirror of
https://github.com/katanemo/plano.git
synced 2026-04-26 01:06:25 +02:00
Add function calling support using bolt-fc-1b (#35)
This commit is contained in:
parent
fdfad87347
commit
7b5203a2ce
39 changed files with 1763 additions and 416 deletions
|
|
@ -8,9 +8,14 @@ use proxy_wasm_test_framework::tester::{self, Tester};
|
|||
use proxy_wasm_test_framework::types::{
|
||||
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
||||
};
|
||||
use public_types::configuration::{self, Endpoint, PromptTarget};
|
||||
use public_types::{
|
||||
common_types::{self, NERResponse, SearchPointResult, SearchPointsResponse},
|
||||
common_types::{
|
||||
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
|
||||
},
|
||||
configuration::{self, Endpoint, PromptTarget},
|
||||
};
|
||||
use public_types::{
|
||||
common_types::{SearchPointResult, SearchPointsResponse},
|
||||
configuration::Configuration,
|
||||
};
|
||||
use serial_test::serial;
|
||||
|
|
@ -87,6 +92,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.expect_http_call(Some("embeddingserver"), None, None, None, None)
|
||||
.returning(Some(1))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -120,12 +126,14 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
|
||||
let prompt_target = PromptTarget {
|
||||
name: String::from("test-prompt-target"),
|
||||
prompt_type: String::from("test-prompt-type"),
|
||||
description: None,
|
||||
prompt_type: configuration::PromptType::FunctionResolver,
|
||||
few_shot_examples: vec![],
|
||||
entities: Some(vec![configuration::Entity {
|
||||
parameters: Some(vec![configuration::Parameter {
|
||||
name: String::from("test-entity"),
|
||||
parameter_type: Some(String::from("string")),
|
||||
description: String::from("test-description"),
|
||||
required: Some(true),
|
||||
description: None,
|
||||
}]),
|
||||
endpoint: Some(Endpoint {
|
||||
cluster: String::from("test-endpoint-cluster"),
|
||||
|
|
@ -159,33 +167,13 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
|
|||
.returning(Some(&search_points_response_buffer))
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(Some("nerhost"), None, None, None, None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("bolt_fc_1b"), None, None, None, None)
|
||||
.returning(Some(3))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let ner_reponse = NERResponse {
|
||||
model: String::from("test-model"),
|
||||
data: vec![common_types::Entity {
|
||||
score: 0.7,
|
||||
text: String::from("test-text"),
|
||||
label: String::from("test-entity"),
|
||||
}],
|
||||
};
|
||||
let ner_response_buffer = serde_json::to_string(&ner_reponse).unwrap();
|
||||
let upstream_name = prompt_target.endpoint.unwrap().cluster.leak();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 3, 0, ner_response_buffer.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&ner_response_buffer))
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_http_call(Some(upstream_name), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
fn default_config() -> Configuration {
|
||||
|
|
@ -209,7 +197,7 @@ system_prompt: |
|
|||
- Use miles per hour for wind speed
|
||||
|
||||
prompt_targets:
|
||||
- type: context_resolver
|
||||
- type: function_resolver
|
||||
name: weather_forecast
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
|
|
@ -221,7 +209,7 @@ prompt_targets:
|
|||
required: true
|
||||
description: "The location for which the weather is requested"
|
||||
|
||||
- type: context_resolver
|
||||
- type: function_resolver
|
||||
name: weather_forecast_2
|
||||
few_shot_examples:
|
||||
- what is the weather in New York?
|
||||
|
|
@ -327,6 +315,7 @@ fn successful_request_to_open_ai_chat_completions() {
|
|||
.returning(Some(chat_completions_request_body))
|
||||
// TODO: assert that the model field was added.
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::Action(Action::Pause))
|
||||
.unwrap();
|
||||
|
|
@ -460,14 +449,57 @@ fn request_ratelimited() {
|
|||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let test_body = "test body";
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("test-tool"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("test-entity"),
|
||||
IntOrString::Text(String::from("test-value")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
let boltfc_tools_call = BoltFCToolsCall {
|
||||
tool_calls: tool_call_detail,
|
||||
};
|
||||
|
||||
let bolt_fc_resp = BoltFCResponse {
|
||||
model: String::from("test"),
|
||||
message: Message {
|
||||
role: String::from("system"),
|
||||
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
|
||||
model: None,
|
||||
},
|
||||
done_reason: String::from("test"),
|
||||
done: true,
|
||||
resolver_name: None,
|
||||
};
|
||||
|
||||
let bolt_fc_resp_str = serde_json::to_string(&bolt_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, test_body.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 3, 0, bolt_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(test_body))
|
||||
.returning(Some(&bolt_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_send_local_response(
|
||||
|
|
@ -522,18 +554,61 @@ fn request_not_ratelimited() {
|
|||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let test_body = "test body";
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("test-tool"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("test-entity"),
|
||||
IntOrString::Text(String::from("test-value")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
let boltfc_tools_call = BoltFCToolsCall {
|
||||
tool_calls: tool_call_detail,
|
||||
};
|
||||
|
||||
let bolt_fc_resp = BoltFCResponse {
|
||||
model: String::from("test"),
|
||||
message: Message {
|
||||
role: String::from("system"),
|
||||
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
|
||||
model: None,
|
||||
},
|
||||
done_reason: String::from("test"),
|
||||
done: true,
|
||||
resolver_name: None,
|
||||
};
|
||||
|
||||
let bolt_fc_resp_str = serde_json::to_string(&bolt_fc_resp).unwrap();
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, test_body.len() as i32, 0)
|
||||
.call_proxy_on_http_call_response(http_context, 3, 0, bolt_fc_resp_str.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(test_body))
|
||||
.returning(Some(&bolt_fc_resp_str))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
|
||||
.returning(Some(4))
|
||||
.expect_metric_increment("active_http_calls", 1)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
|
||||
let body_text = String::from("test body");
|
||||
module
|
||||
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
|
||||
.expect_metric_increment("active_http_calls", -1)
|
||||
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
|
||||
.returning(Some(&body_text))
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Info), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
|
||||
.expect_log(Some(LogLevel::Debug), None)
|
||||
.execute_and_expect(ReturnType::None)
|
||||
.unwrap();
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue