add prefill and test (#236)

* add prefill and test

* fix stream

* fix

* feedback

* address comments

* update

* add e2e test

* fix e2e test

* update fix

* fix

* address cmt

* address cmt
This commit is contained in:
CTran 2024-11-07 11:59:29 -08:00 committed by GitHub
parent f48489f7c0
commit fb67788be0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 216 additions and 19 deletions

15
e2e_tests/.vscode/launch.json vendored Normal file
View file

@ -0,0 +1,15 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python Debugger: Current File",
"type": "debugpy",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal"
}
]
}

View file

@ -10,6 +10,16 @@ LLM_GATEWAY_ENDPOINT = os.getenv(
)
ARCH_STATE_HEADER = "x-arch-state"
PREFILL_LIST = [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
]
def get_data_chunks(stream, n=1):
chunks = []

View file

@ -3,7 +3,12 @@ import pytest
import requests
from deepdiff import DeepDiff
from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks
from common import (
PROMPT_GATEWAY_ENDPOINT,
PREFILL_LIST,
get_arch_messages,
get_data_chunks,
)
@pytest.mark.parametrize("stream", [True, False])
@ -101,13 +106,21 @@ def test_prompt_gateway_arch_direct_response(stream):
assert len(choices) > 0
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
assert len(tool_calls) == 0
response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
message = choices[0]["delta"]["content"]
else:
response_json = response.json()
assert response_json.get("model").startswith("Arch")
choices = response_json.get("choices", [])
assert len(choices) > 0
message = choices[0]["message"]["content"]
assert "Could you provide the following details days" not in message
assert any(
message.startswith(word) for word in PREFILL_LIST
), f"Expected assistant message to start with one of {PREFILL_LIST}, but got '{assistant_message}'"
@pytest.mark.parametrize("stream", [True, False])