diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 733048b67..7fdc6ece0 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -10,7 +10,14 @@ import time from typing import NamedTuple, Union import httpx -from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI +from openai import ( + APIConnectionError, + AsyncAzureOpenAI, + AsyncOpenAI, + AsyncStream, + AzureOpenAI, + OpenAI, +) from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from tenacity import ( @@ -26,7 +33,6 @@ from metagpt.logs import logger from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.schema import Message -from metagpt.utils.common import ensure_trailing_slash from metagpt.utils.singleton import Singleton from metagpt.utils.token_counter import ( TOKEN_COSTS, @@ -154,40 +160,49 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): RateLimiter.__init__(self, rpm=self.rpm) def __init_openai(self, config: Config): - client_kwargs, async_client_kwargs = self._make_client_kwargs(config) - - self.client = OpenAI(**client_kwargs) - self.async_client = AsyncOpenAI(**async_client_kwargs) - + self._make_client(config) self.rpm = int(config.get("RPM", 10)) - def _make_client_kwargs(self, config: Config) -> (dict, dict): - mapping = { - "api_key": "openai_api_key", - "base_url": "openai_base_url", - } - kwargs = {} - for key, attr in mapping.items(): - value = getattr(config, attr, None) - if value: - kwargs[key] = value + def _make_client(self, config: Config): + kwargs, async_kwargs = self._make_client_kwargs(config) - # OpenAI v1 requires the base_url to end with / - if config.openai_base_url: - kwargs["base_url"] = ensure_trailing_slash(config.openai_base_url) + if self._is_azure(config): + self.client = AzureOpenAI(**kwargs) + self.async_client = AsyncAzureOpenAI(**async_kwargs) + else: + self.client = OpenAI(**kwargs) + self.async_client = AsyncOpenAI(**async_kwargs) + + def _make_client_kwargs(self, config: Config) -> (dict, dict): + if self._is_azure(config): + kwargs = dict( + api_key=config.openai_api_key, + api_version=config.openai_api_version, + azure_endpoint=config.openai_base_url, + ) + else: + kwargs = dict(api_key=config.openai_api_key, base_url=config.openai_base_url) async_kwargs = kwargs.copy() - # Create http_client if proxy is specified + # to use proxy, openai v1 needs http_client + proxy_params = self._get_proxy_params(config) + if proxy_params: + kwargs["http_client"] = httpx.Client(**proxy_params) + async_kwargs["http_client"] = httpx.AsyncClient(**proxy_params) + + return kwargs, async_kwargs + + def _is_azure(self, config: Config) -> bool: + return config.openai_api_type == "azure" + + def _get_proxy_params(self, config: Config) -> dict: + params = {} if config.openai_proxy: params = {"proxies": config.openai_proxy} if config.openai_base_url: params["base_url"] = config.openai_base_url - - kwargs["http_client"] = httpx.Client(**params) - async_kwargs["http_client"] = httpx.AsyncClient(**params) - - return kwargs, async_kwargs + return params async def _achat_completion_stream(self, messages: list[dict]) -> str: response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( @@ -230,9 +245,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): elif not CONFIG.deployment_name and not CONFIG.deployment_id: raise ValueError("You must specify `DEPLOYMENT_NAME` or `DEPLOYMENT_ID` parameter") kwargs_mode = ( - {"engine": CONFIG.deployment_name} - if CONFIG.deployment_name - else {"deployment_id": CONFIG.deployment_id} + {"model": CONFIG.deployment_name} if CONFIG.deployment_name else {"deployment_id": CONFIG.deployment_id} ) else: kwargs_mode = {"model": self.model} diff --git a/metagpt/utils/common.py b/metagpt/utils/common.py index c69a0fe10..f09666beb 100644 --- a/metagpt/utils/common.py +++ b/metagpt/utils/common.py @@ -305,9 +305,3 @@ def parse_recipient(text): pattern = r"## Send To:\s*([A-Za-z]+)\s*?" # hard code for now recipient = re.search(pattern, text) return recipient.group(1) if recipient else "" - - -def ensure_trailing_slash(url): - if not url: - return url - return url if url.endswith("/") else url + "/" diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 3e8dbf7e7..8d853f11c 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -1,5 +1,6 @@ +from unittest.mock import Mock + import pytest -from httpx import AsyncClient, Client from metagpt.provider.openai_api import OpenAIGPTAPI from metagpt.schema import UserMessage @@ -81,41 +82,64 @@ def test_ask_code_list_str(): assert len(rsp["code"]) > 0 -def test_make_client_kwargs(): - class Config: - openai_api_key = "test_key" - openai_base_url = "test_url" - openai_proxy = "http://test_proxy" +class TestOpenAI: + @pytest.fixture + def config(self): + return Mock(openai_api_key="test_key", openai_base_url="test_url", openai_proxy=None, openai_api_type="other") - config = Config() - obj = OpenAIGPTAPI() - kwargs, async_kwargs = obj._make_client_kwargs(config) + @pytest.fixture + def config_azure(self): + return Mock( + openai_api_key="test_key", + openai_api_version="test_version", + openai_base_url="test_url", + openai_proxy=None, + openai_api_type="azure", + ) - assert kwargs["api_key"] == "test_key" - assert kwargs["base_url"] == "test_url/" - assert isinstance(kwargs["http_client"], Client) - assert kwargs["http_client"].base_url == "test_url/" + @pytest.fixture + def config_proxy(self): + return Mock( + openai_api_key="test_key", + openai_base_url="test_url", + openai_proxy="http://proxy.com", + openai_api_type="other", + ) - assert async_kwargs["api_key"] == "test_key" - assert async_kwargs["base_url"] == "test_url/" - assert isinstance(async_kwargs["http_client"], AsyncClient) - assert async_kwargs["http_client"].base_url == "test_url/" + @pytest.fixture + def config_azure_proxy(self): + return Mock( + openai_api_key="test_key", + openai_api_version="test_version", + openai_base_url="test_url", + openai_proxy="http://proxy.com", + openai_api_type="azure", + ) + def test_make_client_kwargs_without_proxy(self, config): + instance = OpenAIGPTAPI() + kwargs, async_kwargs = instance._make_client_kwargs(config) + assert kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} + assert "http_client" not in kwargs + assert "http_client" not in async_kwargs -def test_make_client_kwargs_no_proxy(): - class Config: - openai_api_key = "test_key" - openai_base_url = "test_url" - openai_proxy = None + def test_make_client_kwargs_without_proxy_azure(self, config_azure): + instance = OpenAIGPTAPI() + kwargs, async_kwargs = instance._make_client_kwargs(config_azure) + assert kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} + assert async_kwargs == {"api_key": "test_key", "api_version": "test_version", "azure_endpoint": "test_url"} + assert "http_client" not in kwargs + assert "http_client" not in async_kwargs - config = Config() - obj = OpenAIGPTAPI() - kwargs, async_kwargs = obj._make_client_kwargs(config) + def test_make_client_kwargs_with_proxy(self, config_proxy): + instance = OpenAIGPTAPI() + kwargs, async_kwargs = instance._make_client_kwargs(config_proxy) + assert "http_client" in kwargs + assert "http_client" in async_kwargs - assert kwargs["api_key"] == "test_key" - assert kwargs["base_url"] == "test_url/" - assert "http_client" not in kwargs - - assert async_kwargs["api_key"] == "test_key" - assert async_kwargs["base_url"] == "test_url/" - assert "http_client" not in async_kwargs + def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): + instance = OpenAIGPTAPI() + kwargs, async_kwargs = instance._make_client_kwargs(config_azure_proxy) + assert "http_client" in kwargs + assert "http_client" in async_kwargs