address cmt

This commit is contained in:
cotran 2024-11-07 11:15:03 -08:00
parent dd07ba2cd0
commit 1f383eafc4
4 changed files with 27 additions and 23 deletions

View file

@ -10,6 +10,16 @@ LLM_GATEWAY_ENDPOINT = os.getenv(
)
ARCH_STATE_HEADER = "x-arch-state"
PREFILL_LIST = [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
]
def get_data_chunks(stream, n=1):
chunks = []

View file

@ -3,20 +3,12 @@ import pytest
import requests
from deepdiff import DeepDiff
from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks
def get_prefill_list():
prefill_list = [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
]
return prefill_list
from common import (
PROMPT_GATEWAY_ENDPOINT,
PREFILL_LIST,
get_arch_messages,
get_data_chunks,
)
@pytest.mark.parametrize("stream", [True, False])
@ -126,10 +118,9 @@ def test_prompt_gateway_arch_direct_response(stream):
message = choices[0]["message"]["content"]
assert "Could you provide the following details days" not in message
prefill_list = get_prefill_list()
assert any(
message.startswith(word) for word in prefill_list
), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'"
message.startswith(word) for word in PREFILL_LIST
), f"Expected assistant message to start with one of {PREFILL_LIST}, but got '{assistant_message}'"
@pytest.mark.parametrize("stream", [True, False])

View file

@ -8,8 +8,9 @@ from app.prompt_guard.model_handler import ArchGuardHanlder
logger = utils.get_model_server_logger()
arch_function_hanlder = ArchFunctionHandler()
prefill_list = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
prefill_enabled = True
PREFILL_LIST = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
PREFILL_ENABLED = True
TOOL_CALL_TOKEN = "<tool_call>"
arch_function_endpoint = "https://api.fc.archgw.com/v1"
arch_function_client = utils.get_client(arch_function_endpoint)
arch_function_generation_params = {

View file

@ -87,14 +87,14 @@ async def chat_completion(req: ChatMessage, res: Response):
resp = const.arch_function_client.chat.completions.create(
messages=messages,
model=client_model_name,
stream=const.prefill_enabled,
stream=const.PREFILL_ENABLED,
extra_body=const.arch_function_generation_params,
)
except Exception as e:
logger.error(f"model_server <= arch_function: error: {e}")
raise
if const.prefill_enabled:
if const.PREFILL_ENABLED:
first_token_content = ""
for token in resp:
first_token_content = token.choices[
@ -104,14 +104,16 @@ async def chat_completion(req: ChatMessage, res: Response):
break
# Check if the first token requires tool call handling
if first_token_content != "<tool_call>":
if first_token_content != const.TOOL_CALL_TOKEN:
# Engage pre-filling response if no tool call is indicated
resp.close()
logger.info("Tool call is not found! Engage pre filling")
prefill_content = random.choice(const.prefill_list)
prefill_content = random.choice(const.PREFILL_LIST)
messages.append({"role": "assistant", "content": prefill_content})
# Send a new completion request with the updated messages
# the model will continue the final message in the chat instead of starting a new one
# disable add_generation_prompt which tells the template to add tokens that indicate the start of a bot response.
extra_body = {
**const.arch_function_generation_params,
"continue_final_message": True,