mirror of
https://github.com/katanemo/plano.git
synced 2026-05-21 13:55:15 +02:00
Integrate Arch-Function-Calling-1.5B model (#85)
* add arch support * add missing file * e2e tests * delete old files and fix response * fmt
This commit is contained in:
parent
9ea6bb0d73
commit
3511798fa8
12 changed files with 203 additions and 427 deletions
|
|
@ -15,13 +15,13 @@ use log::{debug, info, warn};
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::open_ai::{
|
||||
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message,
|
||||
StreamOptions,
|
||||
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
|
||||
StreamOptions, ToolType,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
BoltFCToolsCall, EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest,
|
||||
ZeroShotClassificationResponse,
|
||||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
|
||||
use public_types::embeddings::{
|
||||
|
|
@ -215,7 +215,8 @@ impl StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing zero shot request: {}", error);
|
||||
let error = format!("Error serializing zero shot request: {}", error);
|
||||
return self.send_server_error(error, None);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -235,10 +236,11 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
|
||||
e
|
||||
);
|
||||
let error_msg = format!(
|
||||
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
|
||||
e
|
||||
);
|
||||
return self.send_server_error(error_msg, None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
|
|
@ -358,16 +360,15 @@ impl StreamContext {
|
|||
|
||||
match prompt_target.prompt_type {
|
||||
PromptType::FunctionResolver => {
|
||||
let mut tools_definitions: Vec<ToolsDefinition> = Vec::new();
|
||||
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, ToolParameter> = match pt.parameters {
|
||||
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = ToolParameter {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: entity.parameter_type.clone(),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
|
|
@ -380,22 +381,24 @@ impl StreamContext {
|
|||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = ToolParameters {
|
||||
parameters_type: "dict".to_string(),
|
||||
properties,
|
||||
};
|
||||
let tools_parameters = FunctionParameters { properties };
|
||||
|
||||
tools_definitions.push(ToolsDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
chat_completion_tools.push({
|
||||
ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(tools_definitions),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
};
|
||||
|
|
@ -432,7 +435,9 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
let error_msg =
|
||||
format!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -459,33 +464,35 @@ impl StreamContext {
|
|||
|
||||
let boltfc_response: ChatCompletionsResponse = serde_json::from_str(&body_str).unwrap();
|
||||
|
||||
let boltfc_response_str = boltfc_response.choices[0].message.content.as_ref().unwrap();
|
||||
let model_resp = &boltfc_response.choices[0];
|
||||
|
||||
let tools_call_response: BoltFCToolsCall = match serde_json::from_str(boltfc_response_str) {
|
||||
Ok(fc_resp) => fc_resp,
|
||||
Err(e) => {
|
||||
// This means that Bolt FC did not have enough information to resolve the function call
|
||||
// Bolt FC probably responded with a message asking for more information.
|
||||
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
||||
if model_resp.message.tool_calls.is_none() {
|
||||
// This means that Bolt FC did not have enough information to resolve the function call
|
||||
// Bolt FC probably responded with a message asking for more information.
|
||||
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
||||
|
||||
// add resolver name to the response so the client can send the response back to the correct resolver
|
||||
info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e);
|
||||
let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap();
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(bolt_fc_dialogue_message.as_bytes()),
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
|
||||
|
||||
debug!("tool_call_details: {}", boltfc_response_str);
|
||||
return self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(body_str.as_bytes()),
|
||||
);
|
||||
}
|
||||
|
||||
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
|
||||
if tool_calls.is_empty() {
|
||||
return self.send_server_error(
|
||||
"No tool calls found in function resolver response".to_string(),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
|
||||
debug!("tool_call_details: {:?}", tool_calls);
|
||||
// extract all tool names
|
||||
let tool_names: Vec<String> = tools_call_response
|
||||
.tool_calls
|
||||
let tool_names: Vec<String> = tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| tool_call.name.clone())
|
||||
.map(|tool_call| tool_call.function.name.clone())
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
|
|
@ -493,8 +500,8 @@ impl StreamContext {
|
|||
callout_context.similarity_scores
|
||||
);
|
||||
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
|
||||
let tool_params = &tools_call_response.tool_calls[0].arguments;
|
||||
let tools_call_name = tools_call_response.tool_calls[0].name.clone();
|
||||
let tool_params = &tool_calls[0].function.arguments;
|
||||
let tools_call_name = tool_calls[0].function.name.clone();
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
||||
let prompt_target = self
|
||||
|
|
@ -581,6 +588,7 @@ impl StreamContext {
|
|||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
|
@ -592,6 +600,7 @@ impl StreamContext {
|
|||
role: USER_ROLE.to_string(),
|
||||
content: Some(body_str),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -601,6 +610,7 @@ impl StreamContext {
|
|||
role: USER_ROLE.to_string(),
|
||||
content: Some(callout_context.user_message.unwrap()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -707,7 +717,8 @@ impl StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
let error_msg = format!("Error serializing embeddings input: {}", error);
|
||||
return self.send_server_error(error_msg, None);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -727,10 +738,11 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
let error_msg = format!(
|
||||
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
|
||||
e
|
||||
);
|
||||
return self.send_server_error(error_msg, None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
|
|
@ -892,7 +904,9 @@ impl HttpContext for StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
let error_msg = format!("Error serializing prompt guard request: {}", error);
|
||||
self.send_server_error(error_msg, None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -912,10 +926,12 @@ impl HttpContext for StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
|
||||
let error_msg = format!(
|
||||
"Error dispatching embedding server HTTP call for prompt-guard: {:?}",
|
||||
e
|
||||
);
|
||||
self.send_server_error(error_msg, None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -3,15 +3,14 @@ use proxy_wasm_test_framework::tester::{self, Tester};
|
|||
use proxy_wasm_test_framework::types::{
|
||||
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
open_ai::{ChatCompletionsResponse, Choice, Message, Usage},
|
||||
BoltFCToolsCall, IntOrString, ToolCallDetail,
|
||||
};
|
||||
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use public_types::embeddings::embedding::Object;
|
||||
use public_types::embeddings::{
|
||||
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
|
||||
};
|
||||
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
||||
use serde_yaml::Value;
|
||||
use serial_test::serial;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
|
@ -403,18 +402,6 @@ fn request_ratelimited() {
|
|||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
IntOrString::Text(String::from("seattle")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
let boltfc_tools_call = BoltFCToolsCall {
|
||||
tool_calls: tool_call_detail,
|
||||
};
|
||||
|
||||
let bolt_fc_resp = ChatCompletionsResponse {
|
||||
usage: Usage {
|
||||
completion_tokens: 0,
|
||||
|
|
@ -424,7 +411,18 @@ fn request_ratelimited() {
|
|||
index: 0,
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)]),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
},
|
||||
}],
|
||||
|
|
@ -519,18 +517,6 @@ fn request_not_ratelimited() {
|
|||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
IntOrString::Text(String::from("seattle")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
let boltfc_tools_call = BoltFCToolsCall {
|
||||
tool_calls: tool_call_detail,
|
||||
};
|
||||
|
||||
let bolt_fc_resp = ChatCompletionsResponse {
|
||||
usage: Usage {
|
||||
completion_tokens: 0,
|
||||
|
|
@ -540,7 +526,18 @@ fn request_not_ratelimited() {
|
|||
index: 0,
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)]),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
},
|
||||
}],
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue