diff --git a/e2e_tests/test_prompt_gateway.py b/e2e_tests/test_prompt_gateway.py index d614a089..e1ff8a9c 100644 --- a/e2e_tests/test_prompt_gateway.py +++ b/e2e_tests/test_prompt_gateway.py @@ -6,6 +6,19 @@ from deepdiff import DeepDiff from common import PROMPT_GATEWAY_ENDPOINT, get_arch_messages, get_data_chunks +def get_prefill_list(): + prefill_list = [ + "May", + "Could", + "Sure", + "Definitely", + "Certainly", + "Of course", + "Can", + ] + return prefill_list + + @pytest.mark.parametrize("stream", [True, False]) def test_prompt_gateway(stream): expected_tool_call = { @@ -101,13 +114,22 @@ def test_prompt_gateway_arch_direct_response(stream): assert len(choices) > 0 tool_calls = choices[0].get("delta", {}).get("tool_calls", []) assert len(tool_calls) == 0 + response_json = json.loads(chunks[1]) + choices = response_json.get("choices", []) + assert len(choices) > 0 + message = choices[0]["delta"]["content"] else: response_json = response.json() assert response_json.get("model").startswith("Arch") choices = response_json.get("choices", []) assert len(choices) > 0 message = choices[0]["message"]["content"] + assert "Could you provide the following details days" not in message + prefill_list = get_prefill_list() + assert any( + message.startswith(word) for word in prefill_list + ), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'" @pytest.mark.parametrize("stream", [True, False]) @@ -260,55 +282,3 @@ def test_prompt_gateway_default_target(stream): response_json.get("choices")[0]["message"]["content"] == "I can help you with weather forecast or insurance claim details" ) - - -@pytest.mark.parametrize("stream", [True, False]) -def test_prompt_gateway_arch_prefill(stream): - body = { - "messages": [ - { - "role": "user", - "content": "how is the weather", - } - ], - "stream": stream, - } - response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body) - assert response.status_code == 200 - - if stream: - chunks = get_data_chunks(response, n=3) - assert len(chunks) > 0 - response_json = json.loads(chunks[0]) - # make sure arch responded directly - assert response_json.get("model").startswith("Arch") - # and tool call is null - choices = response_json.get("choices", []) - assert len(choices) > 0 - tool_calls = choices[0].get("delta", {}).get("tool_calls", []) - assert len(tool_calls) == 0 - response_json = json.loads(chunks[1]) - choices = response_json.get("choices", []) - assert len(choices) > 0 - message = choices[0]["delta"]["content"] - else: - response_json = response.json() - assert response_json.get("model").startswith("Arch") - choices = response_json.get("choices", []) - assert len(choices) > 0 - message = choices[0]["message"]["content"] - assert "Could you provide the following details days" not in message - - prefill_list = [ - "May", - "Could", - "Sure", - "Definitely", - "Certainly", - "Of course", - "Can", - ] - - assert any( - message.startswith(word) for word in prefill_list - ), f"Expected assistant message to start with one of {prefill_list}, but got '{assistant_message}'"