add prefill and test

This commit is contained in:
cotran 2024-10-30 17:00:30 -07:00
parent bb9a774a72
commit 5919f8b9b9
3 changed files with 154 additions and 17 deletions

View file

@ -17,23 +17,23 @@ overrides:
llm_providers:
- name: gpt-4o-mini
access_key: $OPENAI_API_KEY
access_key: OPENAI_API_KEY
provider: openai
model: gpt-4o-mini
default: true
- name: gpt-3.5-turbo-0125
access_key: $OPENAI_API_KEY
access_key: OPENAI_API_KEY
provider: openai
model: gpt-3.5-turbo-0125
- name: gpt-4o
access_key: $OPENAI_API_KEY
access_key: OPENAI_API_KEY
provider: openai
model: gpt-4o
- name: ministral-3b
access_key: $MISTRAL_API_KEY
access_key: MISTRAL_API_KEY
provider: mistral
model: ministral-3b-latest

View file

@ -5,17 +5,17 @@ import app.commons.constants as const
from fastapi import Response
from pydantic import BaseModel
from app.commons.utilities import get_model_server_logger
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
logger = get_model_server_logger()
class Message(BaseModel):
role: str
content: str = ""
tool_calls: List[Dict[str, Any]] = []
tool_call_id: str = ""
role: Optional[str] = ""
content: Optional[str] = ""
tool_calls: Optional[List[Dict[str, Any]]] = []
tool_call_id: Optional[str] = ""
class ChatMessage(BaseModel):
@ -23,6 +23,14 @@ class ChatMessage(BaseModel):
tools: List[Dict[str, Any]]
class Choice(BaseModel):
message: Message
class ChatCompletionResponse(BaseModel):
choices: List[Choice]
def process_messages(history: list[Message]):
updated_history = []
for hist in history:
@ -70,23 +78,63 @@ async def chat_completion(req: ChatMessage, res: Response):
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
stream=True,
extra_body=const.arch_function_generation_params,
)
tool_calls = const.arch_function_hanlder.extract_tool_calls(
resp.choices[0].message.content
)
# Retrieve the first token, handling the Stream object carefully
first_token_content = ""
try:
while True:
first_token = next(resp) # Synchronously retrieve tokens
first_token_content = first_token.choices[
0
].delta.content.strip() # Clean up the content
if first_token_content: # Break if it's non-empty
break
except StopIteration:
print("No non-empty tokens found.")
return None
# Check if the first token requires tool call handling
if first_token_content != "<tool_call>":
# Engage pre-filling response if no tool call is indicated
logger.info("Tool call is not found! Engage pre filling")
messages.append({"role": "assistant", "content": "Sure!"})
# Send a new completion request with the updated messages
pre_fill_resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=False,
extra_body=const.arch_function_generation_params,
)
full_response = pre_fill_resp.choices[0].message.content
else:
# Initialize full response and iterate over tokens to gather the full response
full_response = "<tool_call>"
try:
while True:
token = next(resp) # Retrieve each token synchronously
if hasattr(token.choices[0].delta, "content"):
full_response += token.choices[0].delta.content
except StopIteration:
pass # End of stream
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
if tool_calls:
resp.choices[0].message.tool_calls = tool_calls
resp.choices[0].message.content = None
message = Message(content="", tool_calls=tool_calls)
else:
message = Message(content=full_response, tool_calls=[])
choice = Choice(message=message)
chat_completion_response = ChatCompletionResponse(choices=[choice])
logger.info(
f"model_server <= arch_function: (tools): {json.dumps([tool_call['function'] for tool_call in tool_calls])}"
)
logger.info(
f"model_server <= arch_function: response body: {json.dumps(resp.to_dict())}"
f"model_server <= arch_function: response body: {json.dumps(chat_completion_response.dict())}"
)
return resp
return chat_completion_response

View file

@ -0,0 +1,89 @@
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi import Response
from app.function_calling.model_utils import (
process_messages,
chat_completion,
Message,
ChatMessage,
Choice,
ChatCompletionResponse,
)
def sample_messages():
# Ensure fields are explicitly set with valid data or empty values
return [
Message(role="user", content="Hello!", tool_calls=[], tool_call_id=""),
Message(
role="assistant",
content="",
tool_calls=[{"function": {"name": "sample_tool"}}],
tool_call_id="sample_id",
),
Message(
role="tool", content="Response from tool", tool_calls=[], tool_call_id=""
),
]
def sample_request(sample_messages):
return ChatMessage(
messages=sample_messages,
tools=[{"name": "sample_tool", "description": "A sample tool"}],
)
@patch("app.commons.constants.arch_function_hanlder")
def test_process_messages(mock_hanlder):
messages = sample_messages()
processed = process_messages(messages)
assert len(processed) == 3
assert processed[0] == {"role": "user", "content": "Hello!"}
assert processed[1] == {
"role": "assistant",
"content": '<tool_call>\n{"name": "sample_tool"}\n</tool_call>',
}
assert processed[2] == {
"role": "user",
"content": "<tool_response>\nResponse from tool\n</tool_response>",
}
@patch("app.commons.constants.arch_function_client")
@patch("app.commons.constants.arch_function_hanlder")
@pytest.mark.asyncio
async def test_chat_completion(mock_hanlder, mock_client):
# Mock the model list return for client
mock_client.models.list.return_value = MagicMock(
data=[MagicMock(id="sample_model")]
)
request = sample_request(sample_messages())
# Simulate stream response as list of tokens
mock_response = AsyncMock()
mock_response.__aiter__.return_value = [
MagicMock(choices=[MagicMock(delta=MagicMock(content="Hi there!"))]),
MagicMock(choices=[MagicMock(delta=MagicMock(content=""))]), # end of stream
]
mock_client.chat.completions.create.return_value = mock_response
# Mock the tool formatter
mock_hanlder._format_system.return_value = "<formatted_tools>"
response = Response()
chat_response = await chat_completion(request, response)
assert isinstance(chat_response, ChatCompletionResponse)
assert chat_response.choices[0].message.content is not None
first_call_args = mock_client.chat.completions.create.call_args_list[0][1]
assert first_call_args["stream"] == True
assert "model" in first_call_args
assert first_call_args["messages"][0]["content"] == "<formatted_tools>"
# Check that the arguments for the second call to 'create' include the pre-fill completion
second_call_args = mock_client.chat.completions.create.call_args_list[1][1]
assert second_call_args["stream"] == False
assert "model" in second_call_args
assert second_call_args["messages"][-1]["content"] == "Sure!"