diff --git a/tests/e2e/test_model_alias_routing.py b/tests/e2e/test_model_alias_routing.py index c285bda8..5b1d3719 100644 --- a/tests/e2e/test_model_alias_routing.py +++ b/tests/e2e/test_model_alias_routing.py @@ -634,3 +634,150 @@ def test_anthropic_client_streaming_with_bedrock(): # Verify final message structure assert final_message is not None assert final_message.content and len(final_message.content) > 0 + + +def test_openai_client_streaming_with_bedrock(): + """Test OpenAI client using 'coding-model' alias (maps to Bedrock) with streaming""" + logger.info( + "Testing OpenAI client with 'coding-model' alias -> Bedrock (streaming)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI( + api_key="test-key", + base_url=f"{base_url}/v1", + ) + + stream = client.chat.completions.create( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=500, + messages=[ + { + "role": "user", + "content": "Write a short 4-line sonnet about coding.", + } + ], + stream=True, + ) + + content_chunks = [] + for chunk in stream: + if chunk.choices and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + if delta.content: + content_chunks.append(delta.content) + + full_content = "".join(content_chunks) + logger.info(f"Streaming response from coding-model: {full_content}") + + # Should get a text response + assert len(full_content) > 0, "Expected text response from streaming" + + +def test_openai_client_streaming_with_bedrock_and_tools(): + """Test OpenAI client using 'coding-model' alias (maps to Bedrock) with streaming and tools""" + logger.info( + "Testing OpenAI client with 'coding-model' alias -> Bedrock with tools (streaming)" + ) + + base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "") + client = openai.OpenAI( + api_key="test-key", + base_url=f"{base_url}/v1", + ) + + stream = client.chat.completions.create( + model="coding-model", # This should resolve to us.amazon.nova-premier-v1:0 + max_tokens=1000, + messages=[ + { + "role": "user", + "content": "I need to write a Python function that calculates the factorial of a number. Can you help me write and run it?. You should use the tool to run the code.", + } + ], + tools=[ + { + "type": "function", + "function": { + "name": "run_python_code", + "description": "Execute Python code and return the result", + "parameters": { + "type": "object", + "properties": { + "code": { + "type": "string", + "description": "Python code to execute", + } + }, + "required": ["code"], + }, + }, + } + ], + tool_choice="auto", + stream=True, + ) + + content_chunks = [] + tool_calls = [] + chunk_count = 0 + + for chunk in stream: + chunk_count += 1 + if chunk.choices and len(chunk.choices) > 0: + delta = chunk.choices[0].delta + + # Log what we see in each chunk + has_content = delta.content is not None + has_tool_calls = delta.tool_calls is not None + + if ( + chunk_count % 50 == 0 or has_tool_calls + ): # Log every 50th chunk or any chunk with tool calls + logger.info( + f"Chunk {chunk_count}: content={has_content}, tool_calls={has_tool_calls}" + ) + if has_tool_calls: + logger.info(f" Tool calls in chunk: {delta.tool_calls}") + + # Collect text content + if delta.content: + content_chunks.append(delta.content) + + # Collect tool calls + if delta.tool_calls: + for tool_call in delta.tool_calls: + # Extend or create tool call entries + while len(tool_calls) <= tool_call.index: + tool_calls.append( + { + "id": "", + "type": "function", + "function": {"name": "", "arguments": ""}, + } + ) + + if tool_call.id: + tool_calls[tool_call.index]["id"] = tool_call.id + if tool_call.function: + if tool_call.function.name: + tool_calls[tool_call.index]["function"][ + "name" + ] = tool_call.function.name + if tool_call.function.arguments: + tool_calls[tool_call.index]["function"][ + "arguments" + ] += tool_call.function.arguments + + full_content = "".join(content_chunks) + logger.info(f"Streaming response from coding-model with tools: {full_content}") + logger.info(f"Tool calls collected: {len(tool_calls)}") + + if tool_calls: + for i, tc in enumerate(tool_calls): + logger.info(f" Tool call {i}: {tc['function']['name']}") + + # Should get either text response or tool calls for coding assistance + assert ( + full_content or len(tool_calls) > 0 + ), f"Expected text or tool calls. Got text_len={len(full_content)}, tools={len(tool_calls)}"