mirror of
https://github.com/katanemo/plano.git
synced 2026-06-05 14:45:15 +02:00
Integrate Arch-Function-Calling-1.5B model (#85)
* add arch support * add missing file * e2e tests * delete old files and fix response * fmt
This commit is contained in:
parent
9ea6bb0d73
commit
3511798fa8
12 changed files with 203 additions and 427 deletions
|
|
@ -0,0 +1,21 @@
|
|||
FROM Arch-Function-Calling-1.5B-Q4_K_M.gguf
|
||||
|
||||
# Set parameters for response generation
|
||||
PARAMETER num_predict 1024
|
||||
PARAMETER temperature 0.001
|
||||
PARAMETER top_p 1.0
|
||||
PARAMETER top_k 16000
|
||||
PARAMETER repeat_penalty 1.0
|
||||
PARAMETER stop "<|im_end|>"
|
||||
|
||||
# Set the random number seed to use for generation
|
||||
PARAMETER seed 42
|
||||
|
||||
# Set the prompt template to be passed into the model
|
||||
TEMPLATE """
|
||||
{{- if .System }}<|im_start|>system
|
||||
{{ .System }}<|im_end|>
|
||||
{{ end }}{{ if .Prompt }}<|im_start|>user
|
||||
{{ .Prompt }}<|im_end|>
|
||||
{{ end }}<|im_start|>assistant
|
||||
{{ .Response }}<|im_end|>"""
|
||||
|
|
@ -1,25 +0,0 @@
|
|||
FROM Bolt-Function-Calling-1B-Q3_K_L.gguf
|
||||
|
||||
# Set the size of the context window used to generate the next token
|
||||
# PARAMETER num_ctx 16384
|
||||
PARAMETER num_ctx 4096
|
||||
|
||||
# Set parameters for response generation
|
||||
PARAMETER num_predict 1024
|
||||
PARAMETER temperature 0.1
|
||||
PARAMETER top_p 0.5
|
||||
PARAMETER top_k 32022
|
||||
PARAMETER repeat_penalty 1.0
|
||||
PARAMETER stop "<|EOT|>"
|
||||
|
||||
# Set the random number seed to use for generation
|
||||
PARAMETER seed 42
|
||||
|
||||
# Set the prompt template to be passed into the model
|
||||
TEMPLATE """{{ if .System }}<|begin▁of▁sentence|>
|
||||
{{ .System }}
|
||||
{{ end }}{{ if .Prompt }}### Instruction:
|
||||
{{ .Prompt }}
|
||||
{{ end }}### Response:
|
||||
{{ .Response }}
|
||||
<|EOT|>"""
|
||||
|
|
@ -1,24 +0,0 @@
|
|||
FROM Bolt-Function-Calling-1B-Q4_K_M.gguf
|
||||
|
||||
# Set the size of the context window used to generate the next token
|
||||
PARAMETER num_ctx 4096
|
||||
|
||||
# Set parameters for response generation
|
||||
PARAMETER num_predict 1024
|
||||
PARAMETER temperature 0.1
|
||||
PARAMETER top_p 0.5
|
||||
PARAMETER top_k 32022
|
||||
PARAMETER repeat_penalty 1.0
|
||||
PARAMETER stop "<|EOT|>"
|
||||
|
||||
# Set the random number seed to use for generation
|
||||
PARAMETER seed 42
|
||||
|
||||
# Set the prompt template to be passed into the model
|
||||
TEMPLATE """{{ if .System }}<|begin▁of▁sentence|>
|
||||
{{ .System }}
|
||||
{{ end }}{{ if .Prompt }}### Instruction:
|
||||
{{ .Prompt }}
|
||||
{{ end }}### Response:
|
||||
{{ .Response }}
|
||||
<|EOT|>"""
|
||||
|
|
@ -11,14 +11,14 @@ This demo shows how you can use intelligent prompt gateway to do function callin
|
|||
```sh
|
||||
docker compose up
|
||||
```
|
||||
1. Download Bolt-FC model. This demo assumes we have downloaded [Bolt-Function-Calling-1B:Q4_K_M](https://huggingface.co/katanemolabs/Bolt-Function-Calling-1B.gguf/blob/main/Bolt-Function-Calling-1B-Q4_K_M.gguf) to local folder.
|
||||
1. Download Bolt-FC model. This demo assumes we have downloaded [Arch-Function-Calling-1.5B:Q4_K_M](https://huggingface.co/katanemolabs/Arch-Function-Calling-1.5B.gguf/blob/main/Arch-Function-Calling-1.5B-Q4_K_M.gguf) to local folder.
|
||||
1. If running ollama natively run
|
||||
```sh
|
||||
ollama serve
|
||||
```
|
||||
2. Create model file in ollama repository
|
||||
```sh
|
||||
ollama create Bolt-Function-Calling-1B:Q4_K_M -f Bolt-FC-1B-Q4_K_M.model_file
|
||||
ollama create Arch-Function-Calling-1.5B:Q4_K_M -f Arch-Function-Calling-1.5B-Q4_K_M.model_file
|
||||
```
|
||||
3. Navigate to http://localhost:18080/
|
||||
4. You can type in queries like "how is the weather in Seattle"
|
||||
|
|
|
|||
|
|
@ -59,6 +59,7 @@ services:
|
|||
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
|
||||
# uncomment following line to use ollama endpoint that is hosted by docker
|
||||
# - OLLAMA_ENDPOINT=ollama
|
||||
- OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M
|
||||
|
||||
api_server:
|
||||
build:
|
||||
|
|
|
|||
|
|
@ -15,13 +15,13 @@ use log::{debug, info, warn};
|
|||
use proxy_wasm::traits::*;
|
||||
use proxy_wasm::types::*;
|
||||
use public_types::common_types::open_ai::{
|
||||
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message,
|
||||
StreamOptions,
|
||||
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
|
||||
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
|
||||
StreamOptions, ToolType,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
BoltFCToolsCall, EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest,
|
||||
ZeroShotClassificationResponse,
|
||||
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
|
||||
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
|
||||
};
|
||||
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
|
||||
use public_types::embeddings::{
|
||||
|
|
@ -215,7 +215,8 @@ impl StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing zero shot request: {}", error);
|
||||
let error = format!("Error serializing zero shot request: {}", error);
|
||||
return self.send_server_error(error, None);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -235,10 +236,11 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
|
||||
e
|
||||
);
|
||||
let error_msg = format!(
|
||||
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
|
||||
e
|
||||
);
|
||||
return self.send_server_error(error_msg, None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
|
|
@ -358,16 +360,15 @@ impl StreamContext {
|
|||
|
||||
match prompt_target.prompt_type {
|
||||
PromptType::FunctionResolver => {
|
||||
let mut tools_definitions: Vec<ToolsDefinition> = Vec::new();
|
||||
|
||||
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
|
||||
for pt in self.prompt_targets.read().unwrap().values() {
|
||||
// only extract entity names
|
||||
let properties: HashMap<String, ToolParameter> = match pt.parameters {
|
||||
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
|
||||
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
|
||||
Some(ref entities) => {
|
||||
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
|
||||
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
|
||||
for entity in entities.iter() {
|
||||
let param = ToolParameter {
|
||||
let param = FunctionParameter {
|
||||
parameter_type: entity.parameter_type.clone(),
|
||||
description: entity.description.clone(),
|
||||
required: entity.required,
|
||||
|
|
@ -380,22 +381,24 @@ impl StreamContext {
|
|||
}
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let tools_parameters = ToolParameters {
|
||||
parameters_type: "dict".to_string(),
|
||||
properties,
|
||||
};
|
||||
let tools_parameters = FunctionParameters { properties };
|
||||
|
||||
tools_definitions.push(ToolsDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
chat_completion_tools.push({
|
||||
ChatCompletionTool {
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionDefinition {
|
||||
name: pt.name.clone(),
|
||||
description: pt.description.clone(),
|
||||
parameters: tools_parameters,
|
||||
},
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let chat_completions = ChatCompletionsRequest {
|
||||
model: GPT_35_TURBO.to_string(),
|
||||
messages: callout_context.request_body.messages.clone(),
|
||||
tools: Some(tools_definitions),
|
||||
tools: Some(chat_completion_tools),
|
||||
stream: false,
|
||||
stream_options: None,
|
||||
};
|
||||
|
|
@ -432,7 +435,9 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
let error_msg =
|
||||
format!("Error dispatching HTTP call for function-call: {:?}", e);
|
||||
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -459,33 +464,35 @@ impl StreamContext {
|
|||
|
||||
let boltfc_response: ChatCompletionsResponse = serde_json::from_str(&body_str).unwrap();
|
||||
|
||||
let boltfc_response_str = boltfc_response.choices[0].message.content.as_ref().unwrap();
|
||||
let model_resp = &boltfc_response.choices[0];
|
||||
|
||||
let tools_call_response: BoltFCToolsCall = match serde_json::from_str(boltfc_response_str) {
|
||||
Ok(fc_resp) => fc_resp,
|
||||
Err(e) => {
|
||||
// This means that Bolt FC did not have enough information to resolve the function call
|
||||
// Bolt FC probably responded with a message asking for more information.
|
||||
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
||||
if model_resp.message.tool_calls.is_none() {
|
||||
// This means that Bolt FC did not have enough information to resolve the function call
|
||||
// Bolt FC probably responded with a message asking for more information.
|
||||
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
|
||||
|
||||
// add resolver name to the response so the client can send the response back to the correct resolver
|
||||
info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e);
|
||||
let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap();
|
||||
self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(bolt_fc_dialogue_message.as_bytes()),
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
|
||||
|
||||
debug!("tool_call_details: {}", boltfc_response_str);
|
||||
return self.send_http_response(
|
||||
StatusCode::OK.as_u16().into(),
|
||||
vec![("Powered-By", "Katanemo")],
|
||||
Some(body_str.as_bytes()),
|
||||
);
|
||||
}
|
||||
|
||||
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
|
||||
if tool_calls.is_empty() {
|
||||
return self.send_server_error(
|
||||
"No tool calls found in function resolver response".to_string(),
|
||||
Some(StatusCode::BAD_REQUEST),
|
||||
);
|
||||
}
|
||||
|
||||
debug!("tool_call_details: {:?}", tool_calls);
|
||||
// extract all tool names
|
||||
let tool_names: Vec<String> = tools_call_response
|
||||
.tool_calls
|
||||
let tool_names: Vec<String> = tool_calls
|
||||
.iter()
|
||||
.map(|tool_call| tool_call.name.clone())
|
||||
.map(|tool_call| tool_call.function.name.clone())
|
||||
.collect();
|
||||
|
||||
debug!(
|
||||
|
|
@ -493,8 +500,8 @@ impl StreamContext {
|
|||
callout_context.similarity_scores
|
||||
);
|
||||
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
|
||||
let tool_params = &tools_call_response.tool_calls[0].arguments;
|
||||
let tools_call_name = tools_call_response.tool_calls[0].name.clone();
|
||||
let tool_params = &tool_calls[0].function.arguments;
|
||||
let tools_call_name = tool_calls[0].function.name.clone();
|
||||
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
|
||||
|
||||
let prompt_target = self
|
||||
|
|
@ -581,6 +588,7 @@ impl StreamContext {
|
|||
role: SYSTEM_ROLE.to_string(),
|
||||
content: Some(system_prompt.clone()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
};
|
||||
messages.push(system_prompt_message);
|
||||
}
|
||||
|
|
@ -592,6 +600,7 @@ impl StreamContext {
|
|||
role: USER_ROLE.to_string(),
|
||||
content: Some(body_str),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -601,6 +610,7 @@ impl StreamContext {
|
|||
role: USER_ROLE.to_string(),
|
||||
content: Some(callout_context.user_message.unwrap()),
|
||||
model: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -707,7 +717,8 @@ impl StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
let error_msg = format!("Error serializing embeddings input: {}", error);
|
||||
return self.send_server_error(error_msg, None);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -727,10 +738,11 @@ impl StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
let error_msg = format!(
|
||||
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
|
||||
e
|
||||
);
|
||||
return self.send_server_error(error_msg, None);
|
||||
}
|
||||
};
|
||||
debug!(
|
||||
|
|
@ -892,7 +904,9 @@ impl HttpContext for StreamContext {
|
|||
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
|
||||
Ok(json_data) => json_data,
|
||||
Err(error) => {
|
||||
panic!("Error serializing embeddings input: {}", error);
|
||||
let error_msg = format!("Error serializing prompt guard request: {}", error);
|
||||
self.send_server_error(error_msg, None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -912,10 +926,12 @@ impl HttpContext for StreamContext {
|
|||
) {
|
||||
Ok(token_id) => token_id,
|
||||
Err(e) => {
|
||||
panic!(
|
||||
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
|
||||
let error_msg = format!(
|
||||
"Error dispatching embedding server HTTP call for prompt-guard: {:?}",
|
||||
e
|
||||
);
|
||||
self.send_server_error(error_msg, None);
|
||||
return Action::Pause;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -3,15 +3,14 @@ use proxy_wasm_test_framework::tester::{self, Tester};
|
|||
use proxy_wasm_test_framework::types::{
|
||||
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
|
||||
};
|
||||
use public_types::common_types::{
|
||||
open_ai::{ChatCompletionsResponse, Choice, Message, Usage},
|
||||
BoltFCToolsCall, IntOrString, ToolCallDetail,
|
||||
};
|
||||
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
|
||||
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
|
||||
use public_types::embeddings::embedding::Object;
|
||||
use public_types::embeddings::{
|
||||
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
|
||||
};
|
||||
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
|
||||
use serde_yaml::Value;
|
||||
use serial_test::serial;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
|
|
@ -403,18 +402,6 @@ fn request_ratelimited() {
|
|||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
IntOrString::Text(String::from("seattle")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
let boltfc_tools_call = BoltFCToolsCall {
|
||||
tool_calls: tool_call_detail,
|
||||
};
|
||||
|
||||
let bolt_fc_resp = ChatCompletionsResponse {
|
||||
usage: Usage {
|
||||
completion_tokens: 0,
|
||||
|
|
@ -424,7 +411,18 @@ fn request_ratelimited() {
|
|||
index: 0,
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)]),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
},
|
||||
}],
|
||||
|
|
@ -519,18 +517,6 @@ fn request_not_ratelimited() {
|
|||
|
||||
normal_flow(&mut module, filter_context, http_context);
|
||||
|
||||
let tool_call_detail = vec![ToolCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
IntOrString::Text(String::from("seattle")),
|
||||
)]),
|
||||
}];
|
||||
|
||||
let boltfc_tools_call = BoltFCToolsCall {
|
||||
tool_calls: tool_call_detail,
|
||||
};
|
||||
|
||||
let bolt_fc_resp = ChatCompletionsResponse {
|
||||
usage: Usage {
|
||||
completion_tokens: 0,
|
||||
|
|
@ -540,7 +526,18 @@ fn request_not_ratelimited() {
|
|||
index: 0,
|
||||
message: Message {
|
||||
role: "system".to_string(),
|
||||
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
|
||||
content: None,
|
||||
tool_calls: Some(vec![ToolCall {
|
||||
id: String::from("test"),
|
||||
tool_type: ToolType::Function,
|
||||
function: FunctionCallDetail {
|
||||
name: String::from("weather_forecast"),
|
||||
arguments: HashMap::from([(
|
||||
String::from("city"),
|
||||
Value::String(String::from("seattle")),
|
||||
)]),
|
||||
},
|
||||
}]),
|
||||
model: None,
|
||||
},
|
||||
}],
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ class ArchHandler:
|
|||
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
def convert_tools(tools):
|
||||
return "\n".join([json.dumps(tool) for tool in tools])
|
||||
return "\n".join([json.dumps(tool["function"]) for tool in tools])
|
||||
|
||||
tool_text = convert_tools(tools)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,225 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
|
||||
SYSTEM_PROMPT = """
|
||||
[BEGIN OF TASK INSTRUCTION]
|
||||
You are a function calling assistant with access to the following tools. You task is to assist users as best as you can.
|
||||
For each user query, you may need to call one or more functions to to better generate responses.
|
||||
If none of the functions are relevant, you should point it out.
|
||||
If the given query lacks the parameters required by the function, you should ask users for clarification.
|
||||
The users may execute functions and return results as `Observation` to you. In the case, you MUST generate responses by summarizing it.
|
||||
[END OF TASK INSTRUCTION]
|
||||
""".strip()
|
||||
|
||||
TOOL_PROMPT = """
|
||||
[BEGIN OF AVAILABLE TOOLS]
|
||||
{tool_text}
|
||||
[END OF AVAILABLE TOOLS]
|
||||
""".strip()
|
||||
|
||||
FORMAT_PROMPT = """
|
||||
[BEGIN OF FORMAT INSTRUCTION]
|
||||
You MUST use the following JSON format if using tools.
|
||||
The example format is as follows. DO NOT use this format if no function call is needed.
|
||||
```
|
||||
{
|
||||
"tool_calls": [
|
||||
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
|
||||
... (more tool calls as required)
|
||||
]
|
||||
}
|
||||
```
|
||||
[END OF FORMAT INSTRUCTION]
|
||||
""".strip()
|
||||
|
||||
|
||||
class BoltHandler:
|
||||
def _format_system(self, tools: List[Dict[str, Any]]):
|
||||
tool_text = self._format_tools(tools=tools)
|
||||
return (
|
||||
SYSTEM_PROMPT
|
||||
+ "\n\n"
|
||||
+ TOOL_PROMPT.format(tool_text=tool_text)
|
||||
+ "\n\n"
|
||||
+ FORMAT_PROMPT
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
def _format_tools(self, tools: List[Dict[str, Any]]):
|
||||
TOOL_DESC = "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}"
|
||||
|
||||
tool_text = []
|
||||
for tool in tools:
|
||||
param_text = self.get_param_text(tool.parameters)
|
||||
tool_text.append(
|
||||
TOOL_DESC.format(
|
||||
name=tool.name, desc=tool.description, args=param_text
|
||||
)
|
||||
)
|
||||
|
||||
return "\n".join(tool_text)
|
||||
|
||||
def extract_tools(self, content, executable=False):
|
||||
# retrieve `tool_calls` from model responses
|
||||
try:
|
||||
content_json = json.loads(content)
|
||||
except Exception:
|
||||
fixed_content = self.fix_json_string(content)
|
||||
try:
|
||||
content_json = json.loads(fixed_content)
|
||||
except json.JSONDecodeError:
|
||||
return content
|
||||
|
||||
if isinstance(content_json, list):
|
||||
tool_calls = content_json
|
||||
elif isinstance(content_json, dict):
|
||||
tool_calls = content_json.get("tool_calls", [])
|
||||
else:
|
||||
tool_calls = []
|
||||
|
||||
if not isinstance(tool_calls, list):
|
||||
return content
|
||||
|
||||
# process and extract tools from `tool_calls`
|
||||
extracted = []
|
||||
|
||||
for tool_call in tool_calls:
|
||||
if isinstance(tool_call, dict):
|
||||
try:
|
||||
if not executable:
|
||||
extracted.append({tool_call["name"]: tool_call["arguments"]})
|
||||
else:
|
||||
name, arguments = (
|
||||
tool_call.get("name", ""),
|
||||
tool_call.get("arguments", {}),
|
||||
)
|
||||
|
||||
for key, value in arguments.items():
|
||||
if value == "False" or value == "false":
|
||||
arguments[key] = False
|
||||
elif value == "True" or value == "true":
|
||||
arguments[key] = True
|
||||
|
||||
args_str = ", ".join(
|
||||
[f"{key}={repr(value)}" for key, value in arguments.items()]
|
||||
)
|
||||
|
||||
extracted.append(f"{name}({args_str})")
|
||||
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return extracted
|
||||
|
||||
def get_param_text(self, parameter_dict, prefix=""):
|
||||
param_text = ""
|
||||
|
||||
for name, param in parameter_dict["properties"].items():
|
||||
param_type = param.get("type", "")
|
||||
|
||||
required, default, param_format, properties, enum, items = (
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
"",
|
||||
)
|
||||
|
||||
if name in parameter_dict.get("required", []):
|
||||
required = ", required"
|
||||
|
||||
required_param = parameter_dict.get("required", [])
|
||||
|
||||
if isinstance(required_param, bool):
|
||||
required = ", required" if required_param else ""
|
||||
elif isinstance(required_param, list) and name in required_param:
|
||||
required = ", required"
|
||||
else:
|
||||
required = ", optional"
|
||||
|
||||
default_param = param.get("default", None)
|
||||
if default_param:
|
||||
default = f", default: {default_param}"
|
||||
|
||||
format_in = param.get("format", None)
|
||||
if format_in:
|
||||
param_format = f", format: {format_in}"
|
||||
|
||||
desc = param.get("description", "")
|
||||
|
||||
if "properties" in param:
|
||||
arg_properties = self.get_param_text(param, prefix + " ")
|
||||
properties += "with the properties:\n{}".format(arg_properties)
|
||||
|
||||
enum_param = param.get("enum", None)
|
||||
if enum_param:
|
||||
enum = "should be one of [{}]".format(", ".join(enum_param))
|
||||
|
||||
item_param = param.get("items", None)
|
||||
if item_param:
|
||||
item_type = item_param.get("type", None)
|
||||
if item_type:
|
||||
items += "each item should be the {} type ".format(item_type)
|
||||
|
||||
item_properties = item_param.get("properties", None)
|
||||
if item_properties:
|
||||
item_properties = self.get_param_text(item_param, prefix + " ")
|
||||
items += "with the properties:\n{}".format(item_properties)
|
||||
|
||||
illustration = ", ".join(
|
||||
[x for x in [desc, properties, enum, items] if len(x)]
|
||||
)
|
||||
|
||||
param_text += (
|
||||
prefix
|
||||
+ "- {name} ({param_type}{required}{param_format}{default}): {illustration}\n".format(
|
||||
name=name,
|
||||
param_type=param_type,
|
||||
required=required,
|
||||
param_format=param_format,
|
||||
default=default,
|
||||
illustration=illustration,
|
||||
)
|
||||
)
|
||||
|
||||
return param_text
|
||||
|
||||
def fix_json_string(self, json_str):
|
||||
# Remove any leading or trailing whitespace or newline characters
|
||||
json_str = json_str.strip()
|
||||
|
||||
# Stack to keep track of brackets
|
||||
stack = []
|
||||
|
||||
# Clean string to collect valid characters
|
||||
fixed_str = ""
|
||||
|
||||
# Dictionary for matching brackets
|
||||
matching_bracket = {")": "(", "}": "{", "]": "["}
|
||||
|
||||
# Dictionary for the opposite of matching_bracket
|
||||
opening_bracket = {v: k for k, v in matching_bracket.items()}
|
||||
|
||||
for char in json_str:
|
||||
if char in "{[(":
|
||||
stack.append(char)
|
||||
fixed_str += char
|
||||
elif char in "}])":
|
||||
if stack and stack[-1] == matching_bracket[char]:
|
||||
stack.pop()
|
||||
fixed_str += char
|
||||
else:
|
||||
# Ignore the unmatched closing brackets
|
||||
continue
|
||||
else:
|
||||
fixed_str += char
|
||||
|
||||
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
|
||||
while stack:
|
||||
unmatched_opening = stack.pop()
|
||||
fixed_str += opening_bracket[unmatched_opening]
|
||||
|
||||
# Attempt to parse the corrected string to ensure it’s valid JSON
|
||||
return fixed_str
|
||||
|
|
@ -1,14 +1,10 @@
|
|||
from typing import Any, Dict, List
|
||||
from pydantic import BaseModel
|
||||
|
||||
class Tool(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
parameters: dict
|
||||
|
||||
class Message(BaseModel):
|
||||
role: str
|
||||
content: str
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
messages: list[Message]
|
||||
tools: list[Tool]
|
||||
tools: List[Dict[str, Any]]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import json
|
||||
import random
|
||||
from fastapi import FastAPI, Response
|
||||
from bolt_handler import BoltHandler
|
||||
from arch_handler import ArchHandler
|
||||
from common import ChatMessage
|
||||
import logging
|
||||
|
|
@ -8,15 +8,15 @@ from openai import OpenAI
|
|||
import os
|
||||
|
||||
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
|
||||
ollama_model = os.getenv("OLLAMA_MODEL", "Bolt-Function-Calling-1B:Q4_K_M")
|
||||
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
|
||||
logger = logging.getLogger('uvicorn.error')
|
||||
|
||||
logger.info(f"using model: {ollama_model}")
|
||||
logger.info(f"using ollama endpoint: {ollama_endpoint}")
|
||||
|
||||
app = FastAPI()
|
||||
bolt_handler = BoltHandler()
|
||||
arch_handler = ArchHandler()
|
||||
|
||||
handler = ArchHandler()
|
||||
|
||||
client = OpenAI(
|
||||
base_url='http://{}:11434/v1/'.format(ollama_endpoint),
|
||||
|
|
@ -35,10 +35,6 @@ async def healthz():
|
|||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
logger.info("starting request")
|
||||
if ollama_model.startswith("Bolt"):
|
||||
handler = bolt_handler
|
||||
else:
|
||||
handler = arch_handler
|
||||
tools_encoded = handler._format_system(req.tools)
|
||||
# append system prompt with tools to messages
|
||||
messages = [{"role": "system", "content": tools_encoded}]
|
||||
|
|
@ -46,5 +42,21 @@ async def chat_completion(req: ChatMessage, res: Response):
|
|||
messages.append({"role": message.role, "content": message.content})
|
||||
logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}")
|
||||
resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False)
|
||||
logger.info(f"response: {resp.to_json()}")
|
||||
tools = handler.extract_tools(resp.choices[0].message.content)
|
||||
tool_calls = []
|
||||
for tool in tools:
|
||||
for tool_name, tool_args in tool.items():
|
||||
tool_calls.append({
|
||||
"id": f"call_{random.randint(1000, 10000)}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": tool_args
|
||||
}
|
||||
})
|
||||
if tools:
|
||||
resp.choices[0].message.tool_calls = tool_calls
|
||||
resp.choices[0].message.content = None
|
||||
logger.info(f"response (tools): {json.dumps(tools)}")
|
||||
logger.info(f"response: {json.dumps(resp.to_dict())}")
|
||||
return resp
|
||||
|
|
|
|||
|
|
@ -33,58 +33,11 @@ pub struct SearchPointResult {
|
|||
pub payload: HashMap<String, String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameter {
|
||||
#[serde(rename = "type")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameter_type: Option<String>,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolParameters {
|
||||
#[serde(rename = "type")]
|
||||
pub parameters_type: String,
|
||||
pub properties: HashMap<String, ToolParameter>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolsDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: ToolParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum IntOrString {
|
||||
Integer(i32),
|
||||
Text(String),
|
||||
Float(f64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: HashMap<String, IntOrString>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BoltFCToolsCall {
|
||||
pub tool_calls: Vec<ToolCallDetail>,
|
||||
}
|
||||
|
||||
pub mod open_ai {
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
use super::ToolsDefinition;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_yaml::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsRequest {
|
||||
|
|
@ -92,13 +45,52 @@ pub mod open_ai {
|
|||
pub model: String,
|
||||
pub messages: Vec<Message>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tools: Option<Vec<ToolsDefinition>>,
|
||||
pub tools: Option<Vec<ChatCompletionTool>>,
|
||||
#[serde(default)]
|
||||
pub stream: bool,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub stream_options: Option<StreamOptions>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub enum ToolType {
|
||||
#[serde(rename = "function")]
|
||||
Function,
|
||||
}
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionTool {
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionDefinition,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionDefinition {
|
||||
pub name: String,
|
||||
pub description: String,
|
||||
pub parameters: FunctionParameters,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionParameters {
|
||||
pub properties: HashMap<String, FunctionParameter>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionParameter {
|
||||
#[serde(rename = "type")]
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub parameter_type: Option<String>,
|
||||
pub description: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub required: Option<bool>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
#[serde(rename = "enum")]
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub default: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct StreamOptions {
|
||||
pub include_usage: bool,
|
||||
|
|
@ -110,6 +102,7 @@ pub mod open_ai {
|
|||
pub content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model: Option<String>,
|
||||
pub tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -119,6 +112,20 @@ pub mod open_ai {
|
|||
pub message: Message,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
#[serde(rename = "type")]
|
||||
pub tool_type: ToolType,
|
||||
pub function: FunctionCallDetail,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct FunctionCallDetail {
|
||||
pub name: String,
|
||||
pub arguments: HashMap<String, Value>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatCompletionsResponse {
|
||||
pub usage: Usage,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue