add more tests

This commit is contained in:
Adil Hafeez 2024-10-27 15:02:43 -07:00
parent f5ecf733ff
commit e35560623c
7 changed files with 171 additions and 49 deletions

View file

@ -12,6 +12,7 @@ LLM_GATEWAY_ENDPOINT = os.getenv(
def get_data_chunks(stream, n=1):
chunks = []
for chunk in stream.iter_lines():
print(chunk)
if chunk:
chunk = chunk.decode("utf-8")
chunk_data_id = chunk[0:6]

View file

@ -43,7 +43,7 @@ def test_prompt_gateway_param_gathering(stream):
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
assert response.status_code == 200
if stream:
chunks = get_data_chunks(response)
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
# if its streaming we return tool call and api call in first two chunks
@ -112,3 +112,33 @@ def test_prompt_gateway_param_tool_call(stream):
else:
response_json = response.json()
assert response_json.get("model").startswith("gpt-4o-mini")
@pytest.mark.parametrize("stream", [True, False])
def test_prompt_gateway_default_target(stream):
body = {
"messages": [
{
"role": "user",
"content": "hello, what can you do for me?",
},
],
"stream": stream,
}
response = requests.post(PROMPT_GATEWAY_ENDPOINT, json=body, stream=stream)
assert response.status_code == 200
if stream:
chunks = get_data_chunks(response, n=3)
assert len(chunks) > 0
response_json = json.loads(chunks[0])
assert response_json.get("model").startswith("api_server")
response_json = json.loads(chunks[1])
choices = response_json.get("choices", [])
assert len(choices) > 0
content = choices[0]["delta"]["content"]
assert (
content == "I can help you with weather forecast or insurance claim details"
)
else:
response_json = response.json()
assert response_json.get("model").startswith("api_server")