From 41361915a12b236d82980299310e555021d56a7a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8E=98=E6=9D=83=20=E9=A9=AC?= Date: Mon, 18 Dec 2023 11:31:08 +0800 Subject: [PATCH] feat: upgrade openai 1.x --- metagpt/learn/text_to_speech.py | 2 +- metagpt/memory/brain_memory.py | 2 +- metagpt/provider/fireworks_api.py | 6 ++++-- metagpt/provider/open_llm_api.py | 9 ++++----- metagpt/provider/zhipuai_api.py | 4 ++-- metagpt/tools/openai_text_to_image.py | 18 +++++++----------- tests/metagpt/test_environment.py | 4 ++++ tests/metagpt/test_gpt.py | 4 ++-- tests/metagpt/test_llm.py | 4 ++-- tests/metagpt/test_startup.py | 4 ++++ tests/metagpt/test_subscription.py | 4 ++++ 11 files changed, 35 insertions(+), 26 deletions(-) diff --git a/metagpt/learn/text_to_speech.py b/metagpt/learn/text_to_speech.py index 972515599..72958b8c7 100644 --- a/metagpt/learn/text_to_speech.py +++ b/metagpt/learn/text_to_speech.py @@ -66,7 +66,7 @@ async def text_to_speech( return f"[{text}]({url})" return audio_declaration + base64_data if base64_data else base64_data - raise openai.error.InvalidRequestError( + raise openai.InvalidRequestError( message="AZURE_TTS_SUBSCRIPTION_KEY, AZURE_TTS_REGION, IFLYTEK_APP_ID, IFLYTEK_API_KEY, IFLYTEK_API_SECRET error", param={}, ) diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index decbb6a8b..034bcfa56 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -171,7 +171,7 @@ class BrainMemory(pydantic.BaseModel): if summary: await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS) return summary - raise openai.error.InvalidRequestError(message="text too long", param=None) + raise openai.InvalidRequestError(message="text too long", param=None) async def _metagpt_summarize(self, max_words=200, **kwargs): if not self.history: diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 5dc68ad35..6625cda97 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -19,6 +19,8 @@ class FireWorksGPTAPI(OpenAIGPTAPI): RateLimiter.__init__(self, rpm=self.rpm) def __init_fireworks(self, config: "Config"): - openai.api_key = config.fireworks_api_key - openai.api_base = config.fireworks_api_base + # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you + # instantiate the client, e.g. 'OpenAI(api_base=config.fireworks_api_base)' + # openai.api_key = config.fireworks_api_key + # openai.api_base = config.fireworks_api_base self.rpm = int(config.get("RPM", 10)) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 97e4c9f67..cd30c4a58 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -2,8 +2,6 @@ # -*- coding: utf-8 -*- # @Desc : self-host open llm model with openai-compatible interface -import openai - from metagpt.config import CONFIG from metagpt.logs import logger from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter @@ -35,13 +33,14 @@ class OpenLLMCostManager(CostManager): class OpenLLMGPTAPI(OpenAIGPTAPI): def __init__(self): self.__init_openllm(CONFIG) - self.llm = openai self.model = CONFIG.open_llm_api_model self.auto_max_tokens = False self._cost_manager = OpenLLMCostManager() RateLimiter.__init__(self, rpm=self.rpm) def __init_openllm(self, config: "Config"): - openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value - openai.api_base = config.open_llm_api_base + # TODO: The 'openai.api_base' option isn't read in the client API. You will need to pass it when you + # instantiate the client, e.g. 'OpenAI(api_base=config.open_llm_api_base)' + # openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value + # openai.api_base = config.open_llm_api_base self.rpm = int(config.get("RPM", 10)) diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 82513f83c..ff8e5531e 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -5,7 +5,6 @@ import json from enum import Enum -import openai import zhipuai from requests import ConnectionError from tenacity import ( @@ -48,7 +47,8 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): def __init_zhipuai(self, config: CONFIG): assert config.zhipuai_api_key zhipuai.api_key = config.zhipuai_api_key - openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. + # due to use openai sdk, set the api_key but it will't be used. + # openai.api_key = zhipuai.api_key # due to use openai sdk, set the api_key but it will't be used. def _const_kwargs(self, messages: list[dict]) -> dict: kwargs = {"model": self.model, "prompt": messages, "temperature": 0.3} diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index 6025f04ba..80de04e45 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -10,8 +10,8 @@ import asyncio import base64 import aiohttp -import openai import requests +from openai import AsyncOpenAI from metagpt.config import CONFIG, Config from metagpt.logs import logger @@ -23,6 +23,11 @@ class OpenAIText2Image: :param openai_api_key: OpenAI API key, For more details, checkout: `https://platform.openai.com/account/api-keys` """ self.openai_api_key = openai_api_key if openai_api_key else CONFIG.OPENAI_API_KEY + self._client = AsyncOpenAI(api_key=self.openai_api_key, base_url=CONFIG.openai_api_base) + + def __del__(self): + if self._client: + self._client.close() async def text_2_image(self, text, size_type="1024x1024"): """Text to image @@ -32,16 +37,7 @@ class OpenAIText2Image: :return: The image data is returned in Base64 encoding. """ try: - result = await openai.Image.acreate( - api_key=CONFIG.OPENAI_API_KEY, - api_base=CONFIG.OPENAI_API_BASE, - api_type=None, - api_version=None, - organization=None, - prompt=text, - n=1, - size=size_type, - ) + result = await self._client.images.generate(prompt=text, n=1, size=size_type) except Exception as e: logger.error(f"An error occurred:{e}") return "" diff --git a/tests/metagpt/test_environment.py b/tests/metagpt/test_environment.py index fd731cf9e..bc88eb742 100644 --- a/tests/metagpt/test_environment.py +++ b/tests/metagpt/test_environment.py @@ -56,3 +56,7 @@ async def test_publish_and_process_message(env: Environment): await env.run(k=2) logger.info(f"{env.history=}") assert len(env.history) > 10 + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_gpt.py b/tests/metagpt/test_gpt.py index dda5e6252..daafeb708 100644 --- a/tests/metagpt/test_gpt.py +++ b/tests/metagpt/test_gpt.py @@ -54,5 +54,5 @@ class TestGPT: assert costs.total_cost > 0 -# if __name__ == "__main__": -# pytest.main([__file__, "-s"]) +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index f2d4371d5..d972e55c0 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -35,5 +35,5 @@ async def test_llm_acompletion(llm): assert len(await llm.acompletion_batch_text([hello_msg])) > 0 -# if __name__ == "__main__": -# pytest.main([__file__, "-s"]) +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_startup.py b/tests/metagpt/test_startup.py index c34fd2c31..c8d4d5d29 100644 --- a/tests/metagpt/test_startup.py +++ b/tests/metagpt/test_startup.py @@ -26,3 +26,7 @@ async def test_team(): # def test_startup(): # args = ["Make a 2048 game"] # result = runner.invoke(app, args) + + +if __name__ == "__main__": + pytest.main([__file__, "-s"]) diff --git a/tests/metagpt/test_subscription.py b/tests/metagpt/test_subscription.py index 2e898424d..1399df7fe 100644 --- a/tests/metagpt/test_subscription.py +++ b/tests/metagpt/test_subscription.py @@ -100,3 +100,7 @@ async def test_subscription_run_error(loguru_caplog): logs = "".join(loguru_caplog.messages) assert "run error" in logs assert "has completed" in logs + + +if __name__ == "__main__": + pytest.main([__file__, "-s"])