diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index 4dcd24e2..1b9035e0 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -24,11 +24,7 @@ Provide your tool assessment for ONLY THE LAST USER MESSAGE in the above convers ARCH_INTENT_GENERATION_CONFIG = { - "generation_params": { - "stop_token_ids": [151645], - "max_tokens": 1, - "guided_choice": ["Yes", "No"], - } + "generation_params": {"max_tokens": 1, "stop_token_ids": [151645]} } diff --git a/model_server/app/model_handler/base_handler.py b/model_server/app/model_handler/base_handler.py index cfc7a0f8..f6b811da 100644 --- a/model_server/app/model_handler/base_handler.py +++ b/model_server/app/model_handler/base_handler.py @@ -28,8 +28,8 @@ class ChatCompletionResponse(BaseModel): id: Optional[int] = 0 object: Optional[str] = "chat_completion" created: Optional[str] = "" - model: str choices: List[Choice] + model: str class ArchBaseHandler: @@ -124,7 +124,7 @@ class ArchBaseHandler: if tools: processed_messages.append( - {"role": "system", "content": self._format_system(tools)} + {"role": "system", "content": self._format_system_prompt(tools)} ) for message in messages: diff --git a/model_server/app/model_handler/function_calling.py b/model_server/app/model_handler/function_calling.py index a3bb2dc3..897c785a 100644 --- a/model_server/app/model_handler/function_calling.py +++ b/model_server/app/model_handler/function_calling.py @@ -108,7 +108,9 @@ class ArchIntentHandler(ArchBaseHandler): extra_body=self.generation_params, ) - model_response = Message(content=model_response, tool_calls=[]) + model_response = Message( + content=model_response.choices[0].message.content, tool_calls=[] + ) chat_completion_response = ChatCompletionResponse( choices=[Choice(message=model_response)], model=self.model_name diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py index 9f15507e..49b1ea07 100644 --- a/model_server/app/tests/test_function_calling.py +++ b/model_server/app/tests/test_function_calling.py @@ -86,7 +86,7 @@ async def test_chat_completion(mock_hanlder, mock_client): mock_client.chat.completions.create.return_value = mock_response # Mock the tool formatter - mock_hanlder._format_system.return_value = "" + mock_hanlder._format_system_prompt.return_value = "" response = Response() chat_response = await chat_completion(request, response)