Integrate Arch-Function-Chat (#449)

This commit is contained in:
Shuguang Chen 2025-04-15 14:39:12 -07:00 committed by GitHub
parent f31aa59fac
commit 7d4b261a68
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 558 additions and 603 deletions

View file

@ -2,6 +2,7 @@ import json
import pytest
import requests
from deepdiff import DeepDiff
import re
from common import (
PROMPT_GATEWAY_ENDPOINT,
@ -11,6 +12,15 @@ from common import (
)
def cleanup_tool_call(tool_call):
pattern = r"```json\n(.*?)\n```"
match = re.search(pattern, tool_call, re.DOTALL)
if match:
tool_call = match.group(1)
return tool_call.strip()
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway(stream):
expected_tool_call = {
@ -42,9 +52,14 @@ def test_prompt_gateway(stream):
assert "role" in choices[0]["delta"]
role = choices[0]["delta"]["role"]
assert role == "assistant"
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
print(f"choices: {choices}")
tool_call_str = choices[0].get("delta", {}).get("content", "")
print("tool_call_str: ", tool_call_str)
cleaned_tool_call_str = cleanup_tool_call(tool_call_str)
print("cleaned_tool_call_str: ", cleaned_tool_call_str)
tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
tool_call = tool_calls[0]
location = tool_call["arguments"]["location"]
assert expected_tool_call["arguments"]["location"] in location.lower()
del expected_tool_call["arguments"]["location"]
@ -62,7 +77,7 @@ def test_prompt_gateway(stream):
# third..end chunk is summarization (role = assistant)
response_json = json.loads(chunks[2])
assert response_json.get("model").startswith("llama-3.2-3b-preview")
assert response_json.get("model").startswith("gpt-4o")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["delta"]
@ -71,18 +86,24 @@ def test_prompt_gateway(stream):
else:
response_json = response.json()
assert response_json.get("model").startswith("llama-3.2-3b-preview")
assert response_json.get("model").startswith("gpt-4o")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["message"]
assert choices[0]["message"]["role"] == "assistant"
# now verify arch_messages (tool call and api response) that are sent as response metadata
arch_messages = get_arch_messages(response_json)
print("arch_messages: ", json.dumps(arch_messages))
assert len(arch_messages) == 2
tool_calls_message = arch_messages[0]
tool_calls = tool_calls_message.get("tool_calls", [])
assert len(tool_calls) > 0
tool_call = tool_calls[0]["function"]
print("tool_calls_message: ", tool_calls_message)
tool_calls = tool_calls_message.get("content", [])
cleaned_tool_call_str = cleanup_tool_call(tool_calls)
cleaned_tool_call_json = json.loads(cleaned_tool_call_str)
print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json))
tool_calls_list = cleaned_tool_call_json.get("tool_calls", [])
assert len(tool_calls_list) > 0
tool_call = tool_calls_list[0]
location = tool_call["arguments"]["location"]
assert expected_tool_call["arguments"]["location"] in location.lower()
del expected_tool_call["arguments"]["location"]
@ -231,7 +252,7 @@ def test_prompt_gateway_param_tool_call(stream):
# third..end chunk is summarization (role = assistant)
response_json = json.loads(chunks[2])
assert response_json.get("model").startswith("llama-3.2-3b-preview")
assert response_json.get("model").startswith("gpt-4o")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["delta"]
@ -240,7 +261,7 @@ def test_prompt_gateway_param_tool_call(stream):
else:
response_json = response.json()
assert response_json.get("model").startswith("llama-3.2-3b-preview")
assert response_json.get("model").startswith("gpt-4o")
choices = response_json.get("choices", [])
assert len(choices) > 0
assert "role" in choices[0]["message"]
@ -262,7 +283,7 @@ def test_prompt_gateway_default_target(stream):
"messages": [
{
"role": "user",
"content": "hello, what can you do for me?",
"content": "hello",
},
],
"stream": stream,
@ -273,17 +294,20 @@ def test_prompt_gateway_default_target(stream):
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
print("response_json chunks[0]: ", response_json)
assert response_json.get("model").startswith("api_server")
assert len(response_json.get("choices", [])) > 0
assert response_json.get("choices")[0]["delta"]["role"] == "assistant"
response_json = json.loads(chunks[1])
print("response_json chunks[1]: ", response_json)
choices = response_json.get("choices", [])
assert len(choices) > 0
content = choices[0]["delta"]["content"]
assert content == "I can help you with weather forecast"
else:
response_json = response.json()
print("response_json: ", response_json)
assert response_json.get("model").startswith("api_server")
assert len(response_json.get("choices")) > 0
assert response_json.get("choices")[0]["message"]["role"] == "assistant"