mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
update fix
This commit is contained in:
parent
d9c64738c7
commit
68c2243e83
6 changed files with 43 additions and 33 deletions
|
|
@ -30,8 +30,7 @@ def get_arch_messages(response_json):
|
|||
arch_messages = []
|
||||
if response_json and "metadata" in response_json:
|
||||
# load arch_state from metadata
|
||||
arch_state_str = response_json.get("metadata") or {}
|
||||
arch_state_str = arch_state_str.get(ARCH_STATE_HEADER, "{}")
|
||||
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
|
||||
# parse arch_state into json object
|
||||
arch_state = json.loads(arch_state_str)
|
||||
# load messages from arch_state
|
||||
|
|
|
|||
|
|
@ -262,8 +262,8 @@ def test_prompt_gateway_default_target(stream):
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prefill_enabled", [True, False])
|
||||
def test_prompt_gateway_arch_prefill(prefill_enabled):
|
||||
@pytest.mark.parametrize("stream", [True, False])
|
||||
def test_prompt_gateway_arch_prefill(stream):
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
|
|
@ -271,29 +271,44 @@ def test_prompt_gateway_arch_prefill(prefill_enabled):
|
|||
"content": "how is the weather",
|
||||
}
|
||||
],
|
||||
"prefill_enabled": prefill_enabled,
|
||||
"stream": stream,
|
||||
}
|
||||
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body)
|
||||
assert response.status_code == 200
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
if prefill_enabled:
|
||||
prefill_list = [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
]
|
||||
assistant_message = choices[0]["message"]["content"]
|
||||
assert any(
|
||||
assistant_message.startswith(word) for word in prefill_list
|
||||
), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'"
|
||||
|
||||
if stream:
|
||||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
# make sure arch responded directly
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
# and tool call is null
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
assert len(tool_calls) == 0
|
||||
response_json = json.loads(chunks[1])
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
message = choices[0]["delta"]["content"]
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
message = choices[0]["message"]["content"]
|
||||
assert "Could you provide the following details days" not in message
|
||||
|
||||
prefill_list = [
|
||||
"May",
|
||||
"Could",
|
||||
"Sure",
|
||||
"Definitely",
|
||||
"Certainly",
|
||||
"Of course",
|
||||
"Can",
|
||||
]
|
||||
|
||||
assert any(
|
||||
message.startswith(word) for word in prefill_list
|
||||
), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'"
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ logger = utils.get_model_server_logger()
|
|||
|
||||
arch_function_hanlder = ArchFunctionHandler()
|
||||
prefill_list = ["May", "Could", "Sure", "Definitely", "Certainly", "Of course", "Can"]
|
||||
prefill_enabled = True
|
||||
arch_function_endpoint = "https://api.fc.archgw.com/v1"
|
||||
arch_function_client = utils.get_client(arch_function_endpoint)
|
||||
arch_function_generation_params = {
|
||||
|
|
|
|||
|
|
@ -64,9 +64,7 @@ def process_messages(history: list[Message]):
|
|||
return updated_history
|
||||
|
||||
|
||||
async def chat_completion(
|
||||
req: ChatMessage, res: Response, prefill_enabled: bool = True
|
||||
):
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
logger.info("starting request")
|
||||
|
||||
tools_encoded = const.arch_function_hanlder._format_system(req.tools)
|
||||
|
|
@ -89,14 +87,14 @@ async def chat_completion(
|
|||
resp = const.arch_function_client.chat.completions.create(
|
||||
messages=messages,
|
||||
model=client_model_name,
|
||||
stream=prefill_enabled,
|
||||
stream=const.prefill_enabled,
|
||||
extra_body=const.arch_function_generation_params,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"model_server <= arch_function: error: {e}")
|
||||
raise
|
||||
|
||||
if prefill_enabled:
|
||||
if const.prefill_enabled:
|
||||
first_token_content = ""
|
||||
for token in resp:
|
||||
first_token_content = token.choices[
|
||||
|
|
|
|||
|
|
@ -216,10 +216,7 @@ async def hallucination(req: HallucinationRequest, res: Response):
|
|||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(req: ChatMessage, res: Response, request: Request):
|
||||
try:
|
||||
prefill_enabled = (
|
||||
request.query_params.get("prefill_enabled", "true").lower() == "true"
|
||||
)
|
||||
result = await arch_function_chat_completion(req, res, prefill_enabled)
|
||||
result = await arch_function_chat_completion(req, res)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_completion: {e}")
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ async def test_chat_completion(mock_hanlder, mock_client):
|
|||
mock_hanlder._format_system.return_value = "<formatted_tools>"
|
||||
|
||||
response = Response()
|
||||
chat_response = await chat_completion(request, response, prefill_enabled=True)
|
||||
chat_response = await chat_completion(request, response)
|
||||
|
||||
assert isinstance(chat_response, ChatCompletionResponse)
|
||||
assert chat_response.choices[0].message.content is not None
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue