mirror of
https://github.com/katanemo/plano.git
synced 2026-05-27 14:17:15 +02:00
Integrate Arch-Function-Chat (#449)
This commit is contained in:
parent
f31aa59fac
commit
7d4b261a68
26 changed files with 558 additions and 603 deletions
|
|
@ -2,6 +2,7 @@ import json
|
|||
import pytest
|
||||
import requests
|
||||
from deepdiff import DeepDiff
|
||||
import re
|
||||
|
||||
from common import (
|
||||
PROMPT_GATEWAY_ENDPOINT,
|
||||
|
|
@ -11,6 +12,15 @@ from common import (
|
|||
)
|
||||
|
||||
|
||||
def cleanup_tool_call(tool_call):
|
||||
pattern = r"```json\n(.*?)\n```"
|
||||
match = re.search(pattern, tool_call, re.DOTALL)
|
||||
if match:
|
||||
tool_call = match.group(1)
|
||||
|
||||
return tool_call.strip()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_prompt_gateway(stream):
|
||||
expected_tool_call = {
|
||||
|
|
@ -42,9 +52,14 @@ def test_prompt_gateway(stream):
|
|||
assert "role" in choices[0]["delta"]
|
||||
role = choices[0]["delta"]["role"]
|
||||
assert role == "assistant"
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
print(f"choices: {choices}")
|
||||
tool_call_str = choices[0].get("delta", {}).get("content", "")
|
||||
print("tool_call_str: ", tool_call_str)
|
||||
cleaned_tool_call_str = cleanup_tool_call(tool_call_str)
|
||||
print("cleaned_tool_call_str: ", cleaned_tool_call_str)
|
||||
tool_calls = json.loads(cleaned_tool_call_str).get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
tool_call = tool_calls[0]
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
@ -62,7 +77,7 @@ def test_prompt_gateway(stream):
|
|||
|
||||
# third..end chunk is summarization (role = assistant)
|
||||
response_json = json.loads(chunks[2])
|
||||
assert response_json.get("model").startswith("llama-3.2-3b-preview")
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
|
|
@ -71,18 +86,24 @@ def test_prompt_gateway(stream):
|
|||
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("llama-3.2-3b-preview")
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["message"]
|
||||
assert choices[0]["message"]["role"] == "assistant"
|
||||
# now verify arch_messages (tool call and api response) that are sent as response metadata
|
||||
arch_messages = get_arch_messages(response_json)
|
||||
print("arch_messages: ", json.dumps(arch_messages))
|
||||
assert len(arch_messages) == 2
|
||||
tool_calls_message = arch_messages[0]
|
||||
tool_calls = tool_calls_message.get("tool_calls", [])
|
||||
assert len(tool_calls) > 0
|
||||
tool_call = tool_calls[0]["function"]
|
||||
print("tool_calls_message: ", tool_calls_message)
|
||||
tool_calls = tool_calls_message.get("content", [])
|
||||
cleaned_tool_call_str = cleanup_tool_call(tool_calls)
|
||||
cleaned_tool_call_json = json.loads(cleaned_tool_call_str)
|
||||
print("cleaned_tool_call_json: ", json.dumps(cleaned_tool_call_json))
|
||||
tool_calls_list = cleaned_tool_call_json.get("tool_calls", [])
|
||||
assert len(tool_calls_list) > 0
|
||||
tool_call = tool_calls_list[0]
|
||||
location = tool_call["arguments"]["location"]
|
||||
assert expected_tool_call["arguments"]["location"] in location.lower()
|
||||
del expected_tool_call["arguments"]["location"]
|
||||
|
|
@ -231,7 +252,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
|
||||
# third..end chunk is summarization (role = assistant)
|
||||
response_json = json.loads(chunks[2])
|
||||
assert response_json.get("model").startswith("llama-3.2-3b-preview")
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["delta"]
|
||||
|
|
@ -240,7 +261,7 @@ def test_prompt_gateway_param_tool_call(stream):
|
|||
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("llama-3.2-3b-preview")
|
||||
assert response_json.get("model").startswith("gpt-4o")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
assert "role" in choices[0]["message"]
|
||||
|
|
@ -262,7 +283,7 @@ def test_prompt_gateway_default_target(stream):
|
|||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hello, what can you do for me?",
|
||||
"content": "hello",
|
||||
},
|
||||
],
|
||||
"stream": stream,
|
||||
|
|
@ -273,17 +294,20 @@ def test_prompt_gateway_default_target(stream):
|
|||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
print("response_json chunks[0]: ", response_json)
|
||||
assert response_json.get("model").startswith("api_server")
|
||||
assert len(response_json.get("choices", [])) > 0
|
||||
assert response_json.get("choices")[0]["delta"]["role"] == "assistant"
|
||||
|
||||
response_json = json.loads(chunks[1])
|
||||
print("response_json chunks[1]: ", response_json)
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
content = choices[0]["delta"]["content"]
|
||||
assert content == "I can help you with weather forecast"
|
||||
else:
|
||||
response_json = response.json()
|
||||
print("response_json: ", response_json)
|
||||
assert response_json.get("model").startswith("api_server")
|
||||
assert len(response_json.get("choices")) > 0
|
||||
assert response_json.get("choices")[0]["message"]["role"] == "assistant"
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue