diff --git a/demos/function_calling/Arch-Function-Calling-1.5B-Q4_K_M.model_file b/demos/function_calling/Arch-Function-Calling-1.5B-Q4_K_M.model_file new file mode 100644 index 00000000..855ca44e --- /dev/null +++ b/demos/function_calling/Arch-Function-Calling-1.5B-Q4_K_M.model_file @@ -0,0 +1,21 @@ +FROM Arch-Function-Calling-1.5B-Q4_K_M.gguf + +# Set parameters for response generation +PARAMETER num_predict 1024 +PARAMETER temperature 0.001 +PARAMETER top_p 1.0 +PARAMETER top_k 16000 +PARAMETER repeat_penalty 1.0 +PARAMETER stop "<|im_end|>" + +# Set the random number seed to use for generation +PARAMETER seed 42 + +# Set the prompt template to be passed into the model +TEMPLATE """ +{{- if .System }}<|im_start|>system +{{ .System }}<|im_end|> +{{ end }}{{ if .Prompt }}<|im_start|>user +{{ .Prompt }}<|im_end|> +{{ end }}<|im_start|>assistant +{{ .Response }}<|im_end|>""" diff --git a/demos/function_calling/Bolt-FC-1B-Q3_K_L.model_file b/demos/function_calling/Bolt-FC-1B-Q3_K_L.model_file deleted file mode 100644 index d58a6a17..00000000 --- a/demos/function_calling/Bolt-FC-1B-Q3_K_L.model_file +++ /dev/null @@ -1,25 +0,0 @@ -FROM Bolt-Function-Calling-1B-Q3_K_L.gguf - -# Set the size of the context window used to generate the next token -# PARAMETER num_ctx 16384 -PARAMETER num_ctx 4096 - -# Set parameters for response generation -PARAMETER num_predict 1024 -PARAMETER temperature 0.1 -PARAMETER top_p 0.5 -PARAMETER top_k 32022 -PARAMETER repeat_penalty 1.0 -PARAMETER stop "<|EOT|>" - -# Set the random number seed to use for generation -PARAMETER seed 42 - -# Set the prompt template to be passed into the model -TEMPLATE """{{ if .System }}<|begin▁of▁sentence|> -{{ .System }} -{{ end }}{{ if .Prompt }}### Instruction: -{{ .Prompt }} -{{ end }}### Response: -{{ .Response }} -<|EOT|>""" diff --git a/demos/function_calling/Bolt-FC-1B-Q4_K_M.model_file b/demos/function_calling/Bolt-FC-1B-Q4_K_M.model_file deleted file mode 100644 index 1def85b1..00000000 --- a/demos/function_calling/Bolt-FC-1B-Q4_K_M.model_file +++ /dev/null @@ -1,24 +0,0 @@ -FROM Bolt-Function-Calling-1B-Q4_K_M.gguf - -# Set the size of the context window used to generate the next token -PARAMETER num_ctx 4096 - -# Set parameters for response generation -PARAMETER num_predict 1024 -PARAMETER temperature 0.1 -PARAMETER top_p 0.5 -PARAMETER top_k 32022 -PARAMETER repeat_penalty 1.0 -PARAMETER stop "<|EOT|>" - -# Set the random number seed to use for generation -PARAMETER seed 42 - -# Set the prompt template to be passed into the model -TEMPLATE """{{ if .System }}<|begin▁of▁sentence|> -{{ .System }} -{{ end }}{{ if .Prompt }}### Instruction: -{{ .Prompt }} -{{ end }}### Response: -{{ .Response }} -<|EOT|>""" diff --git a/demos/function_calling/README.md b/demos/function_calling/README.md index e44c6375..86005388 100644 --- a/demos/function_calling/README.md +++ b/demos/function_calling/README.md @@ -11,14 +11,14 @@ This demo shows how you can use intelligent prompt gateway to do function callin ```sh docker compose up ``` -1. Download Bolt-FC model. This demo assumes we have downloaded [Bolt-Function-Calling-1B:Q4_K_M](https://huggingface.co/katanemolabs/Bolt-Function-Calling-1B.gguf/blob/main/Bolt-Function-Calling-1B-Q4_K_M.gguf) to local folder. +1. Download Bolt-FC model. This demo assumes we have downloaded [Arch-Function-Calling-1.5B:Q4_K_M](https://huggingface.co/katanemolabs/Arch-Function-Calling-1.5B.gguf/blob/main/Arch-Function-Calling-1.5B-Q4_K_M.gguf) to local folder. 1. If running ollama natively run ```sh ollama serve ``` 2. Create model file in ollama repository ```sh - ollama create Bolt-Function-Calling-1B:Q4_K_M -f Bolt-FC-1B-Q4_K_M.model_file + ollama create Arch-Function-Calling-1.5B:Q4_K_M -f Arch-Function-Calling-1.5B-Q4_K_M.model_file ``` 3. Navigate to http://localhost:18080/ 4. You can type in queries like "how is the weather in Seattle" diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index ad994592..70f95c36 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -59,6 +59,7 @@ services: - OLLAMA_ENDPOINT=${OLLAMA_ENDPOINT:-host.docker.internal} # uncomment following line to use ollama endpoint that is hosted by docker # - OLLAMA_ENDPOINT=ollama + - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M api_server: build: diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 5d2bdb5c..09550934 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -15,13 +15,13 @@ use log::{debug, info, warn}; use proxy_wasm::traits::*; use proxy_wasm::types::*; use public_types::common_types::open_ai::{ - ChatCompletionChunkResponse, ChatCompletionsRequest, ChatCompletionsResponse, Message, - StreamOptions, + ChatCompletionChunkResponse, ChatCompletionTool, ChatCompletionsRequest, + ChatCompletionsResponse, FunctionDefinition, FunctionParameter, FunctionParameters, Message, + StreamOptions, ToolType, }; use public_types::common_types::{ - BoltFCToolsCall, EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, - ToolParameter, ToolParameters, ToolsDefinition, ZeroShotClassificationRequest, - ZeroShotClassificationResponse, + EmbeddingType, PromptGuardRequest, PromptGuardResponse, PromptGuardTask, + ZeroShotClassificationRequest, ZeroShotClassificationResponse, }; use public_types::configuration::{Overrides, PromptGuards, PromptTarget, PromptType}; use public_types::embeddings::{ @@ -215,7 +215,8 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&zero_shot_classification_request) { Ok(json_data) => json_data, Err(error) => { - panic!("Error serializing zero shot request: {}", error); + let error = format!("Error serializing zero shot request: {}", error); + return self.send_server_error(error, None); } }; @@ -235,10 +236,11 @@ impl StreamContext { ) { Ok(token_id) => token_id, Err(e) => { - panic!( - "Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}", - e - ); + let error_msg = format!( + "Error dispatching embedding server HTTP call for zero-shot-intent-detection: {:?}", + e + ); + return self.send_server_error(error_msg, None); } }; debug!( @@ -358,16 +360,15 @@ impl StreamContext { match prompt_target.prompt_type { PromptType::FunctionResolver => { - let mut tools_definitions: Vec = Vec::new(); - + let mut chat_completion_tools: Vec = Vec::new(); for pt in self.prompt_targets.read().unwrap().values() { // only extract entity names - let properties: HashMap = match pt.parameters { + let properties: HashMap = match pt.parameters { // Clone is unavoidable here because we don't want to move the values out of the prompt target struct. Some(ref entities) => { - let mut properties: HashMap = HashMap::new(); + let mut properties: HashMap = HashMap::new(); for entity in entities.iter() { - let param = ToolParameter { + let param = FunctionParameter { parameter_type: entity.parameter_type.clone(), description: entity.description.clone(), required: entity.required, @@ -380,22 +381,24 @@ impl StreamContext { } None => HashMap::new(), }; - let tools_parameters = ToolParameters { - parameters_type: "dict".to_string(), - properties, - }; + let tools_parameters = FunctionParameters { properties }; - tools_definitions.push(ToolsDefinition { - name: pt.name.clone(), - description: pt.description.clone(), - parameters: tools_parameters, + chat_completion_tools.push({ + ChatCompletionTool { + tool_type: ToolType::Function, + function: FunctionDefinition { + name: pt.name.clone(), + description: pt.description.clone(), + parameters: tools_parameters, + }, + } }); } let chat_completions = ChatCompletionsRequest { model: GPT_35_TURBO.to_string(), messages: callout_context.request_body.messages.clone(), - tools: Some(tools_definitions), + tools: Some(chat_completion_tools), stream: false, stream_options: None, }; @@ -432,7 +435,9 @@ impl StreamContext { ) { Ok(token_id) => token_id, Err(e) => { - panic!("Error dispatching HTTP call for function-call: {:?}", e); + let error_msg = + format!("Error dispatching HTTP call for function-call: {:?}", e); + return self.send_server_error(error_msg, Some(StatusCode::BAD_REQUEST)); } }; @@ -459,33 +464,35 @@ impl StreamContext { let boltfc_response: ChatCompletionsResponse = serde_json::from_str(&body_str).unwrap(); - let boltfc_response_str = boltfc_response.choices[0].message.content.as_ref().unwrap(); + let model_resp = &boltfc_response.choices[0]; - let tools_call_response: BoltFCToolsCall = match serde_json::from_str(boltfc_response_str) { - Ok(fc_resp) => fc_resp, - Err(e) => { - // This means that Bolt FC did not have enough information to resolve the function call - // Bolt FC probably responded with a message asking for more information. - // Let's send the response back to the user to initalize lightweight dialog for parameter collection + if model_resp.message.tool_calls.is_none() { + // This means that Bolt FC did not have enough information to resolve the function call + // Bolt FC probably responded with a message asking for more information. + // Let's send the response back to the user to initalize lightweight dialog for parameter collection - // add resolver name to the response so the client can send the response back to the correct resolver - info!("some requred parameters are missing, sending response from Bolt FC back to user for parameter collection: {}", e); - let bolt_fc_dialogue_message = serde_json::to_string(&boltfc_response).unwrap(); - self.send_http_response( - StatusCode::OK.as_u16().into(), - vec![("Powered-By", "Katanemo")], - Some(bolt_fc_dialogue_message.as_bytes()), - ); - return; - } - }; + //TODO: add resolver name to the response so the client can send the response back to the correct resolver - debug!("tool_call_details: {}", boltfc_response_str); + return self.send_http_response( + StatusCode::OK.as_u16().into(), + vec![("Powered-By", "Katanemo")], + Some(body_str.as_bytes()), + ); + } + + let tool_calls = model_resp.message.tool_calls.as_ref().unwrap(); + if tool_calls.is_empty() { + return self.send_server_error( + "No tool calls found in function resolver response".to_string(), + Some(StatusCode::BAD_REQUEST), + ); + } + + debug!("tool_call_details: {:?}", tool_calls); // extract all tool names - let tool_names: Vec = tools_call_response - .tool_calls + let tool_names: Vec = tool_calls .iter() - .map(|tool_call| tool_call.name.clone()) + .map(|tool_call| tool_call.function.name.clone()) .collect(); debug!( @@ -493,8 +500,8 @@ impl StreamContext { callout_context.similarity_scores ); //HACK: for now we only support one tool call, we will support multiple tool calls in the future - let tool_params = &tools_call_response.tool_calls[0].arguments; - let tools_call_name = tools_call_response.tool_calls[0].name.clone(); + let tool_params = &tool_calls[0].function.arguments; + let tools_call_name = tool_calls[0].function.name.clone(); let tool_params_json_str = serde_json::to_string(&tool_params).unwrap(); let prompt_target = self @@ -581,6 +588,7 @@ impl StreamContext { role: SYSTEM_ROLE.to_string(), content: Some(system_prompt.clone()), model: None, + tool_calls: None, }; messages.push(system_prompt_message); } @@ -592,6 +600,7 @@ impl StreamContext { role: USER_ROLE.to_string(), content: Some(body_str), model: None, + tool_calls: None, } }); @@ -601,6 +610,7 @@ impl StreamContext { role: USER_ROLE.to_string(), content: Some(callout_context.user_message.unwrap()), model: None, + tool_calls: None, } }); @@ -707,7 +717,8 @@ impl StreamContext { let json_data: String = match serde_json::to_string(&get_embeddings_input) { Ok(json_data) => json_data, Err(error) => { - panic!("Error serializing embeddings input: {}", error); + let error_msg = format!("Error serializing embeddings input: {}", error); + return self.send_server_error(error_msg, None); } }; @@ -727,10 +738,11 @@ impl StreamContext { ) { Ok(token_id) => token_id, Err(e) => { - panic!( + let error_msg = format!( "Error dispatching embedding server HTTP call for get-embeddings: {:?}", e ); + return self.send_server_error(error_msg, None); } }; debug!( @@ -892,7 +904,9 @@ impl HttpContext for StreamContext { let json_data: String = match serde_json::to_string(&get_prompt_guards_request) { Ok(json_data) => json_data, Err(error) => { - panic!("Error serializing embeddings input: {}", error); + let error_msg = format!("Error serializing prompt guard request: {}", error); + self.send_server_error(error_msg, None); + return Action::Pause; } }; @@ -912,10 +926,12 @@ impl HttpContext for StreamContext { ) { Ok(token_id) => token_id, Err(e) => { - panic!( - "Error dispatching embedding server HTTP call for get-embeddings: {:?}", + let error_msg = format!( + "Error dispatching embedding server HTTP call for prompt-guard: {:?}", e ); + self.send_server_error(error_msg, None); + return Action::Pause; } }; diff --git a/envoyfilter/tests/integration.rs b/envoyfilter/tests/integration.rs index f45cde7c..ce02a203 100644 --- a/envoyfilter/tests/integration.rs +++ b/envoyfilter/tests/integration.rs @@ -3,15 +3,14 @@ use proxy_wasm_test_framework::tester::{self, Tester}; use proxy_wasm_test_framework::types::{ Action, BufferType, LogLevel, MapType, MetricType, ReturnType, }; -use public_types::common_types::{ - open_ai::{ChatCompletionsResponse, Choice, Message, Usage}, - BoltFCToolsCall, IntOrString, ToolCallDetail, -}; +use public_types::common_types::open_ai::{ChatCompletionsResponse, Choice, Message, Usage}; +use public_types::common_types::open_ai::{FunctionCallDetail, ToolCall, ToolType}; use public_types::embeddings::embedding::Object; use public_types::embeddings::{ create_embedding_response, CreateEmbeddingResponse, CreateEmbeddingResponseUsage, Embedding, }; use public_types::{common_types::ZeroShotClassificationResponse, configuration::Configuration}; +use serde_yaml::Value; use serial_test::serial; use std::collections::HashMap; use std::path::Path; @@ -403,18 +402,6 @@ fn request_ratelimited() { normal_flow(&mut module, filter_context, http_context); - let tool_call_detail = vec![ToolCallDetail { - name: String::from("weather_forecast"), - arguments: HashMap::from([( - String::from("city"), - IntOrString::Text(String::from("seattle")), - )]), - }]; - - let boltfc_tools_call = BoltFCToolsCall { - tool_calls: tool_call_detail, - }; - let bolt_fc_resp = ChatCompletionsResponse { usage: Usage { completion_tokens: 0, @@ -424,7 +411,18 @@ fn request_ratelimited() { index: 0, message: Message { role: "system".to_string(), - content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()), + content: None, + tool_calls: Some(vec![ToolCall { + id: String::from("test"), + tool_type: ToolType::Function, + function: FunctionCallDetail { + name: String::from("weather_forecast"), + arguments: HashMap::from([( + String::from("city"), + Value::String(String::from("seattle")), + )]), + }, + }]), model: None, }, }], @@ -519,18 +517,6 @@ fn request_not_ratelimited() { normal_flow(&mut module, filter_context, http_context); - let tool_call_detail = vec![ToolCallDetail { - name: String::from("weather_forecast"), - arguments: HashMap::from([( - String::from("city"), - IntOrString::Text(String::from("seattle")), - )]), - }]; - - let boltfc_tools_call = BoltFCToolsCall { - tool_calls: tool_call_detail, - }; - let bolt_fc_resp = ChatCompletionsResponse { usage: Usage { completion_tokens: 0, @@ -540,7 +526,18 @@ fn request_not_ratelimited() { index: 0, message: Message { role: "system".to_string(), - content: Some(serde_json::to_string(&boltfc_tools_call).unwrap()), + content: None, + tool_calls: Some(vec![ToolCall { + id: String::from("test"), + tool_type: ToolType::Function, + function: FunctionCallDetail { + name: String::from("weather_forecast"), + arguments: HashMap::from([( + String::from("city"), + Value::String(String::from("seattle")), + )]), + }, + }]), model: None, }, }], diff --git a/function_resolver/app/arch_handler.py b/function_resolver/app/arch_handler.py index 77b0a65d..a92a62bc 100644 --- a/function_resolver/app/arch_handler.py +++ b/function_resolver/app/arch_handler.py @@ -33,7 +33,7 @@ class ArchHandler: def _format_system(self, tools: List[Dict[str, Any]]): def convert_tools(tools): - return "\n".join([json.dumps(tool) for tool in tools]) + return "\n".join([json.dumps(tool["function"]) for tool in tools]) tool_text = convert_tools(tools) diff --git a/function_resolver/app/bolt_handler.py b/function_resolver/app/bolt_handler.py deleted file mode 100644 index bd544803..00000000 --- a/function_resolver/app/bolt_handler.py +++ /dev/null @@ -1,225 +0,0 @@ -import json -from typing import Any, Dict, List - - -SYSTEM_PROMPT = """ -[BEGIN OF TASK INSTRUCTION] -You are a function calling assistant with access to the following tools. You task is to assist users as best as you can. -For each user query, you may need to call one or more functions to to better generate responses. -If none of the functions are relevant, you should point it out. -If the given query lacks the parameters required by the function, you should ask users for clarification. -The users may execute functions and return results as `Observation` to you. In the case, you MUST generate responses by summarizing it. -[END OF TASK INSTRUCTION] -""".strip() - -TOOL_PROMPT = """ -[BEGIN OF AVAILABLE TOOLS] -{tool_text} -[END OF AVAILABLE TOOLS] -""".strip() - -FORMAT_PROMPT = """ -[BEGIN OF FORMAT INSTRUCTION] -You MUST use the following JSON format if using tools. -The example format is as follows. DO NOT use this format if no function call is needed. -``` -{ - "tool_calls": [ - {"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}, - ... (more tool calls as required) - ] -} -``` -[END OF FORMAT INSTRUCTION] -""".strip() - - -class BoltHandler: - def _format_system(self, tools: List[Dict[str, Any]]): - tool_text = self._format_tools(tools=tools) - return ( - SYSTEM_PROMPT - + "\n\n" - + TOOL_PROMPT.format(tool_text=tool_text) - + "\n\n" - + FORMAT_PROMPT - + "\n" - ) - - def _format_tools(self, tools: List[Dict[str, Any]]): - TOOL_DESC = "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}" - - tool_text = [] - for tool in tools: - param_text = self.get_param_text(tool.parameters) - tool_text.append( - TOOL_DESC.format( - name=tool.name, desc=tool.description, args=param_text - ) - ) - - return "\n".join(tool_text) - - def extract_tools(self, content, executable=False): - # retrieve `tool_calls` from model responses - try: - content_json = json.loads(content) - except Exception: - fixed_content = self.fix_json_string(content) - try: - content_json = json.loads(fixed_content) - except json.JSONDecodeError: - return content - - if isinstance(content_json, list): - tool_calls = content_json - elif isinstance(content_json, dict): - tool_calls = content_json.get("tool_calls", []) - else: - tool_calls = [] - - if not isinstance(tool_calls, list): - return content - - # process and extract tools from `tool_calls` - extracted = [] - - for tool_call in tool_calls: - if isinstance(tool_call, dict): - try: - if not executable: - extracted.append({tool_call["name"]: tool_call["arguments"]}) - else: - name, arguments = ( - tool_call.get("name", ""), - tool_call.get("arguments", {}), - ) - - for key, value in arguments.items(): - if value == "False" or value == "false": - arguments[key] = False - elif value == "True" or value == "true": - arguments[key] = True - - args_str = ", ".join( - [f"{key}={repr(value)}" for key, value in arguments.items()] - ) - - extracted.append(f"{name}({args_str})") - - except Exception: - continue - - return extracted - - def get_param_text(self, parameter_dict, prefix=""): - param_text = "" - - for name, param in parameter_dict["properties"].items(): - param_type = param.get("type", "") - - required, default, param_format, properties, enum, items = ( - "", - "", - "", - "", - "", - "", - ) - - if name in parameter_dict.get("required", []): - required = ", required" - - required_param = parameter_dict.get("required", []) - - if isinstance(required_param, bool): - required = ", required" if required_param else "" - elif isinstance(required_param, list) and name in required_param: - required = ", required" - else: - required = ", optional" - - default_param = param.get("default", None) - if default_param: - default = f", default: {default_param}" - - format_in = param.get("format", None) - if format_in: - param_format = f", format: {format_in}" - - desc = param.get("description", "") - - if "properties" in param: - arg_properties = self.get_param_text(param, prefix + " ") - properties += "with the properties:\n{}".format(arg_properties) - - enum_param = param.get("enum", None) - if enum_param: - enum = "should be one of [{}]".format(", ".join(enum_param)) - - item_param = param.get("items", None) - if item_param: - item_type = item_param.get("type", None) - if item_type: - items += "each item should be the {} type ".format(item_type) - - item_properties = item_param.get("properties", None) - if item_properties: - item_properties = self.get_param_text(item_param, prefix + " ") - items += "with the properties:\n{}".format(item_properties) - - illustration = ", ".join( - [x for x in [desc, properties, enum, items] if len(x)] - ) - - param_text += ( - prefix - + "- {name} ({param_type}{required}{param_format}{default}): {illustration}\n".format( - name=name, - param_type=param_type, - required=required, - param_format=param_format, - default=default, - illustration=illustration, - ) - ) - - return param_text - - def fix_json_string(self, json_str): - # Remove any leading or trailing whitespace or newline characters - json_str = json_str.strip() - - # Stack to keep track of brackets - stack = [] - - # Clean string to collect valid characters - fixed_str = "" - - # Dictionary for matching brackets - matching_bracket = {")": "(", "}": "{", "]": "["} - - # Dictionary for the opposite of matching_bracket - opening_bracket = {v: k for k, v in matching_bracket.items()} - - for char in json_str: - if char in "{[(": - stack.append(char) - fixed_str += char - elif char in "}])": - if stack and stack[-1] == matching_bracket[char]: - stack.pop() - fixed_str += char - else: - # Ignore the unmatched closing brackets - continue - else: - fixed_str += char - - # If there are unmatched opening brackets left in the stack, add corresponding closing brackets - while stack: - unmatched_opening = stack.pop() - fixed_str += opening_bracket[unmatched_opening] - - # Attempt to parse the corrected string to ensure it’s valid JSON - return fixed_str diff --git a/function_resolver/app/common.py b/function_resolver/app/common.py index 3b44863e..10fdde51 100644 --- a/function_resolver/app/common.py +++ b/function_resolver/app/common.py @@ -1,14 +1,10 @@ +from typing import Any, Dict, List from pydantic import BaseModel -class Tool(BaseModel): - name: str - description: str - parameters: dict - class Message(BaseModel): role: str content: str class ChatMessage(BaseModel): messages: list[Message] - tools: list[Tool] + tools: List[Dict[str, Any]] diff --git a/function_resolver/app/main.py b/function_resolver/app/main.py index 4c1d028a..99699825 100644 --- a/function_resolver/app/main.py +++ b/function_resolver/app/main.py @@ -1,6 +1,6 @@ import json +import random from fastapi import FastAPI, Response -from bolt_handler import BoltHandler from arch_handler import ArchHandler from common import ChatMessage import logging @@ -8,15 +8,15 @@ from openai import OpenAI import os ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost") -ollama_model = os.getenv("OLLAMA_MODEL", "Bolt-Function-Calling-1B:Q4_K_M") +ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M") logger = logging.getLogger('uvicorn.error') logger.info(f"using model: {ollama_model}") logger.info(f"using ollama endpoint: {ollama_endpoint}") app = FastAPI() -bolt_handler = BoltHandler() -arch_handler = ArchHandler() + +handler = ArchHandler() client = OpenAI( base_url='http://{}:11434/v1/'.format(ollama_endpoint), @@ -35,10 +35,6 @@ async def healthz(): @app.post("/v1/chat/completions") async def chat_completion(req: ChatMessage, res: Response): logger.info("starting request") - if ollama_model.startswith("Bolt"): - handler = bolt_handler - else: - handler = arch_handler tools_encoded = handler._format_system(req.tools) # append system prompt with tools to messages messages = [{"role": "system", "content": tools_encoded}] @@ -46,5 +42,21 @@ async def chat_completion(req: ChatMessage, res: Response): messages.append({"role": message.role, "content": message.content}) logger.info(f"request model: {ollama_model}, messages: {json.dumps(messages)}") resp = client.chat.completions.create(messages=messages, model=ollama_model, stream=False) - logger.info(f"response: {resp.to_json()}") + tools = handler.extract_tools(resp.choices[0].message.content) + tool_calls = [] + for tool in tools: + for tool_name, tool_args in tool.items(): + tool_calls.append({ + "id": f"call_{random.randint(1000, 10000)}", + "type": "function", + "function": { + "name": tool_name, + "arguments": tool_args + } + }) + if tools: + resp.choices[0].message.tool_calls = tool_calls + resp.choices[0].message.content = None + logger.info(f"response (tools): {json.dumps(tools)}") + logger.info(f"response: {json.dumps(resp.to_dict())}") return resp diff --git a/public_types/src/common_types.rs b/public_types/src/common_types.rs index 82693ddc..385c7bef 100644 --- a/public_types/src/common_types.rs +++ b/public_types/src/common_types.rs @@ -33,58 +33,11 @@ pub struct SearchPointResult { pub payload: HashMap, } -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolParameter { - #[serde(rename = "type")] - #[serde(skip_serializing_if = "Option::is_none")] - pub parameter_type: Option, - pub description: String, - #[serde(skip_serializing_if = "Option::is_none")] - pub required: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(rename = "enum")] - pub enum_values: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - pub default: Option, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolParameters { - #[serde(rename = "type")] - pub parameters_type: String, - pub properties: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolsDefinition { - pub name: String, - pub description: String, - pub parameters: ToolParameters, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(untagged)] -pub enum IntOrString { - Integer(i32), - Text(String), - Float(f64), -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct ToolCallDetail { - pub name: String, - pub arguments: HashMap, -} - -#[derive(Debug, Clone, Serialize, Deserialize)] -pub struct BoltFCToolsCall { - pub tool_calls: Vec, -} - pub mod open_ai { - use serde::{Deserialize, Serialize}; + use std::collections::HashMap; - use super::ToolsDefinition; + use serde::{Deserialize, Serialize}; + use serde_yaml::Value; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionsRequest { @@ -92,13 +45,52 @@ pub mod open_ai { pub model: String, pub messages: Vec, #[serde(skip_serializing_if = "Option::is_none")] - pub tools: Option>, + pub tools: Option>, #[serde(default)] pub stream: bool, #[serde(skip_serializing_if = "Option::is_none")] pub stream_options: Option, } + #[derive(Debug, Clone, Serialize, Deserialize)] + pub enum ToolType { + #[serde(rename = "function")] + Function, + } + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct ChatCompletionTool { + #[serde(rename = "type")] + pub tool_type: ToolType, + pub function: FunctionDefinition, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct FunctionDefinition { + pub name: String, + pub description: String, + pub parameters: FunctionParameters, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct FunctionParameters { + pub properties: HashMap, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct FunctionParameter { + #[serde(rename = "type")] + #[serde(skip_serializing_if = "Option::is_none")] + pub parameter_type: Option, + pub description: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "enum")] + pub enum_values: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub default: Option, + } + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct StreamOptions { pub include_usage: bool, @@ -110,6 +102,7 @@ pub mod open_ai { pub content: Option, #[serde(skip_serializing_if = "Option::is_none")] pub model: Option, + pub tool_calls: Option>, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -119,6 +112,20 @@ pub mod open_ai { pub message: Message, } + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub tool_type: ToolType, + pub function: FunctionCallDetail, + } + + #[derive(Debug, Clone, Serialize, Deserialize)] + pub struct FunctionCallDetail { + pub name: String, + pub arguments: HashMap, + } + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ChatCompletionsResponse { pub usage: Usage,