Integrate Arch-Function-Calling-1.5B model (#85)

* add arch support

* add missing file

* e2e tests

* delete old files and fix response

* fmt
This commit is contained in:
Adil Hafeez 2024-09-25 23:30:50 -07:00 committed by GitHub
parent 9ea6bb0d73
commit 3511798fa8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 203 additions and 427 deletions

View file

@ -0,0 +1,21 @@
FROM Arch-Function-Calling-1.5B-Q4_K_M.gguf
# Set parameters for response generation
PARAMETER num_predict 1024
PARAMETER temperature 0.001
PARAMETER top_p 1.0
PARAMETER top_k 16000
PARAMETER repeat_penalty 1.0
PARAMETER stop "<|im_end|>"
# Set the random number seed to use for generation
PARAMETER seed 42
# Set the prompt template to be passed into the model
TEMPLATE """
{{- if .System }}<|im_start|>system
{{ .System }}<|im_end|>
{{ end }}{{ if .Prompt }}<|im_start|>user
{{ .Prompt }}<|im_end|>
{{ end }}<|im_start|>assistant
{{ .Response }}<|im_end|>"""

View file

@ -1,25 +0,0 @@
FROM Bolt-Function-Calling-1B-Q3_K_L.gguf
# Set the size of the context window used to generate the next token
# PARAMETER num_ctx 16384
PARAMETER num_ctx 4096
# Set parameters for response generation
PARAMETER num_predict 1024
PARAMETER temperature 0.1
PARAMETER top_p 0.5
PARAMETER top_k 32022
PARAMETER repeat_penalty 1.0
PARAMETER stop "<|EOT|>"
# Set the random number seed to use for generation
PARAMETER seed 42
# Set the prompt template to be passed into the model
TEMPLATE """{{ if .System }}<begin▁of▁sentence>
{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
{{ .Prompt }}
{{ end }}### Response:
{{ .Response }}
<|EOT|>"""

View file

@ -1,24 +0,0 @@
FROM Bolt-Function-Calling-1B-Q4_K_M.gguf
# Set the size of the context window used to generate the next token
PARAMETER num_ctx 4096
# Set parameters for response generation
PARAMETER num_predict 1024
PARAMETER temperature 0.1
PARAMETER top_p 0.5
PARAMETER top_k 32022
PARAMETER repeat_penalty 1.0
PARAMETER stop "<|EOT|>"
# Set the random number seed to use for generation
PARAMETER seed 42
# Set the prompt template to be passed into the model
TEMPLATE """{{ if .System }}<begin▁of▁sentence>
{{ .System }}
{{ end }}{{ if .Prompt }}### Instruction:
{{ .Prompt }}
{{ end }}### Response:
{{ .Response }}
<|EOT|>"""

View file

@ -11,14 +11,14 @@ This demo shows how you can use intelligent prompt gateway to do function callin
```sh
docker compose up
```
1. Download Bolt-FC model. This demo assumes we have downloaded [Bolt-Function-Calling-1B:Q4_K_M](https://huggingface.co/katanemolabs/Bolt-Function-Calling-1B.gguf/blob/main/Bolt-Function-Calling-1B-Q4_K_M.gguf) to local folder.
1. Download Bolt-FC model. This demo assumes we have downloaded [Arch-Function-Calling-1.5B:Q4_K_M](https://huggingface.co/katanemolabs/Arch-Function-Calling-1.5B.gguf/blob/main/Arch-Function-Calling-1.5B-Q4_K_M.gguf) to local folder.
1. If running ollama natively run
```sh
ollama serve
```
2. Create model file in ollama repository
```sh
ollama create Bolt-Function-Calling-1B:Q4_K_M -f Bolt-FC-1B-Q4_K_M.model_file
ollama create Arch-Function-Calling-1.5B:Q4_K_M -f Arch-Function-Calling-1.5B-Q4_K_M.model_file
```
3. Navigate to http://localhost:18080/
4. You can type in queries like "how is the weather in Seattle"

View file

@ -59,6 +59,7 @@ services:
- OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal}
# uncomment following line to use ollama endpoint that is hosted by docker
# - OLLAMA_ENDPOINT=ollama
- OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M
api_server:
build:

View file

@ -15,13 +15,13 @@ use log::{debug, info, warn};
use proxy_wasm::traits::*;
use proxy_wasm::types::*;
use public_types::common_types::open_ai::{
ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message,
StreamOptions,
ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest,
ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message,
StreamOptions, ToolType,
};
use public_types::common_types::{
BoltFCToolsCall, EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest,
ZeroShotClassificationResponse,
EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask,
ZeroShotClassificationRequest, ZeroShotClassificationResponse,
};
use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType};
use public_types::embeddings::{
@ -215,7 +215,8 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&zero_shot_classification_request) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing zero shot request: {}", error);
let error = format!("Error serializing zero shot request: {}", error);
return self.send_server_error(error, None);
}
};
@ -235,10 +236,11 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
e
);
let error_msg = format!(
"Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}",
e
);
return self.send_server_error(error_msg, None);
}
};
debug!(
@ -358,16 +360,15 @@ impl StreamContext {
match prompt_target.prompt_type {
PromptType::FunctionResolver => {
let mut tools_definitions: Vec<ToolsDefinition> = Vec::new();
let mut chat_completion_tools: Vec<ChatCompletionTool> = Vec::new();
for pt in self.prompt_targets.read().unwrap().values() {
// only extract entity names
let properties: HashMap<String, ToolParameter> = match pt.parameters {
let properties: HashMap<String, FunctionParameter> = match pt.parameters {
// Clone is unavoidable here because we don't want to move the values out of the prompt target struct.
Some(ref entities) => {
let mut properties: HashMap<String, ToolParameter> = HashMap::new();
let mut properties: HashMap<String, FunctionParameter> = HashMap::new();
for entity in entities.iter() {
let param = ToolParameter {
let param = FunctionParameter {
parameter_type: entity.parameter_type.clone(),
description: entity.description.clone(),
required: entity.required,
@ -380,22 +381,24 @@ impl StreamContext {
}
None => HashMap::new(),
};
let tools_parameters = ToolParameters {
parameters_type: "dict".to_string(),
properties,
};
let tools_parameters = FunctionParameters { properties };
tools_definitions.push(ToolsDefinition {
name: pt.name.clone(),
description: pt.description.clone(),
parameters: tools_parameters,
chat_completion_tools.push({
ChatCompletionTool {
tool_type: ToolType::Function,
function: FunctionDefinition {
name: pt.name.clone(),
description: pt.description.clone(),
parameters: tools_parameters,
},
}
});
}
let chat_completions = ChatCompletionsRequest {
model: GPT_35_TURBO.to_string(),
messages: callout_context.request_body.messages.clone(),
tools: Some(tools_definitions),
tools: Some(chat_completion_tools),
stream: false,
stream_options: None,
};
@ -432,7 +435,9 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!("Error dispatching HTTP call for function-call: {:?}", e);
let error_msg =
format!("Error dispatching HTTP call for function-call: {:?}", e);
return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST));
}
};
@ -459,33 +464,35 @@ impl StreamContext {
let boltfc_response: ChatCompletionsResponse = serde_json::from_str(&body_str).unwrap();
let boltfc_response_str = boltfc_response.choices[0].message.content.as_ref().unwrap();
let model_resp = &boltfc_response.choices[0];
let tools_call_response: BoltFCToolsCall = match serde_json::from_str(boltfc_response_str) {
Ok(fc_resp) => fc_resp,
Err(e) => {
// This means that Bolt FC did not have enough information to resolve the function call
// Bolt FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
if model_resp.message.tool_calls.is_none() {
// This means that Bolt FC did not have enough information to resolve the function call
// Bolt FC probably responded with a message asking for more information.
// Let's send the response back to the user to initalize lightweight dialog for parameter collection
// add resolver name to the response so the client can send the response back to the correct resolver
info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e);
let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap();
self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(bolt_fc_dialogue_message.as_bytes()),
);
return;
}
};
//TODO: add resolver name to the response so the client can send the response back to the correct resolver
debug!("tool_call_details: {}", boltfc_response_str);
return self.send_http_response(
StatusCode::OK.as_u16().into(),
vec![("Powered-By", "Katanemo")],
Some(body_str.as_bytes()),
);
}
let tool_calls = model_resp.message.tool_calls.as_ref().unwrap();
if tool_calls.is_empty() {
return self.send_server_error(
"No tool calls found in function resolver response".to_string(),
Some(StatusCode::BAD_REQUEST),
);
}
debug!("tool_call_details: {:?}", tool_calls);
// extract all tool names
let tool_names: Vec<String> = tools_call_response
.tool_calls
let tool_names: Vec<String> = tool_calls
.iter()
.map(|tool_call| tool_call.name.clone())
.map(|tool_call| tool_call.function.name.clone())
.collect();
debug!(
@ -493,8 +500,8 @@ impl StreamContext {
callout_context.similarity_scores
);
//HACK: for now we only support one tool call, we will support multiple tool calls in the future
let tool_params = &tools_call_response.tool_calls[0].arguments;
let tools_call_name = tools_call_response.tool_calls[0].name.clone();
let tool_params = &tool_calls[0].function.arguments;
let tools_call_name = tool_calls[0].function.name.clone();
let tool_params_json_str = serde_json::to_string(&tool_params).unwrap();
let prompt_target = self
@ -581,6 +588,7 @@ impl StreamContext {
role: SYSTEM_ROLE.to_string(),
content: Some(system_prompt.clone()),
model: None,
tool_calls: None,
};
messages.push(system_prompt_message);
}
@ -592,6 +600,7 @@ impl StreamContext {
role: USER_ROLE.to_string(),
content: Some(body_str),
model: None,
tool_calls: None,
}
});
@ -601,6 +610,7 @@ impl StreamContext {
role: USER_ROLE.to_string(),
content: Some(callout_context.user_message.unwrap()),
model: None,
tool_calls: None,
}
});
@ -707,7 +717,8 @@ impl StreamContext {
let json_data: String = match serde_json::to_string(&get_embeddings_input) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
let error_msg = format!("Error serializing embeddings input: {}", error);
return self.send_server_error(error_msg, None);
}
};
@ -727,10 +738,11 @@ impl StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
let error_msg = format!(
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
e
);
return self.send_server_error(error_msg, None);
}
};
debug!(
@ -892,7 +904,9 @@ impl HttpContext for StreamContext {
let json_data: String = match serde_json::to_string(&get_prompt_guards_request) {
Ok(json_data) => json_data,
Err(error) => {
panic!("Error serializing embeddings input: {}", error);
let error_msg = format!("Error serializing prompt guard request: {}", error);
self.send_server_error(error_msg, None);
return Action::Pause;
}
};
@ -912,10 +926,12 @@ impl HttpContext for StreamContext {
) {
Ok(token_id) => token_id,
Err(e) => {
panic!(
"Error dispatching embedding server HTTP call for get-embeddings: {:?}",
let error_msg = format!(
"Error dispatching embedding server HTTP call for prompt-guard: {:?}",
e
);
self.send_server_error(error_msg, None);
return Action::Pause;
}
};

View file

@ -3,15 +3,14 @@ use proxy_wasm_test_framework::tester::{self, Tester};
use proxy_wasm_test_framework::types::{
Action, BufferType, LogLevel, MapType, MetricType, ReturnType,
};
use public_types::common_types::{
open_ai::{ChatCompletionsResponse, Choice, Message, Usage},
BoltFCToolsCall, IntOrString, ToolCallDetail,
};
use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage};
use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType};
use public_types::embeddings::embedding::Object;
use public_types::embeddings::{
create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding,
};
use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration};
use serde_yaml::Value;
use serial_test::serial;
use std::collections::HashMap;
use std::path::Path;
@ -403,18 +402,6 @@ fn request_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
let boltfc_tools_call = BoltFCToolsCall {
tool_calls: tool_call_detail,
};
let bolt_fc_resp = ChatCompletionsResponse {
usage: Usage {
completion_tokens: 0,
@ -424,7 +411,18 @@ fn request_ratelimited() {
index: 0,
message: Message {
role: "system".to_string(),
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],
@ -519,18 +517,6 @@ fn request_not_ratelimited() {
normal_flow(&mut module, filter_context, http_context);
let tool_call_detail = vec![ToolCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
IntOrString::Text(String::from("seattle")),
)]),
}];
let boltfc_tools_call = BoltFCToolsCall {
tool_calls: tool_call_detail,
};
let bolt_fc_resp = ChatCompletionsResponse {
usage: Usage {
completion_tokens: 0,
@ -540,7 +526,18 @@ fn request_not_ratelimited() {
index: 0,
message: Message {
role: "system".to_string(),
content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()),
content: None,
tool_calls: Some(vec![ToolCall {
id: String::from("test"),
tool_type: ToolType::Function,
function: FunctionCallDetail {
name: String::from("weather_forecast"),
arguments: HashMap::from([(
String::from("city"),
Value::String(String::from("seattle")),
)]),
},
}]),
model: None,
},
}],

