Integrate Arch-Function-Calling-1.5B model (#85)

* add arch support

* add missing file

* e2e tests

* delete old files and fix response

* fmt
This commit is contained in:
Adil Hafeez 2024-09-25 23:30:50 -07:00 committed by GitHub
parent 9ea6bb0d73
commit 3511798fa8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 203 additions and 427 deletions

View file

@ -15,13 +15,13 @@ use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message,
StreamOptions,
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
StreamOptions, ToolType,
};
use public_types::common_types::{
BoltFCToolsCall, EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
use public_types::embeddings::{
@ -215,7 +215,8 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing zero shot request: {}", error);
let error = format!("Error serializing zero shot request: {}", error);
return self.send_server_error(error, None);
}
};
@ -235,10 +236,11 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
e
);
let error_msg = format!(
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
e
);
return self.send_server_error(error_msg, None);
}
};
debug!(
@ -358,16 +360,15 @@ impl StreamContext {
match prompt_target.prompt_type {
PromptType::FunctionResolver => {
let mut tools_definitions: Vec<ToolsDefinition> = Vec::new();
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.read().unwrap().values() {
// only extract entity names
let properties: HashMap<String, ToolParameter> = match pt.parameters {
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => {
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = ToolParameter {
let param = FunctionParameter {
parameter_type: entity.parameter_type.clone(),
description: entity.description.clone(),
required: entity.required,
@ -380,22 +381,24 @@ impl StreamContext {
}
None => HashMap::new(),
};
let tools_parameters = ToolParameters {
parameters_type: "dict".to_string(),
properties,
};
let tools_parameters = FunctionParameters { properties };
tools_definitions.push(ToolsDefinition {
name: pt.name.clone(),
description: pt.description.clone(),
parameters: tools_parameters,
chat_completion_tools.push({
ChatCompletionTool {
tool_type: ToolType::Function,
function: FunctionDefinition {
name: pt.name.clone(),
description: pt.description.clone(),
parameters: tools_parameters,
},
}
});
}
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(tools_definitions),
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
};
@ -432,7 +435,9 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for function-call: {:?}", e);
let error_msg =
format!("Error dispatching HTTP call for function-call: {:?}", e);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
@ -459,33 +464,35 @@ impl StreamContext {
let boltfc_response: ChatCompletionsResponse = serde_json::from_str(&body_str).unwrap();
let boltfc_response_str = boltfc_response.choices[0].message.content.as_ref().unwrap();
let model_resp = &boltfc_response.choices[0];
let tools_call_response: BoltFCToolsCall = match serde_json::from_str(boltfc_response_str) {
Ok(fc_resp) => fc_resp,
Err(e) => {
// This means that Bolt FC did not have enough information to resolve the function call
// Bolt FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
if model_resp.message.tool_calls.is_none() {
// This means that Bolt FC did not have enough information to resolve the function call
// Bolt FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
// add resolver name to the response so the client can send the response back to the correct resolver
info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e);
let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap();
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(bolt_fc_dialogue_message.as_bytes()),
);
return;
}
};
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
debug!("tool_call_details: {}", boltfc_response_str);
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(body_str.as_bytes()),
);
}
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
if tool_calls.is_empty() {
return self.send_server_error(
"No tool calls found in function resolver response".to_string(),
Some(StatusCode::BAD_REQUEST),
);
}
debug!("tool_call_details: {:?}", tool_calls);
// extract all tool names
let tool_names: Vec<String> = tools_call_response
.tool_calls
let tool_names: Vec<String> = tool_calls
.iter()
.map(|tool_call| tool_call.name.clone())
.map(|tool_call| tool_call.function.name.clone())
.collect();
debug!(
@ -493,8 +500,8 @@ impl StreamContext {
callout_context.similarity_scores
);
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
let tool_params = &tools_call_response.tool_calls[0].arguments;
let tools_call_name = tools_call_response.tool_calls[0].name.clone();
let tool_params = &tool_calls[0].function.arguments;
let tools_call_name = tool_calls[0].function.name.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let prompt_target = self
@ -581,6 +588,7 @@ impl StreamContext {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
};
messages.push(system_prompt_message);
}
@ -592,6 +600,7 @@ impl StreamContext {
role: USER_ROLE.to_string(),
content: Some(body_str),
model: None,
tool_calls: None,
}
});
@ -601,6 +610,7 @@ impl StreamContext {
role: USER_ROLE.to_string(),
content: Some(callout_context.user_message.unwrap()),
model: None,
tool_calls: None,
}
});
@ -707,7 +717,8 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
let error_msg = format!("Error serializing embeddings input: {}", error);
return self.send_server_error(error_msg, None);
}
};
@ -727,10 +738,11 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
let error_msg = format!(
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
e
);
return self.send_server_error(error_msg, None);
}
};
debug!(
@ -892,7 +904,9 @@ impl HttpContext for StreamContext {
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
let error_msg = format!("Error serializing prompt guard request: {}", error);
self.send_server_error(error_msg, None);
return Action::Pause;
}
};
@ -912,10 +926,12 @@ impl HttpContext for StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
let error_msg = format!(
"Error dispatching embedding server HTTP call for prompt-guard: {:?}",
e
);
self.send_server_error(error_msg, None);
return Action::Pause;
}
};

View file

@ -3,15 +3,14 @@ use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::common_types::{
open_ai::{ChatCompletionsResponse, Choice, Message, Usage},
BoltFCToolsCall, IntOrString, ToolCallDetail,
};
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::embeddings::embedding::Object;
use public_types::embeddings::{
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
@ -403,18 +402,6 @@ fn request_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
let boltfc_tools_call = BoltFCToolsCall {
tool_calls: tool_call_detail,
};
let bolt_fc_resp = ChatCompletionsResponse {
usage: Usage {
completion_tokens: 0,
@ -424,7 +411,18 @@ fn request_ratelimited() {
index: 0,
message: Message {
role: "system".to_string(),
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],
@ -519,18 +517,6 @@ fn request_not_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
let boltfc_tools_call = BoltFCToolsCall {
tool_calls: tool_call_detail,
};
let bolt_fc_resp = ChatCompletionsResponse {
usage: Usage {
completion_tokens: 0,
@ -540,7 +526,18 @@ fn request_not_ratelimited() {
index: 0,
message: Message {
role: "system".to_string(),
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],