mirror of
https://github.com/katanemo/plano.git
synced 2026-04-26 17:26:26 +02:00
Integrate Arch-Function-Chat (#449)
This commit is contained in:
parent
f31aa59fac
commit
7d4b261a68
26 changed files with 558 additions and 603 deletions
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
|
||||
import time
|
||||
from src.commons.globals import handler_map
|
||||
from src.core.utils.model_utils import ChatMessage, Message
|
||||
|
||||
|
|
@ -37,26 +37,9 @@ get_weather_api = {
|
|||
# get_data class return request, intent, hallucination, parameter_gathering
|
||||
|
||||
|
||||
def get_hallucination_data_complex():
|
||||
def get_hallucination_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
message2 = Message(
|
||||
role="assistant", content="Can you specify the unit you want the weather in?"
|
||||
)
|
||||
message3 = Message(role="user", content="In celcius please!")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
|
||||
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_medium():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in?")
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in days?")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
|
@ -65,26 +48,10 @@ def get_hallucination_data_medium():
|
|||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# first token will not be tool call
|
||||
return req, True, True, True
|
||||
return req, False, True
|
||||
|
||||
|
||||
def get_complete_data_2():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(
|
||||
role="user",
|
||||
content="what is the weather forecast for seattle in the next 10 days?",
|
||||
)
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
def get_complete_data():
|
||||
def get_success_tool_call_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
|
||||
|
||||
|
|
@ -94,7 +61,7 @@ def get_complete_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
return req, True, False
|
||||
|
||||
|
||||
def get_irrelevant_data():
|
||||
|
|
@ -107,7 +74,7 @@ def get_irrelevant_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
return req, False, False
|
||||
|
||||
|
||||
def get_greeting_data():
|
||||
|
|
@ -120,38 +87,29 @@ def get_greeting_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
return req, False, False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data_complex,
|
||||
get_complete_data,
|
||||
get_hallucination_data,
|
||||
get_greeting_data,
|
||||
get_irrelevant_data,
|
||||
get_complete_data_2,
|
||||
get_success_tool_call_data,
|
||||
],
|
||||
)
|
||||
async def test_function_calling(get_data_func):
|
||||
req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
req, intent, hallucination = get_data_func()
|
||||
handler_name = "Arch-Function"
|
||||
use_agent_orchestrator = False
|
||||
model_handler: ArchFunctionHandler = handler_map[handler_name]
|
||||
|
||||
intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
start_time = time.perf_counter()
|
||||
final_response = await model_handler.chat_completion(req)
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
assert intent == (len(final_response.choices[0].message.tool_calls) >= 1)
|
||||
|
||||
if intent:
|
||||
function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
req
|
||||
)
|
||||
assert (
|
||||
handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
== hallucination
|
||||
)
|
||||
response_txt = function_calling_response.choices[0].message.content
|
||||
|
||||
if parameter_gathering:
|
||||
prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
assert any(
|
||||
response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
assert hallucination == model_handler.hallucination_state.hallucination
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from src.commons.globals import handler_map
|
||||
from src.core.function_calling import Message
|
||||
from src.core.function_calling import ArchFunctionHandler, Message
|
||||
|
||||
|
||||
test_input_history = [
|
||||
|
|
@ -7,34 +7,19 @@ test_input_history = [
|
|||
{
|
||||
"role": "assistant",
|
||||
"model": "Arch-Function",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_3394",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
"content": '```json\n{"tool_calls": [{"name": "get_current_weather", "arguments": {"days": 5, "location": "Chicago, Illinois"}}]}\n```',
|
||||
},
|
||||
{
|
||||
"role": "tool",
|
||||
"model": "Arch-Function",
|
||||
"content": '{"location":"Chicago%2C%20Illinois","temperature":[{"date":"2025-04-14","temperature":{"min":53,"max":65},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432817+00:00"},{"date":"2025-04-15","temperature":{"min":85,"max":97},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432830+00:00"},{"date":"2025-04-16","temperature":{"min":62,"max":78},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432835+00:00"},{"date":"2025-04-17","temperature":{"min":89,"max":101},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432839+00:00"},{"date":"2025-04-18","temperature":{"min":86,"max":104},"units":"Farenheit","query_time":"2025-04-14 17:01:52.432843+00:00"}],"units":"Farenheit"}',
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_3394"},
|
||||
{"role": "assistant", "content": "--", "model": "gpt-3.5-turbo-0125"},
|
||||
{"role": "user", "content": "how is the weather in chicago for next 5 days?"},
|
||||
{
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": "call_5306",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "weather_forecast",
|
||||
"arguments": {"city": "Chicago", "days": 5},
|
||||
},
|
||||
}
|
||||
],
|
||||
"model": "gpt-4o-2024-08-06",
|
||||
"content": '{"response": "Based on the forecast data you provided, here is the weather for the next 5 days in Chicago:\\n\\n- **April 14, 2025**: The temperature will range between 53\\u00b0F and 65\\u00b0F. \\n- **April 15, 2025**: The temperature will range between 85\\u00b0F and 97\\u00b0F.\\n- **April 16, 2025**: The temperature will range between 62\\u00b0F and 78\\u00b0F.\\n- **April 17, 2025**: The temperature will range between 89\\u00b0F and 101\\u00b0F.\\n- **April 18, 2025**: The temperature will range between 86\\u00b0F and 104\\u00b0F.\\n\\nPlease note that the temperatures are given in Fahrenheit."}',
|
||||
},
|
||||
{"role": "tool", "content": "--", "tool_call_id": "call_5306"},
|
||||
{"role": "user", "content": "what about seattle?"},
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -44,7 +29,8 @@ def test_update_fc_history():
|
|||
for h in test_input_history:
|
||||
message_history.append(Message(**h))
|
||||
|
||||
updated_history = handler_map["Arch-Function"]._process_messages(message_history)
|
||||
assert len(updated_history) == 7
|
||||
handler: ArchFunctionHandler = handler_map["Arch-Function"]
|
||||
updated_history = handler._process_messages(message_history)
|
||||
assert len(updated_history) == 5
|
||||
# ensure that tool role does not exist anymore
|
||||
assert all([h["role"] != "tool" for h in updated_history])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue