From 02efe6d0954d6330dc76ca9f0e6d0a4c799c4de3 Mon Sep 17 00:00:00 2001 From: Adil Hafeez Date: Mon, 14 Apr 2025 15:29:23 -0700 Subject: [PATCH] fix e2e tests --- .../weather_forecast/arch_config.yaml | 2 +- demos/samples_python/weather_forecast/main.py | 5 +---- tests/e2e/test_prompt_gateway.py | 13 ++++++++----- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/demos/samples_python/weather_forecast/arch_config.yaml b/demos/samples_python/weather_forecast/arch_config.yaml index db18eb85..b6463594 100644 --- a/demos/samples_python/weather_forecast/arch_config.yaml +++ b/demos/samples_python/weather_forecast/arch_config.yaml @@ -22,12 +22,12 @@ llm_providers: provider_interface: openai model: llama-3.2-3b-preview base_url: https://api.groq.com - default: true - name: gpt-4o access_key: $OPENAI_API_KEY provider_interface: openai model: gpt-4o + default: true system_prompt: | You are a helpful assistant. diff --git a/demos/samples_python/weather_forecast/main.py b/demos/samples_python/weather_forecast/main.py index 3be2f4da..84ee75e6 100644 --- a/demos/samples_python/weather_forecast/main.py +++ b/demos/samples_python/weather_forecast/main.py @@ -73,7 +73,7 @@ async def weather(req: WeatherRequest, res: Response): class DefaultTargetRequest(BaseModel): - messages: list + messages: list = [] @app.post("/default_target") @@ -86,12 +86,9 @@ async def default_target(req: DefaultTargetRequest, res: Response): "role": "assistant", "content": "I can help you with weather forecast", }, - "finish_reason": "completed", - "index": 0, } ], "model": "api_server", - "usage": {"completion_tokens": 0}, } logger.info(f"sending response: {json.dumps(resp)}") return resp diff --git a/tests/e2e/test_prompt_gateway.py b/tests/e2e/test_prompt_gateway.py index 87459f57..e6a10f3a 100644 --- a/tests/e2e/test_prompt_gateway.py +++ b/tests/e2e/test_prompt_gateway.py @@ -77,7 +77,7 @@ def test_prompt_gateway(stream): # third..end chunk is summarization (role = assistant) response_json = json.loads(chunks[2]) - assert response_json.get("model").startswith("llama-3.2-3b-preview") + assert response_json.get("model").startswith("gpt-4o") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["delta"] @@ -86,7 +86,7 @@ def test_prompt_gateway(stream): else: response_json = response.json() - assert response_json.get("model").startswith("llama-3.2-3b-preview") + assert response_json.get("model").startswith("gpt-4o") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["message"] @@ -252,7 +252,7 @@ def test_prompt_gateway_param_tool_call(stream): # third..end chunk is summarization (role = assistant) response_json = json.loads(chunks[2]) - assert response_json.get("model").startswith("llama-3.2-3b-preview") + assert response_json.get("model").startswith("gpt-4o") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["delta"] @@ -261,7 +261,7 @@ def test_prompt_gateway_param_tool_call(stream): else: response_json = response.json() - assert response_json.get("model").startswith("llama-3.2-3b-preview") + assert response_json.get("model").startswith("gpt-4o") choices = response_json.get("choices", []) assert len(choices) > 0 assert "role" in choices[0]["message"] @@ -283,7 +283,7 @@ def test_prompt_gateway_default_target(stream): "messages": [ { "role": "user", - "content": "hello, what can you do for me?", + "content": "hello", }, ], "stream": stream, @@ -294,17 +294,20 @@ def test_prompt_gateway_default_target(stream): chunks = get_data_chunks(response, n=3) assert len(chunks) > 0 response_json = json.loads(chunks[0]) + print("response_json chunks[0]: ", response_json) assert response_json.get("model").startswith("api_server") assert len(response_json.get("choices", [])) > 0 assert response_json.get("choices")[0]["delta"]["role"] == "assistant" response_json = json.loads(chunks[1]) + print("response_json chunks[1]: ", response_json) choices = response_json.get("choices", []) assert len(choices) > 0 content = choices[0]["delta"]["content"] assert content == "I can help you with weather forecast" else: response_json = response.json() + print("response_json: ", response_json) assert response_json.get("model").startswith("api_server") assert len(response_json.get("choices")) > 0 assert response_json.get("choices")[0]["message"]["role"] == "assistant"