mirror of
https://github.com/katanemo/plano.git
synced 2026-06-08 14:55:14 +02:00
removing model_server python module to brightstaff (function calling) (#615)
* adding function_calling functionality via rust * fixed rendered YAML file * removed model_server from envoy.template and forwarding traffic to bright_staff * fixed bugs in function_calling.rs that were breaking tests. All good now * updating e2e test to clean up disk usage * removing Arch* models to be used as a default model if one is not specified * if the user sets arch-function base_url we should honor it * fixing demos as we needed to pin to a particular version of huggingface_hub else the chatbot ui wouldn't build * adding a constant for Arch-Function model name * fixing some edge cases with calls made to Arch-Function * fixed JSON parsing issues in function_calling.rs * fixed bug where the raw response from Arch-Function was re-encoded * removed debug from supervisord.conf * commenting out disk cleanup * adding back disk space --------- Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-288.local> Co-authored-by: Salman Paracha <salmanparacha@MacBook-Pro-342.local>
This commit is contained in:
parent
126b029345
commit
88c2bd1851
40 changed files with 2517 additions and 1356 deletions
|
|
@ -22,6 +22,28 @@ from common import (
|
|||
)
|
||||
|
||||
|
||||
def normalize_tool_call_arguments(tool_call):
|
||||
"""
|
||||
Normalize tool call arguments to ensure they are always a dict.
|
||||
|
||||
According to OpenAI API spec, the 'arguments' field should be a JSON string,
|
||||
but for easier testing we parse it into a dict here.
|
||||
|
||||
Args:
|
||||
tool_call: A tool call dict that may have 'arguments' as either a string or dict
|
||||
|
||||
Returns:
|
||||
A tool call dict with 'arguments' guaranteed to be a dict
|
||||
"""
|
||||
if "arguments" in tool_call and isinstance(tool_call["arguments"], str):
|
||||
try:
|
||||
tool_call["arguments"] = json.loads(tool_call["arguments"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If parsing fails, keep it as is
|
||||
pass
|
||||
return tool_call
|
||||
|
||||
|
||||
def test_prompt_gateway(httpserver: HTTPServer):
|
||||
simple_fixture = TEST_CASE_FIXTURES["SIMPLE"]
|
||||
input = simple_fixture["input"]
|
||||
|
|
@ -67,7 +89,7 @@ def test_prompt_gateway(httpserver: HTTPServer):
|
|||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
|
|
|
|||
|
|
@ -24,6 +24,28 @@ def cleanup_tool_call(tool_call):
|
|||
return tool_call.strip()
|
||||
|
||||
|
||||
def normalize_tool_call_arguments(tool_call):
|
||||
"""
|
||||
Normalize tool call arguments to ensure they are always a dict.
|
||||
|
||||
According to OpenAI API spec, the 'arguments' field should be a JSON string,
|
||||
but for easier testing we parse it into a dict here.
|
||||
|
||||
Args:
|
||||
tool_call: A tool call dict that may have 'arguments' as either a string or dict
|
||||
|
||||
Returns:
|
||||
A tool call dict with 'arguments' guaranteed to be a dict
|
||||
"""
|
||||
if "arguments" in tool_call and isinstance(tool_call["arguments"], str):
|
||||
try:
|
||||
tool_call["arguments"] = json.loads(tool_call["arguments"])
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# If parsing fails, keep it as is
|
||||
pass
|
||||
return tool_call
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_prompt_gateway(stream):
|
||||
expected_tool_call = {
|
||||
|
|
@ -62,7 +84,7 @@ def test_prompt_gateway(stream):
|
|||
print("cleaned_tool_call_str: ", cleaned_tool_call_str)
|
||||
tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls[0])
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
@ -106,7 +128,7 @@ def test_prompt_gateway(stream):
|
|||
print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json))
|
||||
tool_calls_list = cleaned_tool_call_json.get("tool_calls", [])
|
||||
assert len(tool_calls_list) > 0
|
||||
tool_call = tool_calls_list[0]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls_list[0])
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
@ -241,7 +263,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
assert role == "assistant"
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
|
|
@ -275,7 +297,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = normalize_tool_call_arguments(tool_calls[0]["function"])
|
||||
diff = DeepDiff(tool_call, expected_tool_call, ignore_string_case=True)
|
||||
assert not diff
|
||||
|
||||
|
|
|
|||
|
|
@ -1,44 +0,0 @@
|
|||
import os
|
||||
import pytest
|
||||
import requests
|
||||
import logging
|
||||
import yaml
|
||||
|
||||
pytestmark = pytest.mark.skip(
|
||||
reason="Skipping entire test file as hallucination is not enabled for archfc 1.1 yet"
|
||||
)
|
||||
|
||||
MODEL_SERVER_ENDPOINT = os.getenv(
|
||||
"MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling"
|
||||
)
|
||||
|
||||
# Load test data from YAML file
|
||||
script_dir = os.path.dirname(__file__)
|
||||
|
||||
# Construct the full path to the YAML file
|
||||
yaml_file_path = os.path.join(script_dir, "test_hallucination_data.yaml")
|
||||
|
||||
# Load test data from YAML file
|
||||
with open(yaml_file_path, "r") as file:
|
||||
test_data_yaml = yaml.safe_load(file)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"test_data",
|
||||
[
|
||||
pytest.param(test_case, id=test_case["id"])
|
||||
for test_case in test_data_yaml["test_cases"]
|
||||
],
|
||||
)
|
||||
def test_model_server(test_data):
|
||||
input = test_data["input"]
|
||||
expected = test_data["expected"]
|
||||
|
||||
response = requests.post(MODEL_SERVER_ENDPOINT, json=input)
|
||||
assert response.status_code == 200
|
||||
assert response.headers["content-type"] == "application/json"
|
||||
|
||||
response_json = response.json()
|
||||
assert response_json
|
||||
metadata = response_json.get("metadata", {})
|
||||
assert (metadata["hallucination"].lower() == "true") == expected[0]["hallucination"]
|
||||
|
|
@ -1,257 +0,0 @@
|
|||
test_cases:
|
||||
- id: "[WEATHER AGENT] - single turn, single tool, prompt prefilling"
|
||||
input:
|
||||
messages:
|
||||
- role: "user"
|
||||
content: "what is the weather forecast for seattle?"
|
||||
tools:
|
||||
- type: "function"
|
||||
function:
|
||||
name: "get_current_weather"
|
||||
description: "Get current weather at a location."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
location:
|
||||
type: "string"
|
||||
description: "The location to get the weather for"
|
||||
format: "City, State"
|
||||
days:
|
||||
type: "integer"
|
||||
description: "The number of days for the request."
|
||||
required:
|
||||
- location
|
||||
- days
|
||||
expected:
|
||||
- type: "metadata"
|
||||
hallucination: false
|
||||
|
||||
- id: "[WEATHER AGENT] - single turn, single tool, hallucination"
|
||||
input:
|
||||
messages:
|
||||
- role: "user"
|
||||
content: "what is the weather in Seattle in days?"
|
||||
tools:
|
||||
- type: "function"
|
||||
function:
|
||||
name: "get_current_weather"
|
||||
description: "Get current weather at a location."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
location:
|
||||
type: "str"
|
||||
description: "The location to get the weather for"
|
||||
format: "City, State"
|
||||
days:
|
||||
type: "int"
|
||||
description: "the number of days for the request."
|
||||
required: ["location", "days"]
|
||||
expected:
|
||||
- type: "metadata"
|
||||
hallucination: true
|
||||
|
||||
- id: "[WEATHER AGENT] - multi turn, single tool, all params passed"
|
||||
input:
|
||||
messages:
|
||||
- role: "user"
|
||||
content: "how is the weather in chicago for next 5 days?"
|
||||
- role: "assistant"
|
||||
content: "Can you tell me your location and how many days you want?"
|
||||
- role: "user"
|
||||
content: "Seattle"
|
||||
- role: "assistant"
|
||||
content: "Can you please provide me the days for the weather forecast?"
|
||||
- role: "user"
|
||||
content: "5 days"
|
||||
tools:
|
||||
- type: "function"
|
||||
function:
|
||||
name: "get_current_weather"
|
||||
description: "Get current weather at a location."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
location:
|
||||
type: "str"
|
||||
description: "The location to get the weather for"
|
||||
format: "City, State"
|
||||
days:
|
||||
type: "int"
|
||||
description: "the number of days for the request."
|
||||
required: ["location", "days"]
|
||||
expected:
|
||||
- type: "metadata"
|
||||
hallucination: false
|
||||
|
||||
- id: "[WEATHER AGENT] - multi turn, single tool, clarification"
|
||||
input:
|
||||
messages:
|
||||
- role: "user"
|
||||
content: "how is the weather for next 5 days?"
|
||||
- role: "assistant"
|
||||
content: "Can you tell me your location and how many days you want?"
|
||||
- role: "user"
|
||||
content: "Seattle"
|
||||
- role: "assistant"
|
||||
content: "Can you please provide me the days for the weather forecast?"
|
||||
- role: "user"
|
||||
content: "Sorry, the location is actually los angeles in 5 days"
|
||||
tools:
|
||||
- type: "function"
|
||||
function:
|
||||
name: "get_current_weather"
|
||||
description: "Get current weather at a location."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
location:
|
||||
type: "str"
|
||||
description: "The location to get the weather for"
|
||||
format: "City, State"
|
||||
days:
|
||||
type: "int"
|
||||
description: "the number of days for the request."
|
||||
required: ["location", "days"]
|
||||
expected:
|
||||
- type: "metadata"
|
||||
hallucination: false
|
||||
|
||||
- id: "[SALE AGENT] - single turn, single tool, hallucination region"
|
||||
input:
|
||||
messages:
|
||||
- role: "user"
|
||||
content: "get me sales opportunities of tech"
|
||||
tools:
|
||||
- type: "function"
|
||||
function:
|
||||
name: "sales_opportunity"
|
||||
description: "Retrieve potential sales opportunities based for a particular industry type in a region."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
region:
|
||||
type: "str"
|
||||
description: "Geographical region to identify sales opportunities."
|
||||
industry:
|
||||
type: "str"
|
||||
description: "Industry type."
|
||||
max_results:
|
||||
type: "int"
|
||||
description: "Maximum number of sales opportunities to retrieve."
|
||||
default: 20
|
||||
required: ["region", "industry"]
|
||||
expected:
|
||||
- type: "metadata"
|
||||
hallucination: true
|
||||
|
||||
- id: "[SALE AGENT] - single turn, single tool, hallucination industry"
|
||||
input:
|
||||
messages:
|
||||
- role: "user"
|
||||
content: "get me sales opportunities in NA"
|
||||
tools:
|
||||
- type: "function"
|
||||
function:
|
||||
name: "sales_opportunity"
|
||||
description: "Retrieve potential sales opportunities based for a particular industry type in a region."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
region:
|
||||
type: "str"
|
||||
description: "Geographical region to identify sales opportunities."
|
||||
industry:
|
||||
type: "str"
|
||||
description: "Industry type."
|
||||
max_results:
|
||||
type: "int"
|
||||
description: "Maximum number of sales opportunities to retrieve."
|
||||
default: 20
|
||||
required: ["region", "industry"]
|
||||
expected:
|
||||
- type: "metadata"
|
||||
hallucination: true
|
||||
|
||||
- id: "[PRODUCT AGENT] - single turn, single tool, hallucination industry"
|
||||
input:
|
||||
messages:
|
||||
- role: "user"
|
||||
content: "get me sales opportunities in NA"
|
||||
tools:
|
||||
- type: "function"
|
||||
function:
|
||||
name: "product_recommendation"
|
||||
description: "Place an order for an iphone with user_id 195 and location is 1600 pensylvania ave"
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
user_id:
|
||||
type: "str"
|
||||
description: "Unique identifier for the user."
|
||||
category:
|
||||
type: "str"
|
||||
description: "Product category for recommendations."
|
||||
max_results:
|
||||
type: "int"
|
||||
description: "Maximum number of recommended products to show."
|
||||
default: 10
|
||||
required: ["user_id", "category"]
|
||||
- type: "function"
|
||||
function:
|
||||
name: "place_order"
|
||||
description: "Place and pay for an order for one or more products to ship to the an address."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
user_id:
|
||||
type: "str"
|
||||
description: "Unique identifier for the user placing the order."
|
||||
product_ids:
|
||||
type: "array"
|
||||
description: "List of product IDs to include in the order."
|
||||
shipping_address:
|
||||
type: "str"
|
||||
description: "Shipping address for the order."
|
||||
payment_method:
|
||||
type: "str"
|
||||
description: "Payment method for the order."
|
||||
required: ["user_id", "product_ids", "shipping_address", "payment_method"]
|
||||
- type: "function"
|
||||
function:
|
||||
name: "sales_opportunity"
|
||||
description: "Retrieve potential sales opportunities based for a particular industry type in a region."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
region:
|
||||
type: "str"
|
||||
description: "Geographical region to identify sales opportunities."
|
||||
industry:
|
||||
type: "str"
|
||||
description: "Industry type."
|
||||
max_results:
|
||||
type: "int"
|
||||
description: "Maximum number of sales opportunities to retrieve."
|
||||
default: 20
|
||||
required: ["region", "industry"]
|
||||
- type: "function"
|
||||
function:
|
||||
name: "query_database"
|
||||
description: "Perform a database query to retrieve or update information."
|
||||
parameters:
|
||||
type: "object"
|
||||
properties:
|
||||
query:
|
||||
type: "str"
|
||||
description: "SQL query string to execute against the database."
|
||||
parameters:
|
||||
type: "array"
|
||||
description: "List of parameters to safely inject into the SQL query (to prevent SQL injection)."
|
||||
operation:
|
||||
type: "str"
|
||||
description: "Type of operation."
|
||||
required: ["query", "operation"]
|
||||
expected:
|
||||
- type: "metadata"
|
||||
hallucination: true
|
||||
|
|
@ -10,7 +10,7 @@ pytestmark = pytest.mark.skip(
|
|||
)
|
||||
|
||||
MODEL_SERVER_ENDPOINT = os.getenv(
|
||||
"MODEL_SERVER_ENDPOINT", "http://localhost:51000/function_calling"
|
||||
"MODEL_SERVER_ENDPOINT", "http://localhost:12000/function_calling"
|
||||
)
|
||||
|
||||
# Load test data from YAML file
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
@model_server_endpoint = http://localhost:51000
|
||||
@model_server_endpoint = http://localhost:12000
|
||||
@archfc_endpoint = https://archfc.katanemo.dev
|
||||
|
||||
### talk to function calling endpoint
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
@model_server_endpoint = http://localhost:51000
|
||||
@model_server_endpoint = http://localhost:12000
|
||||
@archfc_endpoint = https://archfc.katanemo.dev
|
||||
|
||||
### multi turn conversation with intent, except parameter gathering
|
||||
|
|
@ -54,26 +54,8 @@ Content-Type: application/json
|
|||
}
|
||||
]
|
||||
}
|
||||
### talk to Arch-Intent directly for completion
|
||||
POST https://archfc.katanemo.dev/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "Arch-Intent",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.\n\nYou task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.\n\n<tools>\n{\"index\": \"T0\", \"type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"str\"}, \"days\": {\"type\": \"int\"}}, \"required\": [\"city\", \"days\"]}}}\n</tools>\n\nProvide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:\n- First line must read 'Yes' or 'No'.\n- If yes, a second line must include a comma-separated list of tool indexes.\n"
|
||||
},
|
||||
{ "role": "user", "content": "hi" }
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
|
||||
|
||||
|
||||
### multi turn conversation with correct parameters
|
||||
|
||||
### multi turn conversation with intent, except parameter gathering
|
||||
POST {{model_server_endpoint}}/function_calling HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
|
|
@ -125,21 +107,6 @@ Content-Type: application/json
|
|||
}
|
||||
]
|
||||
}
|
||||
### talk to Arch-Intent directly for completion, expect No
|
||||
POST https://archfc.katanemo.dev/v1/chat/completions HTTP/1.1
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "Arch-Intent",
|
||||
"messages": [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You are a helpful assistant.\n\nYou task is to check if there are any tools that can be used to help the last user message in conversations according to the available tools listed below.\n\n<tools>\n{\"index\": \"T0\", \"type\": \"function\", \"function\": {\"name\": \"weather_forecast\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"str\"}, \"days\": {\"type\": \"int\"}}, \"required\": [\"city\", \"days\"]}}}\n</tools>\n\nProvide your tool assessment for ONLY THE LAST USER MESSAGE in the above conversation:\n- First line must read 'Yes' or 'No'.\n- If yes, a second line must include a comma-separated list of tool indexes.\n"
|
||||
},
|
||||
{ "role": "user", "content": "what is your name" }
|
||||
],
|
||||
"stream": false
|
||||
}
|
||||
|
||||
### multi turn conversation with correct parameters
|
||||
POST {{model_server_endpoint}}/function_calling HTTP/1.1
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
@model_server_endpoint = http://localhost:51000
|
||||
@model_server_endpoint = http://localhost:12000
|
||||
@archfc_endpoint = https://archfc.katanemo.dev
|
||||
|
||||
### single turn function calling all parameters insurance agent summary
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue