diff --git a/e2e_tests/.vscode/launch.json b/e2e_tests/.vscode/launch.json new file mode 100644 index 00000000..6a211d8e --- /dev/null +++ b/e2e_tests/.vscode/launch.json @@ -0,0 +1,15 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "name": "Python Debugger: Current File", + "type": "debugpy", + "request": "launch", + "program": "${file}", + "console": "integratedTerminal" + } + ] +} diff --git a/e2e_tests/common.py b/e2e_tests/common.py index 7ccee7c4..1edb6517 100644 --- a/e2e_tests/common.py +++ b/e2e_tests/common.py @@ -10,6 +10,16 @@ LLM_GATEWAY_ENDPOINT = os.getenv( ) ARCH_STATE_HEADER = "x-arch-state" +PREFILL_LIST = [ + "May", + "Could", + "Sure", + "Definitely", + "Certainly", + "Of course", + "Can", +] + def get_data_chunks(stream, n=1): chunks = [] diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index 31f305d4..4a3e7eb7 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -3,7 +3,12 @@ import pytest import requests from deepdiff import DeepDiff -from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks +from common import ( + PROMPT_GATEWAY_ENDPOINT, + PREFILL_LIST, + get_arch_messages, + get_data_chunks, +) @pytest.mark.parametrize("stream", [True, False]) @@ -101,13 +106,21 @@ def test_prompt_gateway_arch_direct_response(stream): assert len(choices) > 0 tool_calls = choices[0].get("delta", {}).get("tool_calls", []) assert len(tool_calls) == 0 + response_json = json.loads(chunks[1]) + choices = response_json.get("choices", []) + assert len(choices) > 0 + message = choices[0]["delta"]["content"] else: response_json = response.json() assert response_json.get("model").startswith("Arch") choices = response_json.get("choices", []) assert len(choices) > 0 message = choices[0]["message"]["content"] + assert "Could you provide the following details days" not in message + assert any( + message.startswith(word) for word in PREFILL_LIST + ), f"Expected assistant message to start with one of {PREFILL_LIST}, but got '{assistant_message}'" @pytest.mark.parametrize("stream", [True, False]) diff --git a/model_server/app/commons/constants.py b/model_server/app/commons/constants.py index 67970bf9..d4e01d12 100644 --- a/model_server/app/commons/constants.py +++ b/model_server/app/commons/constants.py @@ -8,6 +8,9 @@ from app.prompt_guard.model_handler import ArchGuardHanlder logger = utils.get_model_server_logger() arch_function_hanlder = ArchFunctionHandler() +PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"] +PREFILL_ENABLED = True +TOOL_CALL_TOKEN = "" arch_function_endpoint = "https://api.fc.archgw.com/v1" arch_function_client = utils.get_client(arch_function_endpoint) arch_function_generation_params = { diff --git a/model_server/app/function_calling/model_utils.py b/model_server/app/function_calling/model_utils.py index acab10c5..6e7b926c 100644 --- a/model_server/app/function_calling/model_utils.py +++ b/model_server/app/function_calling/model_utils.py @@ -1,21 +1,21 @@ import json import hashlib import app.commons.constants as const - +import random from fastapi import Response from pydantic import BaseModel from app.commons.utilities import get_model_server_logger -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional logger = get_model_server_logger() class Message(BaseModel): - role: str - content: str = "" - tool_calls: List[Dict[str, Any]] = [] - tool_call_id: str = "" + role: Optional[str] = "" + content: Optional[str] = "" + tool_calls: Optional[List[Dict[str, Any]]] = [] + tool_call_id: Optional[str] = "" class ChatMessage(BaseModel): @@ -23,6 +23,20 @@ class ChatMessage(BaseModel): tools: List[Dict[str, Any]] +class Choice(BaseModel): + message: Message + finish_reason: Optional[str] = "stop" + index: Optional[int] = 0 + + +class ChatCompletionResponse(BaseModel): + choices: List[Choice] + model: Optional[str] = "Arch-Function" + created: Optional[str] = "" + id: Optional[str] = "" + object: Optional[str] = "chat_completion" + + def process_messages(history: list[Message]): updated_history = [] for hist in history: @@ -67,30 +81,77 @@ async def chat_completion(req: ChatMessage, res: Response): f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}" ) + # Retrieve the first token, handling the Stream object carefully + try: resp = const.arch_function_client.chat.completions.create( messages=messages, model=client_model_name, - stream=False, + stream=const.PREFILL_ENABLED, extra_body=const.arch_function_generation_params, ) except Exception as e: logger.error(f"model_server <= arch_function: error: {e}") raise - tool_calls = const.arch_function_hanlder.extract_tool_calls( - resp.choices[0].message.content - ) + if const.PREFILL_ENABLED: + first_token_content = "" + for token in resp: + first_token_content = token.choices[ + 0 + ].delta.content.strip() # Clean up the content + if first_token_content: # Break if it's non-empty + break + + # Check if the first token requires tool call handling + if first_token_content != const.TOOL_CALL_TOKEN: + # Engage pre-filling response if no tool call is indicated + resp.close() + logger.info("Tool call is not found! Engage pre filling") + prefill_content = random.choice(const.PREFILL_LIST) + messages.append({"role": "assistant", "content": prefill_content}) + + # Send a new completion request with the updated messages + # the model will continue the final message in the chat instead of starting a new one + # disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response. + extra_body = { + **const.arch_function_generation_params, + "continue_final_message": True, + "add_generation_prompt": False, + } + pre_fill_resp = const.arch_function_client.chat.completions.create( + messages=messages, + model=client_model_name, + stream=False, + extra_body=extra_body, + ) + full_response = pre_fill_resp.choices[0].message.content + else: + # Initialize full response and iterate over tokens to gather the full response + full_response = first_token_content + for token in resp: + if hasattr(token.choices[0].delta, "content"): + full_response += token.choices[0].delta.content + else: + logger.info("Stream is disabled, not engaging pre-filling") + full_response = resp.choices[0].message.content + + tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response) if tool_calls: - resp.choices[0].message.tool_calls = tool_calls - resp.choices[0].message.content = None + message = Message(content="", tool_calls=tool_calls) + else: + message = Message(content=full_response, tool_calls=[]) + choice = Choice(message=message) + chat_completion_response = ChatCompletionResponse( + choices=[choice], model=client_model_name + ) logger.info( f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}" ) logger.info( - f"model_server <= arch_function: response body: {json.dumps(resp.to_dict())}" + f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}" ) - return resp + return chat_completion_response diff --git a/model_server/app/main.py b/model_server/app/main.py index a8d312d7..801bc36d 100644 --- a/model_server/app/main.py +++ b/model_server/app/main.py @@ -6,7 +6,7 @@ import app.prompt_guard.model_utils as guard_utils from typing import List, Dict from pydantic import BaseModel -from fastapi import FastAPI, Response, HTTPException +from fastapi import FastAPI, Response, HTTPException, Request from app.function_calling.model_utils import ChatMessage from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler @@ -214,6 +214,11 @@ async def hallucination(req: HallucinationRequest, res: Response): @app.post("/v1/chat/completions") -async def chat_completion(req: ChatMessage, res: Response): - result = await arch_function_chat_completion(req, res) - return result +async def chat_completion(req: ChatMessage, res: Response, request: Request): + try: + result = await arch_function_chat_completion(req, res) + return result + except Exception as e: + logger.error(f"Error in chat_completion: {e}") + res.status_code = 500 + return {"error": "Internal server error"} diff --git a/model_server/app/tests/test_function_calling.py b/model_server/app/tests/test_function_calling.py new file mode 100644 index 00000000..251007d3 --- /dev/null +++ b/model_server/app/tests/test_function_calling.py @@ -0,0 +1,90 @@ +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +import app.commons.constants as const +from fastapi import Response +from app.function_calling.model_utils import ( + process_messages, + chat_completion, + Message, + ChatMessage, + Choice, + ChatCompletionResponse, +) + + +def sample_messages(): + # Ensure fields are explicitly set with valid data or empty values + return [ + Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""), + Message( + role="assistant", + content="", + tool_calls=[{"function": {"name": "sample_tool"}}], + tool_call_id="sample_id", + ), + Message( + role="tool", content="Response from tool", tool_calls=[], tool_call_id="" + ), + ] + + +def sample_request(sample_messages): + return ChatMessage( + messages=sample_messages, + tools=[{"name": "sample_tool", "description": "A sample tool"}], + ) + + +@patch("app.commons.constants.arch_function_hanlder") +def test_process_messages(mock_hanlder): + messages = sample_messages() + processed = process_messages(messages) + + assert len(processed) == 3 + assert processed[0] == {"role": "user", "content": "Hello!"} + assert processed[1] == { + "role": "assistant", + "content": '\n{"name": "sample_tool"}\n', + } + assert processed[2] == { + "role": "user", + "content": "\nResponse from tool\n", + } + + +@patch("app.commons.constants.arch_function_client") +@patch("app.commons.constants.arch_function_hanlder") +@pytest.mark.asyncio +async def test_chat_completion(mock_hanlder, mock_client): + # Mock the model list return for client + mock_client.models.list.return_value = MagicMock( + data=[MagicMock(id="sample_model")] + ) + request = sample_request(sample_messages()) + # Simulate stream response as list of tokens + mock_response = AsyncMock() + mock_response.__aiter__.return_value = [ + MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]), + MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream + ] + mock_client.chat.completions.create.return_value = mock_response + + # Mock the tool formatter + mock_hanlder._format_system.return_value = "" + + response = Response() + chat_response = await chat_completion(request, response) + + assert isinstance(chat_response, ChatCompletionResponse) + assert chat_response.choices[0].message.content is not None + + first_call_args = mock_client.chat.completions.create.call_args_list[0][1] + assert first_call_args["stream"] == True + assert "model" in first_call_args + assert first_call_args["messages"][0]["content"] == "" + + # Check that the arguments for the second call to 'create' include the pre-fill completion + second_call_args = mock_client.chat.completions.create.call_args_list[1][1] + assert second_call_args["stream"] == False + assert "model" in second_call_args + assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST