integrate hallucination

This commit is contained in:
cotran 2024-12-06 15:50:03 -08:00
parent f7d69d52a7
commit 5e164e8e3c
3 changed files with 594 additions and 25 deletions

View file

@ -12,6 +12,7 @@ from app.model_handler.base_handler import (
ChatCompletionResponse,
ArchBaseHandler,
)
from app.function_calling.hallucination_handler import HallucinationStateHandler
SUPPORT_DATA_TYPES = ["int", "float", "bool", "str", "list", "tuple", "set", "dict"]
@ -342,6 +343,36 @@ class ArchFunctionHandler(ArchBaseHandler):
return is_valid, error_tool_call, error_message
def _prefill_response(self, messages: List[Dict[str, str]]):
"""
Prefills the response with the tool call prefix.
Args:
messages (List[Dict[str, str]]): A list of messages.
tools (List[Dict[str, Any]]): A list of tools.
Returns:
List[Dict[str, str]]: A list of messages with the prefill prefix.
"""
messages.append(
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
)
prefill_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
return prefill_response
@override
async def chat_completion(
self, req: ChatMessage, enable_prefilling=True
@ -392,23 +423,7 @@ class ArchFunctionHandler(ArchBaseHandler):
# start parameter gathering if the model is not generating a tool call
if has_tool_call is False:
messages.append(
{
"role": "assistant",
"content": random.choice(self.prefill_prefix),
}
)
prefill_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body={
**self.generation_params,
**self.prefill_params,
},
)
prefill_response = self._prefill_response(messages)
model_response = prefill_response.choices[0].message.content
else:
model_response = response.choices[0].message.content