diff --git a/demos/function_calling/docker-compose.yaml b/demos/function_calling/docker-compose.yaml index 2cda9d57..1476b967 100644 --- a/demos/function_calling/docker-compose.yaml +++ b/demos/function_calling/docker-compose.yaml @@ -61,6 +61,7 @@ services: # uncomment following line to use ollama endpoint that is hosted by docker # - OLLAMA_ENDPOINT=ollama - OLLAMA_MODEL=Arch-Function-Calling-1.5B:Q4_K_M + # - OLLAMA_MODEL=Bolt-Function-Calling-1B:Q4_K_M api_server: build: diff --git a/envoyfilter/src/filter_context.rs b/envoyfilter/src/filter_context.rs index 07d8afa3..d9310fd9 100644 --- a/envoyfilter/src/filter_context.rs +++ b/envoyfilter/src/filter_context.rs @@ -195,7 +195,10 @@ impl Context for FilterContext { body_size: usize, _num_trailers: usize, ) { - debug!("on_http_call_response called with token_id: {:?}", token_id); + debug!( + "filter_context: on_http_call_response called with token_id: {:?}", + token_id + ); let callout_data = self.callouts.remove(&token_id).expect("invalid token_id"); self.metrics diff --git a/envoyfilter/src/stream_context.rs b/envoyfilter/src/stream_context.rs index 09550934..c7b6ae22 100644 --- a/envoyfilter/src/stream_context.rs +++ b/envoyfilter/src/stream_context.rs @@ -244,7 +244,7 @@ impl StreamContext { } }; debug!( - "dispatched HTTP call to embedding server for zero-shot-intent-detection token_id={}", + "dispatched call to model_server/zeroshot token_id={}", token_id ); @@ -462,7 +462,18 @@ impl StreamContext { let body_str = String::from_utf8(body).unwrap(); debug!("function_resolver response str: {}", body_str); - let boltfc_response: ChatCompletionsResponse = serde_json::from_str(&body_str).unwrap(); + let boltfc_response: ChatCompletionsResponse = match serde_json::from_str(&body_str) { + Ok(boltfc_response) => boltfc_response, + Err(e) => { + return self.send_server_error( + format!( + "Error deserializing function resolver response into ChatCompletion: {:?}", + e + ), + None, + ); + } + }; let model_resp = &boltfc_response.choices[0]; @@ -738,15 +749,12 @@ impl StreamContext { ) { Ok(token_id) => token_id, Err(e) => { - let error_msg = format!( - "Error dispatching embedding server HTTP call for get-embeddings: {:?}", - e - ); + let error_msg = format!("dispatched call to model_server/embeddings: {:?}", e); return self.send_server_error(error_msg, None); } }; debug!( - "dispatched HTTP call to embedding server token_id={}", + "dispatched call to model_server/embeddings token_id={}", token_id ); diff --git a/function_resolver/app/bolt_handler.py b/function_resolver/app/bolt_handler.py new file mode 100644 index 00000000..7331764e --- /dev/null +++ b/function_resolver/app/bolt_handler.py @@ -0,0 +1,226 @@ +import json +from typing import Any, Dict, List + + +SYSTEM_PROMPT = """ +[BEGIN OF TASK INSTRUCTION] +You are a function calling assistant with access to the following tools. You task is to assist users as best as you can. +For each user query, you may need to call one or more functions to to better generate responses. +If none of the functions are relevant, you should point it out. +If the given query lacks the parameters required by the function, you should ask users for clarification. +The users may execute functions and return results as `Observation` to you. In the case, you MUST generate responses by summarizing it. +[END OF TASK INSTRUCTION] +""".strip() + +TOOL_PROMPT = """ +[BEGIN OF AVAILABLE TOOLS] +{tool_text} +[END OF AVAILABLE TOOLS] +""".strip() + +FORMAT_PROMPT = """ +[BEGIN OF FORMAT INSTRUCTION] +You MUST use the following JSON format if using tools. +The example format is as follows. DO NOT use this format if no function call is needed. +``` +{ + "tool_calls": [ + {"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}}, + ... (more tool calls as required) + ] +} +``` +[END OF FORMAT INSTRUCTION] +""".strip() + + +class BoltHandler: + def _format_system(self, tools: List[Dict[str, Any]]): + tool_text = self._format_tools(tools=tools) + return ( + SYSTEM_PROMPT + + "\n\n" + + TOOL_PROMPT.format(tool_text=tool_text) + + "\n\n" + + FORMAT_PROMPT + + "\n" + ) + + def _format_tools(self, tools: List[Dict[str, Any]]): + TOOL_DESC = "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}" + + tool_text = [] + for fn in tools: + tool = fn["function"] + param_text = self.get_param_text(tool["parameters"]) + tool_text.append( + TOOL_DESC.format( + name=tool["name"], desc=tool["description"], args=param_text + ) + ) + + return "\n".join(tool_text) + + def extract_tools(self, content, executable=False): + extracted_tools = [] + # retrieve `tool_calls` from model responses + try: + content_json = json.loads(content) + except Exception: + fixed_content = self.fix_json_string(content) + try: + content_json = json.loads(fixed_content) + except json.JSONDecodeError: + return extracted_tools + + if isinstance(content_json, list): + tool_calls = content_json + elif isinstance(content_json, dict): + tool_calls = content_json.get("tool_calls", []) + else: + tool_calls = [] + + if not isinstance(tool_calls, list): + return extracted_tools + + # process and extract tools from `tool_calls` + + for tool_call in tool_calls: + if isinstance(tool_call, dict): + try: + if not executable: + extracted_tools.append({tool_call["name"]: tool_call["arguments"]}) + else: + name, arguments = ( + tool_call.get("name", ""), + tool_call.get("arguments", {}), + ) + + for key, value in arguments.items(): + if value == "False" or value == "false": + arguments[key] = False + elif value == "True" or value == "true": + arguments[key] = True + + args_str = ", ".join( + [f"{key}={repr(value)}" for key, value in arguments.items()] + ) + + extracted_tools.append(f"{name}({args_str})") + + except Exception: + continue + + return extracted_tools + + def get_param_text(self, parameter_dict, prefix=""): + param_text = "" + + for name, param in parameter_dict["properties"].items(): + param_type = param.get("type", "") + + required, default, param_format, properties, enum, items = ( + "", + "", + "", + "", + "", + "", + ) + + if name in parameter_dict.get("required", []): + required = ", required" + + required_param = parameter_dict.get("required", []) + + if isinstance(required_param, bool): + required = ", required" if required_param else "" + elif isinstance(required_param, list) and name in required_param: + required = ", required" + else: + required = ", optional" + + default_param = param.get("default", None) + if default_param: + default = f", default: {default_param}" + + format_in = param.get("format", None) + if format_in: + param_format = f", format: {format_in}" + + desc = param.get("description", "") + + if "properties" in param: + arg_properties = self.get_param_text(param, prefix + " ") + properties += "with the properties:\n{}".format(arg_properties) + + enum_param = param.get("enum", None) + if enum_param: + enum = "should be one of [{}]".format(", ".join(enum_param)) + + item_param = param.get("items", None) + if item_param: + item_type = item_param.get("type", None) + if item_type: + items += "each item should be the {} type ".format(item_type) + + item_properties = item_param.get("properties", None) + if item_properties: + item_properties = self.get_param_text(item_param, prefix + " ") + items += "with the properties:\n{}".format(item_properties) + + illustration = ", ".join( + [x for x in [desc, properties, enum, items] if len(x)] + ) + + param_text += ( + prefix + + "- {name} ({param_type}{required}{param_format}{default}): {illustration}\n".format( + name=name, + param_type=param_type, + required=required, + param_format=param_format, + default=default, + illustration=illustration, + ) + ) + + return param_text + + def fix_json_string(self, json_str): + # Remove any leading or trailing whitespace or newline characters + json_str = json_str.strip() + + # Stack to keep track of brackets + stack = [] + + # Clean string to collect valid characters + fixed_str = "" + + # Dictionary for matching brackets + matching_bracket = {")": "(", "}": "{", "]": "["} + + # Dictionary for the opposite of matching_bracket + opening_bracket = {v: k for k, v in matching_bracket.items()} + + for char in json_str: + if char in "{[(": + stack.append(char) + fixed_str += char + elif char in "}])": + if stack and stack[-1] == matching_bracket[char]: + stack.pop() + fixed_str += char + else: + # Ignore the unmatched closing brackets + continue + else: + fixed_str += char + + # If there are unmatched opening brackets left in the stack, add corresponding closing brackets + while stack: + unmatched_opening = stack.pop() + fixed_str += opening_bracket[unmatched_opening] + + # Attempt to parse the corrected string to ensure it’s valid JSON + return fixed_str diff --git a/function_resolver/app/main.py b/function_resolver/app/main.py index 99699825..4da5465f 100644 --- a/function_resolver/app/main.py +++ b/function_resolver/app/main.py @@ -2,6 +2,7 @@ import json import random from fastapi import FastAPI, Response from arch_handler import ArchHandler +from bolt_handler import BoltHandler from common import ChatMessage import logging from openai import OpenAI @@ -11,13 +12,17 @@ ollama_endpoint = os.getenv("OLLAMA_ENDPOINT", "localhost") ollama_model = os.getenv("OLLAMA_MODEL", "Arch-Function-Calling-1.5B-Q4_K_M") logger = logging.getLogger('uvicorn.error') +handler = None +if ollama_model.startswith("Arch"): + handler = ArchHandler() +else: + handler = BoltHandler() + logger.info(f"using model: {ollama_model}") logger.info(f"using ollama endpoint: {ollama_endpoint}") app = FastAPI() -handler = ArchHandler() - client = OpenAI( base_url='http://{}:11434/v1/'.format(ollama_endpoint),