Add function calling support using bolt-fc-1b (#35)

This commit is contained in:
Adil Hafeez 2024-09-10 14:24:46 -07:00 committed by GitHub
parent fdfad87347
commit 7b5203a2ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
39 changed files with 1763 additions and 416 deletions

View file

@ -8,9 +8,14 @@ use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::configuration::{self, Endpoint, PromptTarget};
use public_types::{
common_types::{self, NERResponse, SearchPointResult, SearchPointsResponse},
common_types::{
open_ai::Message, BoltFCResponse, BoltFCToolsCall, IntOrString, ToolCallDetail,
},
configuration::{self, Endpoint, PromptTarget},
};
use public_types::{
common_types::{SearchPointResult, SearchPointsResponse},
configuration::Configuration,
};
use serial_test::serial;
@ -87,6 +92,7 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.expect_http_call(Some("embeddingserver"), None, None, None, None)
.returning(Some(1))
.expect_metric_increment("active_http_calls", 1)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
@ -120,12 +126,14 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
let prompt_target = PromptTarget {
name: String::from("test-prompt-target"),
prompt_type: String::from("test-prompt-type"),
description: None,
prompt_type: configuration::PromptType::FunctionResolver,
few_shot_examples: vec![],
entities: Some(vec![configuration::Entity {
parameters: Some(vec![configuration::Parameter {
name: String::from("test-entity"),
parameter_type: Some(String::from("string")),
description: String::from("test-description"),
required: Some(true),
description: None,
}]),
endpoint: Some(Endpoint {
cluster: String::from("test-endpoint-cluster"),
@ -159,33 +167,13 @@ fn normal_flow(module: &mut Tester, filter_context: i32, http_context: i32) {
.returning(Some(&search_points_response_buffer))
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(Some("nerhost"), None, None, None, None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("bolt_fc_1b"), None, None, None, None)
.returning(Some(3))
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let ner_reponse = NERResponse {
model: String::from("test-model"),
data: vec![common_types::Entity {
score: 0.7,
text: String::from("test-text"),
label: String::from("test-entity"),
}],
};
let ner_response_buffer = serde_json::to_string(&ner_reponse).unwrap();
let upstream_name = prompt_target.endpoint.unwrap().cluster.leak();
module
.call_proxy_on_http_call_response(http_context, 3, 0, ner_response_buffer.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&ner_response_buffer))
.expect_log(Some(LogLevel::Info), None)
.expect_http_call(Some(upstream_name), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap()
}
fn default_config() -> Configuration {
@ -209,7 +197,7 @@ system_prompt: |
- Use miles per hour for wind speed
prompt_targets:
- type: context_resolver
- type: function_resolver
name: weather_forecast
few_shot_examples:
- what is the weather in New York?
@ -221,7 +209,7 @@ prompt_targets:
required: true
description: "The location for which the weather is requested"
- type: context_resolver
- type: function_resolver
name: weather_forecast_2
few_shot_examples:
- what is the weather in New York?
@ -327,6 +315,7 @@ fn successful_request_to_open_ai_chat_completions() {
.returning(Some(chat_completions_request_body))
// TODO: assert that the model field was added.
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::Action(Action::Pause))
.unwrap();
@ -460,14 +449,57 @@ fn request_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let test_body = "test body";
let tool_call_detail = vec![ToolCallDetail {
name: String::from("test-tool"),
arguments: HashMap::from([(
String::from("test-entity"),
IntOrString::Text(String::from("test-value")),
)]),
}];
let boltfc_tools_call = BoltFCToolsCall {
tool_calls: tool_call_detail,
};
let bolt_fc_resp = BoltFCResponse {
model: String::from("test"),
message: Message {
role: String::from("system"),
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
model: None,
},
done_reason: String::from("test"),
done: true,
resolver_name: None,
};
let bolt_fc_resp_str = serde_json::to_string(&bolt_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 4, 0, test_body.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 3, 0, bolt_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(test_body))
.returning(Some(&bolt_fc_resp_str))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_send_local_response(
@ -522,18 +554,61 @@ fn request_not_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let test_body = "test body";
let tool_call_detail = vec![ToolCallDetail {
name: String::from("test-tool"),
arguments: HashMap::from([(
String::from("test-entity"),
IntOrString::Text(String::from("test-value")),
)]),
}];
let boltfc_tools_call = BoltFCToolsCall {
tool_calls: tool_call_detail,
};
let bolt_fc_resp = BoltFCResponse {
model: String::from("test"),
message: Message {
role: String::from("system"),
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
model: None,
},
done_reason: String::from("test"),
done: true,
resolver_name: None,
};
let bolt_fc_resp_str = serde_json::to_string(&bolt_fc_resp).unwrap();
module
.call_proxy_on_http_call_response(http_context, 4, 0, test_body.len() as i32, 0)
.call_proxy_on_http_call_response(http_context, 3, 0, bolt_fc_resp_str.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(test_body))
.returning(Some(&bolt_fc_resp_str))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_http_call(Some("test-endpoint-cluster"), None, None, None, None)
.returning(Some(4))
.expect_metric_increment("active_http_calls", 1)
.execute_and_expect(ReturnType::None)
.unwrap();
let body_text = String::from("test body");
module
.call_proxy_on_http_call_response(http_context, 4, 0, body_text.len() as i32, 0)
.expect_metric_increment("active_http_calls", -1)
.expect_get_buffer_bytes(Some(BufferType::HttpCallResponseBody))
.returning(Some(&body_text))
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Info), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_log(Some(LogLevel::Debug), None)
.expect_set_buffer_bytes(Some(BufferType::HttpRequestBody), None)
.expect_log(Some(LogLevel::Debug), None)
.execute_and_expect(ReturnType::None)
.unwrap();
}