diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py index ca0696830..6a267b7ee 100644 --- a/metagpt/provider/azure_openai_api.py +++ b/metagpt/provider/azure_openai_api.py @@ -10,12 +10,12 @@ """ -from openai import AsyncAzureOpenAI, AzureOpenAI -from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper +from openai import AsyncAzureOpenAI +from openai._base_client import AsyncHttpxClientWrapper -from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.config import LLMProviderEnum from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter +from metagpt.provider.openai_api import OpenAIGPTAPI @register_provider(LLMProviderEnum.AZURE_OPENAI) @@ -24,46 +24,22 @@ class AzureOpenAIGPTAPI(OpenAIGPTAPI): Check https://platform.openai.com/examples for examples """ - def __init__(self): - self.config: Config = CONFIG - self._init_openai() - self.auto_max_tokens = False - RateLimiter.__init__(self, rpm=self.rpm) - - def _make_client(self): - kwargs, async_kwargs = self._make_client_kwargs() + def _init_client(self): + kwargs = self._make_client_kwargs() # https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix - self.client = AzureOpenAI(**kwargs) - self.async_client = AsyncAzureOpenAI(**async_kwargs) + self.async_client = AsyncAzureOpenAI(**kwargs) self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs - def _make_client_kwargs(self) -> (dict, dict): + def _make_client_kwargs(self) -> dict: kwargs = dict( api_key=self.config.OPENAI_API_KEY, api_version=self.config.OPENAI_API_VERSION, azure_endpoint=self.config.OPENAI_BASE_URL, ) - async_kwargs = kwargs.copy() # to use proxy, openai v1 needs http_client proxy_params = self._get_proxy_params() if proxy_params: - kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params) - async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) - - return kwargs, async_kwargs - - def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict: - kwargs = { - "messages": messages, - "max_tokens": self.get_max_tokens(messages), - "n": 1, - "stop": None, - "temperature": 0.3, - "model": self.model, - } - if configs: - kwargs.update(configs) - kwargs["timeout"] = max(CONFIG.timeout, timeout) + kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) return kwargs diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_gpt_api.py index c7417af90..90cf59fd4 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_gpt_api.py @@ -112,7 +112,7 @@ class BaseGPTAPI(BaseChatbot): """ @abstractmethod - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """Asynchronous version of completion. Return str. Support stream-print""" def get_choice_text(self, rsp: dict) -> str: diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 55b1b6c28..e42088213 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -83,7 +83,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI): def __init_fireworks(self): self.is_azure = False self.rpm = int(self.config.get("RPM", 10)) - self._make_client() + self._init_client() self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it def _make_client_kwargs(self) -> (dict, dict): @@ -103,7 +103,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI): return self._cost_manager.get_costs() async def _achat_completion_stream(self, messages: list[dict]) -> str: - response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( + response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages), stream=True ) @@ -133,9 +133,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text( - self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 - ) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: """when streaming, print each token in place.""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index eace329aa..ca2133cfa 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -136,9 +136,7 @@ class GeminiGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text( - self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 - ) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 90a50a154..0d6d51e04 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -147,9 +147,7 @@ class OllamaGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text( - self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 - ) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index dd1491780..21efb6677 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -46,7 +46,7 @@ class OpenLLMGPTAPI(OpenAIGPTAPI): def __init_openllm(self): self.is_azure = False self.rpm = int(self.config.get("RPM", 10)) - self._make_client() + self._init_client() self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it def _make_client_kwargs(self) -> (dict, dict): diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index b72eff0dc..ea58f690b 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -14,9 +14,8 @@ import json import time from typing import AsyncIterator, Union -import openai -from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI -from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper +from openai import APIConnectionError, AsyncOpenAI, AsyncStream +from openai._base_client import AsyncHttpxClientWrapper from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from tenacity import ( @@ -80,9 +79,7 @@ See FAQ 5.8 @register_provider(LLMProviderEnum.OPENAI) class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): - """ - Check https://platform.openai.com/examples for examples - """ + """Check https://platform.openai.com/examples for examples""" def __init__(self): self.config: Config = CONFIG @@ -91,27 +88,23 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): RateLimiter.__init__(self, rpm=self.rpm) def _init_openai(self): - self.rpm = int(self.config.RPM or 10) - self._make_client() + self.rpm = int(self.config.openai_api_rpm) + self._init_client() - def _make_client(self): - kwargs, async_kwargs = self._make_client_kwargs() + def _init_client(self): + kwargs = self._make_client_kwargs() # https://github.com/openai/openai-python#async-usage - self.client = OpenAI(**kwargs) - self.async_client = AsyncOpenAI(**async_kwargs) + self.aclient = AsyncOpenAI(**kwargs) self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs def _make_client_kwargs(self) -> (dict, dict): - kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL) - async_kwargs = kwargs.copy() + kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL} # to use proxy, openai v1 needs http_client - proxy_params = self._get_proxy_params() - if proxy_params: - kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params) - async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) + if proxy_params := self._get_proxy_params(): + kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) - return kwargs, async_kwargs + return kwargs def _get_proxy_params(self) -> dict: params = {} @@ -123,7 +116,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return params async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]: - response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( + response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages, timeout=timeout), stream=True ) @@ -148,18 +141,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion: kwargs = self._cons_kwargs(messages, timeout=timeout) - rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs) + rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs) self._update_costs(rsp.usage) return rsp - def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion: - rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout)) - self._update_costs(rsp.usage) - return rsp - - def completion(self, messages: list[dict], timeout=3) -> ChatCompletion: - return self._chat_completion(messages, timeout=timeout) - async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion: return await self._achat_completion(messages, timeout=timeout) @@ -199,14 +184,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs) - def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion: - rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs)) - self._update_costs(rsp.usage) - return rsp - async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion: kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs) - rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs) + rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs) self._update_costs(rsp.usage) return rsp @@ -226,56 +206,28 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): ) return messages - def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: - """Use function of tools to ask a code. - - Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create - - Examples: - - >>> llm = OpenAIGPTAPI() - >>> llm.ask_code("Write a python hello world code.") - {'language': 'python', 'code': "print('Hello, World!')"} - >>> msg = [{'role': 'user', 'content': "Write a python hello world code."}] - >>> llm.ask_code(msg) - {'language': 'python', 'code': "print('Hello, World!')"} - """ - messages = self._process_message(messages) - rsp = self._chat_completion_function(messages, **kwargs) - return self.get_choice_function_arguments(rsp) - async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: """Use function of tools to ask a code. - - Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create + Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create Examples: - >>> llm = OpenAIGPTAPI() - >>> rsp = await llm.ask_code("Write a python hello world code.") - >>> rsp - {'language': 'python', 'code': "print('Hello, World!')"} >>> msg = [{'role': 'user', 'content': "Write a python hello world code."}] - >>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} + >>> rsp = await llm.aask_code(msg) + # -> {'language': 'python', 'code': "print('Hello, World!')"} """ messages = self._process_message(messages) - try: - rsp = await self._achat_completion_function(messages, **kwargs) - return self.get_choice_function_arguments(rsp) - except openai.BadRequestError as e: - logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}") - raise e + rsp = await self._achat_completion_function(messages, **kwargs) + return self.get_choice_function_arguments(rsp) + @handle_exception def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict: """Required to provide the first function arguments of choice. :return dict: return the first function arguments of choice, for example, {'language': 'python', 'code': "print('Hello, World!')"} """ - try: - return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments) - except json.JSONDecodeError: - return {} + return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments) def get_choice_text(self, rsp: ChatCompletion) -> str: """Required to provide the first text of choice""" @@ -320,12 +272,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): logger.info(f"Result of task {idx}: {result}") return results + @handle_exception def _update_costs(self, usage: CompletionUsage): if CONFIG.calc_usage and usage: - try: - CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - except Exception as e: - logger.error(f"updating costs failed!, exp: {e}") + CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) def get_costs(self) -> Costs: return CONFIG.cost_manager.get_costs() @@ -335,18 +285,6 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return CONFIG.max_tokens_rsp return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) - def moderation(self, content: Union[str, list[str]]): - return self.client.moderations.create(input=content) - @handle_exception async def amoderation(self, content: Union[str, list[str]]): - return await self.async_client.moderations.create(input=content) - - async def close(self): - """Close connection""" - if self.client: - self.client.close() - self.client = None - if self.async_client: - await self.async_client.close() - self.async_client = None + return await self.aclient.moderations.create(input=content) diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 70076bc86..4ec7be8cf 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -50,9 +50,7 @@ class SparkGPTAPI(BaseGPTAPI): def get_choice_text(self, rsp: dict) -> str: return rsp["payload"]["choices"]["text"][-1]["content"] - async def acompletion_text( - self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 - ) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: # 不支持 logger.error("该功能禁用。") w = GetMessageFromWeb(messages) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 8d57cd444..533ce5719 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -131,7 +131,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index 71381d8f2..b76385b13 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -25,10 +25,6 @@ class OpenAIText2Image: self._llm = LLM() self._client = self._llm.async_client - def __del__(self): - if self._llm: - self._llm.close() - async def text_2_image(self, text, size_type="1024x1024"): """Text to image diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 64423dfb1..8f827986c 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -278,11 +278,11 @@ class UTGenerator: question += self.build_api_doc(node, path, method) self.ask_gpt_and_save(question, tag, summary) - def gpt_msgs_to_code(self, messages: list) -> str: + async def gpt_msgs_to_code(self, messages: list) -> str: """Choose based on different calling methods""" result = "" if self.chatgpt_method == "API": - result = GPTAPI().ask_code(msgs=messages) + result = await GPTAPI().aask_code(msgs=messages) return result diff --git a/tests/metagpt/provider/test_base_gpt_api.py b/tests/metagpt/provider/test_base_gpt_api.py index aaa7b64ff..8628608a9 100644 --- a/tests/metagpt/provider/test_base_gpt_api.py +++ b/tests/metagpt/provider/test_base_gpt_api.py @@ -34,7 +34,7 @@ class MockBaseGPTAPI(BaseGPTAPI): async def acompletion(self, messages: list[dict], timeout=3): return default_chat_resp - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: return resp_content async def close(self): @@ -87,14 +87,14 @@ def test_base_gpt_api(): choice_text = base_gpt_api.get_choice_text(openai_funccall_resp) assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"] - resp = base_gpt_api.ask(prompt_msg) - assert resp == resp_content + # resp = base_gpt_api.ask(prompt_msg) + # assert resp == resp_content - resp = base_gpt_api.ask_batch([prompt_msg]) - assert resp == resp_content + # resp = base_gpt_api.ask_batch([prompt_msg]) + # assert resp == resp_content - resp = base_gpt_api.ask_code([prompt_msg]) - assert resp == resp_content + # resp = base_gpt_api.ask_code([prompt_msg]) + # assert resp == resp_content @pytest.mark.asyncio diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index caf8b9f45..4d92c5f45 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -55,17 +55,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: return default_resp.choices[0].message.content -def test_fireworks_completion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion) - fireworks_gpt = FireWorksGPTAPI() - - resp = fireworks_gpt.completion(messages) - assert resp.choices[0].message.content == resp_content - - resp = fireworks_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 1f25951b1..0736b1d4a 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -36,52 +36,6 @@ async def test_aask_code_Message(): assert len(rsp["code"]) > 0 -def test_ask_code(): - llm = OpenAIGPTAPI() - msg = [{"role": "user", "content": "Write a python hello world code."}] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_str(): - llm = OpenAIGPTAPI() - msg = "Write a python hello world code." - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_Message(): - llm = OpenAIGPTAPI() - msg = UserMessage("Write a python hello world code.") - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_list_Message(): - llm = OpenAIGPTAPI() - msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_list_str(): - llm = OpenAIGPTAPI() - msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - print(rsp) - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - class TestOpenAI: @pytest.fixture def config(self):