View file

@ -33,7 +33,7 @@ class ArchHandler:
def _format_system(self, tools: List[Dict[str, Any]]):
def convert_tools(tools):
return "\n".join([json.dumps(tool) for tool in tools])
return "\n".join([json.dumps(tool["function"]) for tool in tools])
tool_text = convert_tools(tools)

View file

@ -1,225 +0,0 @@
import json
from typing import Any, Dict, List
SYSTEM_PROMPT = """
[BEGIN OF TASK INSTRUCTION]
You are a function calling assistant with access to the following tools. You task is to assist users as best as you can.
For each user query, you may need to call one or more functions to to better generate responses.
If none of the functions are relevant, you should point it out.
If the given query lacks the parameters required by the function, you should ask users for clarification.
The users may execute functions and return results as `Observation` to you. In the case, you MUST generate responses by summarizing it.
[END OF TASK INSTRUCTION]
""".strip()
TOOL_PROMPT = """
[BEGIN OF AVAILABLE TOOLS]
{tool_text}
[END OF AVAILABLE TOOLS]
""".strip()
FORMAT_PROMPT = """
[BEGIN OF FORMAT INSTRUCTION]
You MUST use the following JSON format if using tools.
The example format is as follows. DO NOT use this format if no function call is needed.
```
{
"tool_calls": [
{"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
... (more tool calls as required)
]
}
```
[END OF FORMAT INSTRUCTION]
""".strip()
class BoltHandler:
def _format_system(self, tools: List[Dict[str, Any]]):
tool_text = self._format_tools(tools=tools)
return (
SYSTEM_PROMPT
+ "\n\n"
+ TOOL_PROMPT.format(tool_text=tool_text)
+ "\n\n"
+ FORMAT_PROMPT
+ "\n"
)
def _format_tools(self, tools: List[Dict[str, Any]]):
TOOL_DESC = "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}"
tool_text = []
for tool in tools:
param_text = self.get_param_text(tool.parameters)
tool_text.append(
TOOL_DESC.format(
name=tool.name, desc=tool.description, args=param_text
)
)
return "\n".join(tool_text)
def extract_tools(self, content, executable=False):
# retrieve `tool_calls` from model responses
try:
content_json = json.loads(content)
except Exception:
fixed_content = self.fix_json_string(content)
try:
content_json = json.loads(fixed_content)
except json.JSONDecodeError:
return content
if isinstance(content_json, list):
tool_calls = content_json
elif isinstance(content_json, dict):
tool_calls = content_json.get("tool_calls", [])
else:
tool_calls = []
if not isinstance(tool_calls, list):
return content
# process and extract tools from `tool_calls`
extracted = []
for tool_call in tool_calls:
if isinstance(tool_call, dict):
try:
if not executable:
extracted.append({tool_call["name"]: tool_call["arguments"]})
else:
name, arguments = (
tool_call.get("name", ""),
tool_call.get("arguments", {}),
)
for key, value in arguments.items():
if value == "False" or value == "false":
arguments[key] = False
elif value == "True" or value == "true":
arguments[key] = True
args_str = ", ".join(
[f"{key}={repr(value)}" for key, value in arguments.items()]
)
extracted.append(f"{name}({args_str})")
except Exception:
continue
return extracted
def get_param_text(self, parameter_dict, prefix=""):
param_text = ""
for name, param in parameter_dict["properties"].items():
param_type = param.get("type", "")
required, default, param_format, properties, enum, items = (
"",
"",
"",
"",
"",
"",
)
if name in parameter_dict.get("required", []):
required = ", required"
required_param = parameter_dict.get("required", [])
if isinstance(required_param, bool):
required = ", required" if required_param else ""
elif isinstance(required_param, list) and name in required_param:
required = ", required"
else:
required = ", optional"
default_param = param.get("default", None)
if default_param:
default = f", default: {default_param}"
format_in = param.get("format", None)
if format_in:
param_format = f", format: {format_in}"
desc = param.get("description", "")
if "properties" in param:
arg_properties = self.get_param_text(param, prefix + " ")
properties += "with the properties:\n{}".format(arg_properties)
enum_param = param.get("enum", None)
if enum_param:
enum = "should be one of [{}]".format(", ".join(enum_param))
item_param = param.get("items", None)
if item_param:
item_type = item_param.get("type", None)
if item_type:
items += "each item should be the {} type ".format(item_type)
item_properties = item_param.get("properties", None)
if item_properties:
item_properties = self.get_param_text(item_param, prefix + " ")
items += "with the properties:\n{}".format(item_properties)
illustration = ", ".join(
[x for x in [desc, properties, enum, items] if len(x)]
)
param_text += (
prefix
+ "- {name} ({param_type}{required}{param_format}{default}): {illustration}\n".format(
name=name,
param_type=param_type,
required=required,
param_format=param_format,
default=default,
illustration=illustration,
)
)
return param_text
def fix_json_string(self, json_str):
# Remove any leading or trailing whitespace or newline characters
json_str = json_str.strip()
# Stack to keep track of brackets
stack = []
# Clean string to collect valid characters
fixed_str = ""
# Dictionary for matching brackets
matching_bracket = {")": "(", "}": "{", "]": "["}
# Dictionary for the opposite of matching_bracket
opening_bracket = {v: k for k, v in matching_bracket.items()}
for char in json_str:
if char in "{[(":
stack.append(char)
fixed_str += char
elif char in "}])":
if stack and stack[-1] == matching_bracket[char]:
stack.pop()
fixed_str += char
else:
# Ignore the unmatched closing brackets
continue
else:
fixed_str += char
# If there are unmatched opening brackets left in the stack, add corresponding closing brackets
while stack:
unmatched_opening = stack.pop()
fixed_str += opening_bracket[unmatched_opening]
# Attempt to parse the corrected string to ensure its valid JSON
return fixed_str

View file

@ -1,14 +1,10 @@
from typing import Any, Dict, List
from pydantic import BaseModel
class Tool(BaseModel):
name: str
description: str
parameters: dict
class Message(BaseModel):
role: str
content: str
class ChatMessage(BaseModel):
messages: list[Message]
tools: list[Tool]
tools: List[Dict[str, Any]]

View file

@ -1,6 +1,6 @@
import json
import random
from fastapi import FastAPI, Response
from bolt_handler import BoltHandler
from arch_handler import ArchHandler
from common import ChatMessage
import logging
@ -8,15 +8,15 @@ from openai import OpenAI
import os
ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost")
ollama_model = os.getenv("OLLAMA_MODEL", "Bolt-Function-Calling-1B:Q4_K_M")
ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M")
logger = logging.getLogger('uvicorn.error')
logger.info(f"using model: {ollama_model}")
logger.info(f"using ollama endpoint: {ollama_endpoint}")
app = FastAPI()
bolt_handler = BoltHandler()
arch_handler = ArchHandler()
handler = ArchHandler()
client = OpenAI(
base_url='http://{}:11434/v1/'.format(ollama_endpoint),
@ -35,10 +35,6 @@ async def healthz():
@app.post("/v1/chat/completions")
async def chat_completion(req: ChatMessage, res: Response):
logger.info("starting request")
if ollama_model.startswith("Bolt"):
handler = bolt_handler
else:
handler = arch_handler
tools_encoded = handler._format_system(req.tools)
# append system prompt with tools to messages
messages = [{"role": "system", "content": tools_encoded}]
@ -46,5 +42,21 @@ async def chat_completion(req: ChatMessage, res: Response):
messages.append({"role": message.role, "content": message.content})
logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}")
resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False)
logger.info(f"response: {resp.to_json()}")
tools = handler.extract_tools(resp.choices[0].message.content)
tool_calls = []
for tool in tools:
for tool_name, tool_args in tool.items():
tool_calls.append({
"id": f"call_{random.randint(1000, 10000)}",
"type": "function",
"function": {
"name": tool_name,
"arguments": tool_args
}
})
if tools:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
logger.info(f"response (tools): {json.dumps(tools)}")
logger.info(f"response: {json.dumps(resp.to_dict())}")
return resp

View file

@ -33,58 +33,11 @@ pub struct SearchPointResult {
pub payload: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameter {
#[serde(rename = "type")]
#[serde(skip_serializing_if = "Option::is_none")]
pub parameter_type: Option<String>,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolParameters {
#[serde(rename = "type")]
pub parameters_type: String,
pub properties: HashMap<String, ToolParameter>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolsDefinition {
pub name: String,
pub description: String,
pub parameters: ToolParameters,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum IntOrString {
Integer(i32),
Text(String),
Float(f64),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallDetail {
pub name: String,
pub arguments: HashMap<String, IntOrString>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoltFCToolsCall {
pub tool_calls: Vec<ToolCallDetail>,
}
pub mod open_ai {
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use super::ToolsDefinition;
use serde::{Deserialize, Serialize};
use serde_yaml::Value;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsRequest {
@ -92,13 +45,52 @@ pub mod open_ai {
pub model: String,
pub messages: Vec<Message>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tools: Option<Vec<ToolsDefinition>>,
pub tools: Option<Vec<ChatCompletionTool>>,
#[serde(default)]
pub stream: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub stream_options: Option<StreamOptions>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ToolType {
#[serde(rename = "function")]
Function,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionTool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionDefinition,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionDefinition {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionParameters {
pub properties: HashMap<String, FunctionParameter>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionParameter {
#[serde(rename = "type")]
#[serde(skip_serializing_if = "Option::is_none")]
pub parameter_type: Option<String>,
pub description: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub required: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
#[serde(rename = "enum")]
pub enum_values: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamOptions {
pub include_usage: bool,
@ -110,6 +102,7 @@ pub mod open_ai {
pub content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
pub tool_calls: Option<Vec<ToolCall>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -119,6 +112,20 @@ pub mod open_ai {
pub message: Message,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: FunctionCallDetail,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FunctionCallDetail {
pub name: String,
pub arguments: HashMap<String, Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatCompletionsResponse {
pub usage: Usage,