mirror of
https://github.com/katanemo/plano.git
synced 2026-06-17 15:25:17 +02:00
add e2e test
This commit is contained in:
parent
0d9cbdebda
commit
e74a3e1e38
5 changed files with 61 additions and 4 deletions
15
e2e_tests/.vscode/launch.json
vendored
Normal file
15
e2e_tests/.vscode/launch.json
vendored
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python Debugger: Current File",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "${file}",
|
||||
"console": "integratedTerminal"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
|
@ -30,7 +30,8 @@ def get_arch_messages(response_json):
|
|||
arch_messages = []
|
||||
if response_json and "metadata" in response_json:
|
||||
# load arch_state from metadata
|
||||
arch_state_str = response_json.get("metadata", {}).get(ARCH_STATE_HEADER, "{}")
|
||||
arch_state_str = response_json.get("metadata") or {}
|
||||
arch_state_str = arch_state_str.get(ARCH_STATE_HEADER, "{}")
|
||||
# parse arch_state into json object
|
||||
arch_state = json.loads(arch_state_str)
|
||||
# load messages from arch_state
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import json
|
|||
import pytest
|
||||
import requests
|
||||
from deepdiff import DeepDiff
|
||||
import model_server.app.commons.constants as const
|
||||
|
||||
from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks
|
||||
|
||||
|
|
@ -260,3 +261,39 @@ def test_prompt_gateway_default_target(stream):
|
|||
response_json.get("choices")[0]["message"]["content"]
|
||||
== "I can help you with weather forecast or insurance claim details"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("prefill_enabled", [True, False])
|
||||
def test_prompt_gateway_arch_prefill(prefill_enabled):
|
||||
body = {
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how is the weather",
|
||||
}
|
||||
],
|
||||
"prefill_enabled": prefill_enabled,
|
||||
}
|
||||
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body)
|
||||
assert response.status_code == 200
|
||||
if prefill_enabled:
|
||||
chunks = get_data_chunks(response, n=3)
|
||||
assert len(chunks) > 0
|
||||
response_json = json.loads(chunks[0])
|
||||
# make sure arch responded directly
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
# and tool call is null
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
|
||||
assert len(tool_calls) == 0
|
||||
assistant_message = choices[0].get("delta", {}).get("content", "")
|
||||
assert assistant_message in const.prefill_list
|
||||
|
||||
else:
|
||||
response_json = response.json()
|
||||
assert response_json.get("model").startswith("Arch")
|
||||
choices = response_json.get("choices", [])
|
||||
assert len(choices) > 0
|
||||
message = choices[0]["message"]["content"]
|
||||
assert "Could you provide the following details days" not in message
|
||||
|
|
|
|||
|
|
@ -133,6 +133,7 @@ async def chat_completion(
|
|||
if hasattr(token.choices[0].delta, "content"):
|
||||
full_response += token.choices[0].delta.content
|
||||
else:
|
||||
logger.info("Stream is disabled, not engaging pre-filling")
|
||||
full_response = resp.choices[0].message.content
|
||||
|
||||
tool_calls = const.arch_function_hanlder.extract_tool_calls(full_response)
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ import app.prompt_guard.model_utils as guard_utils
|
|||
|
||||
from typing import List, Dict
|
||||
from pydantic import BaseModel
|
||||
from fastapi import FastAPI, Response, HTTPException
|
||||
from fastapi import FastAPI, Response, HTTPException, Request
|
||||
from app.function_calling.model_utils import ChatMessage
|
||||
|
||||
from app.commons.constants import embedding_model, zero_shot_model, arch_guard_handler
|
||||
|
|
@ -214,9 +214,12 @@ async def hallucination(req: HallucinationRequest, res: Response):
|
|||
|
||||
|
||||
@app.post("/v1/chat/completions")
|
||||
async def chat_completion(req: ChatMessage, res: Response):
|
||||
async def chat_completion(req: ChatMessage, res: Response, request: Request):
|
||||
try:
|
||||
result = await arch_function_chat_completion(req, res)
|
||||
prefill_enabled = (
|
||||
request.query_params.get("prefill_enabled", "true").lower() == "true"
|
||||
)
|
||||
result = await arch_function_chat_completion(req, res, prefill_enabled)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in chat_completion: {e}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue