add prefill and test (#236)

* add prefill and test

* fix stream

* fix

* feedback

* address comments

* update

* add e2e test

* fix e2e test

* update fix

* fix

* address cmt

* address cmt
This commit is contained in:
CTran 2024-11-07 11:59:29 -08:00 committed by GitHub
parent f48489f7c0
commit fb67788be0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 216 additions and 19 deletions

View file

@ -8,6 +8,9 @@ from app.prompt_guard.model_handler import ArchGuardHanlder
logger = utils.get_model_server_logger()
arch_function_hanlder = ArchFunctionHandler()
PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
PREFILL_ENABLED = True
TOOL_CALL_TOKEN = "<tool_call>"
arch_function_endpoint = "https://api.fc.archgw.com/v1"
arch_function_client = utils.get_client(arch_function_endpoint)
arch_function_generation_params = {

View file

@ -1,21 +1,21 @@
import json
import hashlib
import app.commons.constants as const
import random
from fastapi import Response
from pydantic import BaseModel
from app.commons.utilities import get_model_server_logger
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
logger = get_model_server_logger()
class Message(BaseModel):
role: str
content: str = ""
tool_calls: List[Dict[str, Any]] = []
tool_call_id: str = ""
role: Optional[str] = ""
content: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
tool_call_id: Optional[str] = ""
class ChatMessage(BaseModel):
@ -23,6 +23,20 @@ class ChatMessage(BaseModel):
tools: List[Dict[str, Any]]
class Choice(BaseModel):
message: Message
finish_reason: Optional[str] = "stop"
index: Optional[int] = 0
class ChatCompletionResponse(BaseModel):
choices: List[Choice]
model: Optional[str] = "Arch-Function"
created: Optional[str] = ""
id: Optional[str] = ""
object: Optional[str] = "chat_completion"
def process_messages(history: list[Message]):
updated_history = []
for hist in history:
@ -67,30 +81,77 @@ async def chat_completion(req: ChatMessage, res: Response):
f"model_server => arch_function: {client_model_name}, messages: {json.dumps(messages)}"
)
# Retrieve the first token, handling the Stream object carefully
try:
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
stream=const.PREFILL_ENABLED,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
tool_calls = const.arch_function_hanlder.extract_tool_calls(
resp.choices[0].message.content
)
if const.PREFILL_ENABLED:
first_token_content = ""
for token in resp:
first_token_content = token.choices[
0
].delta.content.strip() # Clean up the content
if first_token_content: # Break if it's non-empty
break
# Check if the first token requires tool call handling
if first_token_content != const.TOOL_CALL_TOKEN:
# Engage pre-filling response if no tool call is indicated
resp.close()
logger.info("Tool call is not found! Engage pre filling")
prefill_content = random.choice(const.PREFILL_LIST)
messages.append({"role": "assistant", "content": prefill_content})
# Send a new completion request with the updated messages
# the model will continue the final message in the chat instead of starting a new one
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
extra_body = {
**const.arch_function_generation_params,
"continue_final_message": True,
"add_generation_prompt": False,
}
pre_fill_resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=extra_body,
)
full_response = pre_fill_resp.choices[0].message.content
else:
# Initialize full response and iterate over tokens to gather the full response
full_response = first_token_content
for token in resp:
if hasattr(token.choices[0].delta, "content"):
full_response += token.choices[0].delta.content
else:
logger.info("Stream is disabled, not engaging pre-filling")
full_response = resp.choices[0].message.content
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
if tool_calls:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
message = Message(content="", tool_calls=tool_calls)
else:
message = Message(content=full_response, tool_calls=[])
choice = Choice(message=message)
chat_completion_response = ChatCompletionResponse(
choices=[choice], model=client_model_name
)
logger.info(
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
)
logger.info(
f"model_server <= arch_function: response body: {json.dumps(resp.to_dict())}"
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
)
return resp
return chat_completion_response

View file

@ -6,7 +6,7 @@ import app.prompt_guard.model_utils as guard_utils
from typing import List, Dict
from pydantic import BaseModel
from fastapi import FastAPI, Response, HTTPException
from fastapi import FastAPI, Response, HTTPException, Request
from app.function_calling.model_utils import ChatMessage
from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler
@ -214,6 +214,11 @@ async def hallucination(req: HallucinationRequest, res: Response):
@app.post("/v1/chat/completions")
async def chat_completion(req: ChatMessage, res: Response):
result = await arch_function_chat_completion(req, res)
return result
async def chat_completion(req: ChatMessage, res: Response, request: Request):
try:
result = await arch_function_chat_completion(req, res)
return result
except Exception as e:
logger.error(f"Error in chat_completion: {e}")
res.status_code = 500
return {"error": "Internal server error"}

View file

@ -0,0 +1,90 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
import app.commons.constants as const
from fastapi import Response
from app.function_calling.model_utils import (
process_messages,
chat_completion,
Message,
ChatMessage,
Choice,
ChatCompletionResponse,
)
def sample_messages():
# Ensure fields are explicitly set with valid data or empty values
return [
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
Message(
role="assistant",
content="",
tool_calls=[{"function": {"name": "sample_tool"}}],
tool_call_id="sample_id",
),
Message(
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
),
]
def sample_request(sample_messages):
return ChatMessage(
messages=sample_messages,
tools=[{"name": "sample_tool", "description": "A sample tool"}],
)
@patch("app.commons.constants.arch_function_hanlder")
def test_process_messages(mock_hanlder):
messages = sample_messages()
processed = process_messages(messages)
assert len(processed) == 3
assert processed[0] == {"role": "user", "content": "Hello!"}
assert processed[1] == {
"role": "assistant",
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
}
assert processed[2] == {
"role": "user",
"content": "<tool_response>\nResponse from tool\n</tool_response>",
}
@patch("app.commons.constants.arch_function_client")
@patch("app.commons.constants.arch_function_hanlder")
@pytest.mark.asyncio
async def test_chat_completion(mock_hanlder, mock_client):
# Mock the model list return for client
mock_client.models.list.return_value = MagicMock(
data=[MagicMock(id="sample_model")]
)
request = sample_request(sample_messages())
# Simulate stream response as list of tokens
mock_response = AsyncMock()
mock_response.__aiter__.return_value = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
]
mock_client.chat.completions.create.return_value = mock_response
# Mock the tool formatter
mock_hanlder._format_system.return_value = "<formatted_tools>"
response = Response()
chat_response = await chat_completion(request, response)
assert isinstance(chat_response, ChatCompletionResponse)
assert chat_response.choices[0].message.content is not None
first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
assert first_call_args["stream"] == True
assert "model" in first_call_args
assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
# Check that the arguments for the second call to 'create' include the pre-fill completion
second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
assert second_call_args["stream"] == False
assert "model" in second_call_args
assert second_call_args["messages"][-1]["content"] in const.PREFILL_LIST