Merge branch 'shuguang/main' of https://github.com/katanemo/arch into cotran/intent

This commit is contained in:
cotran 2024-12-06 14:34:55 -08:00
commit f7d69d52a7
5 changed files with 37 additions and 21 deletions

View file

@ -69,6 +69,8 @@ async def function_calling(req: ChatMessage, res: Response):
# logger.error(f"Error in chat_completion from `Arch-Function`: {e}")
res.status_code = 500
return {"error": f"[Arch-Function] - {e}"}
# [TODO] Review: define the behavior if `Arch-Intent` doesn't detect an intent
# else:
except Exception as e:
# [TODO] Review: update how to collect debugging outputs

View file

@ -106,7 +106,7 @@ class ArchBaseHandler:
self,
messages: List[Message],
tools: List[Dict[str, Any]] = None,
extra_instructions: str = None,
extra_instruction: str = None,
):
"""
Processes a list of messages and formats them appropriately.
@ -114,7 +114,7 @@ class ArchBaseHandler:
Args:
messages (List[Message]): A list of message objects.
tools (List[Dict[str, Any]], optional): A list of tools to include in the system prompt.
extra_instructions (str, optional): Additional instructions to append to the last user message.
extra_instruction (str, optional): Additional instructions to append to the last user message.
Returns:
List[Dict[str, Any]]: A list of processed message dictionaries.
@ -148,8 +148,8 @@ class ArchBaseHandler:
assert processed_messages[-1]["role"] == "user"
if extra_instructions:
processed_messages[-1]["content"] += extra_instructions
if extra_instruction:
processed_messages[-1]["content"] += extra_instruction
return processed_messages

View file

@ -99,20 +99,24 @@ class ArchIntentHandler(ArchBaseHandler):
Currently only support vllm inference
"""
messages = self._process_messages(
req.messages, req.tools, self.extra_instruction
)
# In the case that no tools are available, simply return `No` to avoid making a call
if len(req.tools) == 0:
model_response = Message(content="No", tool_calls=[])
else:
messages = self._process_messages(
req.messages, req.tools, self.extra_instruction
)
model_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
model_response = self.client.chat.completions.create(
messages=messages,
model=self.model_name,
stream=False,
extra_body=self.generation_params,
)
model_response = Message(
content=model_response.choices[0].message.content, tool_calls=[]
)
model_response = Message(
content=model_response.choices[0].message.content, tool_calls=[]
)
chat_completion_response = ChatCompletionResponse(
choices=[Choice(message=model_response)], model=self.model_name

View file

@ -105,7 +105,7 @@ class ArchGuardHanlder:
sentence = None
return GuardResponse(
prob=prob.item(), verdict=verdict, sentence=sentence, latency=latency
prob=[prob.item()], verdict=verdict, sentence=[sentence], latency=latency
)
def predict(self, req: GuardRequest, max_num_words=300) -> GuardResponse:
@ -138,9 +138,9 @@ class ArchGuardHanlder:
chunk_result = self._predict_text(req.task, chunk)
if chunk_result.verdict:
prob.append(chunk_result.prob)
prob.append(chunk_result.prob[0])
verdict = True
sentence.append(chunk_result.sentence)
sentence.append(chunk_result.sentence[0])
latency += chunk_result.latency
return GuardResponse(