This commit is contained in:
cotran 2024-11-06 16:31:22 -08:00
parent 68c2243e83
commit dd07ba2cd0

View file

@ -6,6 +6,19 @@ from deepdiff import DeepDiff
from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks
def get_prefill_list():
prefill_list = [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
]
return prefill_list
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway(stream):
expected_tool_call = {
@ -101,13 +114,22 @@ def test_prompt_gateway_arch_direct_response(stream):
assert len(choices) > 0
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
assert len(tool_calls) == 0
response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
message = choices[0]["delta"]["content"]
else:
response_json = response.json()
assert response_json.get("model").startswith("Arch")
choices = response_json.get("choices", [])
assert len(choices) > 0
message = choices[0]["message"]["content"]
assert "Could you provide the following details days" not in message
prefill_list = get_prefill_list()
assert any(
message.startswith(word) for word in prefill_list
), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'"
@pytest.mark.parametrize("stream", [True, False])
@ -260,55 +282,3 @@ def test_prompt_gateway_default_target(stream):
response_json.get("choices")[0]["message"]["content"]
== "I can help you with weather forecast or insurance claim details"
)
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway_arch_prefill(stream):
body = {
"messages": [
{
"role": "user",
"content": "how is the weather",
}
],
"stream": stream,
}
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body)
assert response.status_code == 200
if stream:
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
# make sure arch responded directly
assert response_json.get("model").startswith("Arch")
# and tool call is null
choices = response_json.get("choices", [])
assert len(choices) > 0
tool_calls = choices[0].get("delta", {}).get("tool_calls", [])
assert len(tool_calls) == 0
response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
message = choices[0]["delta"]["content"]
else:
response_json = response.json()
assert response_json.get("model").startswith("Arch")
choices = response_json.get("choices", [])
assert len(choices) > 0
message = choices[0]["message"]["content"]
assert "Could you provide the following details days" not in message
prefill_list = [
"May",
"Could",
"Sure",
"Definitely",
"Certainly",
"Of course",
"Can",
]
assert any(
message.startswith(word) for word in prefill_list
), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'"