mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
fix bug and test
This commit is contained in:
parent
cc0845bce4
commit
afe7cc9e9e
3 changed files with 36 additions and 79 deletions
|
|
@ -416,15 +416,14 @@ class ArchFunctionHandler(ArchBaseHandler):
|
|||
has_tool_calls, has_hallucination = None, False
|
||||
for _ in self.hallucination_state:
|
||||
# check if the first token is <tool_call>
|
||||
if len(self.hallucination_state.tokens) > 5 and has_tool_calls is None:
|
||||
content = "".join(self.hallucination_state.tokens)
|
||||
if "tool_calls" in content:
|
||||
logger.info(
|
||||
f"[Content]: {content}"
|
||||
)
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
content = "".join(self.hallucination_state.tokens)
|
||||
if "tool_calls" in content:
|
||||
logger.info(
|
||||
f"[Content]: {content}"
|
||||
)
|
||||
has_tool_calls = True
|
||||
else:
|
||||
has_tool_calls = False
|
||||
|
||||
|
||||
# if the model is hallucinating, start parameter gathering
|
||||
|
|
|
|||
|
|
@ -36,8 +36,8 @@ class MaskToken(Enum):
|
|||
|
||||
|
||||
HALLUCINATION_THRESHOLD_DICT = {
|
||||
"entropy": 0.28,
|
||||
"varentropy": 1.4,
|
||||
"entropy": 0.35,
|
||||
"varentropy": 1.1,
|
||||
"probability": 0.8,
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -37,24 +37,7 @@ get_weather_api = {
|
|||
# get_data class return request, intent, hallucination, parameter_gathering
|
||||
|
||||
|
||||
def get_hallucination_data_complex():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle?")
|
||||
message2 = Message(
|
||||
role="assistant", content="Can you specify the unit you want the weather in?"
|
||||
)
|
||||
message3 = Message(role="user", content="In celcius please!")
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1, message2, message3], tools=tools)
|
||||
|
||||
return req, True, True, True
|
||||
|
||||
|
||||
def get_hallucination_data_medium():
|
||||
def get_hallucination_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in?")
|
||||
|
||||
|
|
@ -65,26 +48,10 @@ def get_hallucination_data_medium():
|
|||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
# first token will not be tool call
|
||||
return req, True, True, True
|
||||
return req, False, True
|
||||
|
||||
|
||||
def get_complete_data_2():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(
|
||||
role="user",
|
||||
content="what is the weather forecast for seattle in the next 10 days?",
|
||||
)
|
||||
|
||||
# Create a list of tools
|
||||
tools = [get_weather_api]
|
||||
|
||||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
|
||||
|
||||
def get_complete_data():
|
||||
def get_success_tool_call_data():
|
||||
# Create instances of the Message class
|
||||
message1 = Message(role="user", content="How is the weather in Seattle in 7 days?")
|
||||
|
||||
|
|
@ -94,7 +61,7 @@ def get_complete_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, True, False, False
|
||||
return req, True, False
|
||||
|
||||
|
||||
def get_irrelevant_data():
|
||||
|
|
@ -107,7 +74,7 @@ def get_irrelevant_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
return req, False, False
|
||||
|
||||
|
||||
def get_greeting_data():
|
||||
|
|
@ -120,38 +87,29 @@ def get_greeting_data():
|
|||
# Create an instance of the ChatMessage class
|
||||
req = ChatMessage(messages=[message1], tools=tools)
|
||||
|
||||
return req, False, False, False
|
||||
return req, False, False
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# "get_data_func",
|
||||
# [
|
||||
# get_hallucination_data_complex,
|
||||
# get_complete_data,
|
||||
# get_irrelevant_data,
|
||||
# get_complete_data_2,
|
||||
# ],
|
||||
# )
|
||||
# async def test_function_calling(get_data_func):
|
||||
# req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"get_data_func",
|
||||
[
|
||||
get_hallucination_data,
|
||||
get_greeting_data,
|
||||
get_irrelevant_data,
|
||||
get_success_tool_call_data,
|
||||
],
|
||||
)
|
||||
async def test_function_calling(get_data_func):
|
||||
req, intent, hallucination, parameter_gathering = get_data_func()
|
||||
handler_name = "Arch-Function"
|
||||
use_agent_orchestrator = False
|
||||
model_handler: ArchFunctionHandler = handler_map[handler_name]
|
||||
|
||||
# intent_response = await handler_map["Arch-Intent"].chat_completion(req)
|
||||
start_time = time.perf_counter()
|
||||
final_response = await model_handler.chat_completion(req)
|
||||
latency = time.perf_counter() - start_time
|
||||
|
||||
# assert handler_map["Arch-Intent"].detect_intent(intent_response) == intent
|
||||
assert intent == (len(final_response.choices[0].message.tool_calls)>=1)
|
||||
|
||||
# if intent:
|
||||
# function_calling_response = await handler_map["Arch-Function"].chat_completion(
|
||||
# req
|
||||
# )
|
||||
# assert (
|
||||
# handler_map["Arch-Function"].hallucination_state.hallucination
|
||||
# == hallucination
|
||||
# )
|
||||
# response_txt = function_calling_response.choices[0].message.content
|
||||
|
||||
# if parameter_gathering:
|
||||
# prefill_prefix = handler_map["Arch-Function"].prefill_prefix
|
||||
# assert any(
|
||||
# response_txt.startswith(prefix) for prefix in prefill_prefix
|
||||
# ), f"Response '{response_txt}' does not start with any of the prefixes: {prefill_prefix}"
|
||||
assert hallucination == model_handler.hallucination_state.hallucination
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue