diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index b82ac1210..609344fc3 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field from metagpt.config import CONFIG from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import logger -from metagpt.provider import MetaGPTAPI +from metagpt.provider import MetaGPTLLM from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, SimpleMessage from metagpt.utils.redis import Redis @@ -122,7 +122,7 @@ class BrainMemory(BaseModel): return v async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_summarize(max_words=max_words) self.llm = llm @@ -175,7 +175,7 @@ class BrainMemory(BaseModel): async def get_title(self, llm, max_words=5, **kwargs) -> str: """Generate text title""" - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return self.history[0].content if self.history else "New" summary = await self.summarize(llm=llm, max_words=500) @@ -190,7 +190,7 @@ class BrainMemory(BaseModel): return response async def is_related(self, text1, text2, llm): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm) return await self._openai_is_related(text1=text1, text2=text2, llm=llm) @@ -212,7 +212,7 @@ class BrainMemory(BaseModel): return result async def rewrite(self, sentence: str, context: str, llm): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm) return await self._openai_rewrite(sentence=sentence, context=context, llm=llm) diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 36d585c94..28157a4e2 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -7,21 +7,21 @@ """ from metagpt.provider.fireworks_api import FireworksLLM -from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.google_gemini_api import GeminiLLM from metagpt.provider.ollama_api import OllamaLLM -from metagpt.provider.open_llm_api import OpenLLMGPTAPI +from metagpt.provider.open_llm_api import OpenLLM from metagpt.provider.openai_api import OpenAILLM -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAILLM from metagpt.provider.azure_openai_api import AzureOpenAILLM -from metagpt.provider.metagpt_api import MetaGPTAPI +from metagpt.provider.metagpt_api import MetaGPTLLM __all__ = [ "FireworksLLM", - "GeminiGPTAPI", - "OpenLLMGPTAPI", + "GeminiLLM", + "OpenLLM", "OpenAILLM", - "ZhiPuAIGPTAPI", + "ZhiPuAILLM", "AzureOpenAILLM", - "MetaGPTAPI", + "MetaGPTLLM", "OllamaLLM", ] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 5683095c7..b9ee73a92 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel): @register_provider(LLMProviderEnum.GEMINI) -class GeminiGPTAPI(BaseLLM): +class GeminiLLM(BaseLLM): """ Refs to `https://ai.google.dev/tutorials/python_quickstart` """ diff --git a/metagpt/provider/metagpt_api.py b/metagpt/provider/metagpt_api.py index 2b7629895..69aa7f305 100644 --- a/metagpt/provider/metagpt_api.py +++ b/metagpt/provider/metagpt_api.py @@ -11,6 +11,6 @@ from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.METAGPT) -class MetaGPTAPI(OpenAILLM): +class MetaGPTLLM(OpenAILLM): def __init__(self): super().__init__() diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 976e95c57..6ccdb4da0 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -35,7 +35,7 @@ class OpenLLMCostManager(CostManager): @register_provider(LLMProviderEnum.OPEN_LLM) -class OpenLLMGPTAPI(OpenAILLM): +class OpenLLM(OpenAILLM): def __init__(self): self.config: Config = CONFIG self.__init_openllm() diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index df8c330b8..cdc9c63e6 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -5,6 +5,7 @@ import json from enum import Enum +import openai import zhipuai from requests import ConnectionError from tenacity import ( @@ -31,7 +32,7 @@ class ZhiPuEvent(Enum): @register_provider(LLMProviderEnum.ZHIPUAI) -class ZhiPuAIGPTAPI(BaseLLM): +class ZhiPuAILLM(BaseLLM): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` From now, there is only one model named `chatglm_turbo` diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index 00b3c716a..d9c946ef7 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -15,6 +15,9 @@ from metagpt.provider.fireworks_api import ( FireworksCostManager, FireworksLLM, ) +from metagpt.config import CONFIG + +CONFIG.fireworks_api_key = "xxx" resp_content = "I'm fireworks" default_resp = ChatCompletion( @@ -23,7 +26,7 @@ default_resp = ChatCompletion( object="chat.completion", created=1703300855, choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content)) + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content), logprobs=None) ], usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) @@ -57,10 +60,10 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.fireworks_api.FireworksLLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.fireworks_api.FireworksLLM._achat_completion", mock_llm_acompletion) mocker.patch( - "metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.fireworks_api.FireworksLLM._achat_completion_stream", mock_llm_achat_completion_stream ) fireworks_gpt = FireworksLLM() diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 60f50c9ad..7e372634c 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import pytest -from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.google_gemini_api import GeminiLLM @dataclass @@ -37,12 +37,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM._achat_completion", mock_llm_acompletion) mocker.patch( - "metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.google_gemini_api.GeminiLLM._achat_completion_stream", mock_llm_achat_completion_stream ) - gemini_gpt = GeminiGPTAPI() + gemini_gpt = GeminiLLM() resp = await gemini_gpt.acompletion(messages) assert resp.text == default_resp.text diff --git a/tests/metagpt/provider/test_metagpt_llm_api.py b/tests/metagpt/provider/test_metagpt_llm_api.py index f454b08a7..8fce6b6b0 100644 --- a/tests/metagpt/provider/test_metagpt_llm_api.py +++ b/tests/metagpt/provider/test_metagpt_llm_api.py @@ -5,11 +5,11 @@ @Author : mashenquan @File : test_metagpt_llm_api.py """ -from metagpt.provider.metagpt_api import MetaGPTAPI +from metagpt.provider.metagpt_api import MetaGPTLLM def test_metagpt(): - llm = MetaGPTAPI() + llm = MetaGPTLLM() assert llm diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index d19e23e17..ba019f295 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -30,9 +30,9 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream) + mocker.patch("metagpt.provider.ollama_api.OllamaLLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion_stream", mock_llm_achat_completion_stream) ollama_gpt = OllamaLLM() resp = await ollama_gpt.acompletion(messages) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 329edadff..cb86dfcf9 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -86,31 +86,25 @@ class TestOpenAI: def test_make_client_kwargs_without_proxy(self, config): instance = OpenAILLM() instance.config = config - kwargs, async_kwargs = instance._make_client_kwargs() + kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} - assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} assert "http_client" not in kwargs - assert "http_client" not in async_kwargs def test_make_client_kwargs_without_proxy_azure(self, config_azure): instance = OpenAILLM() instance.config = config_azure - kwargs, async_kwargs = instance._make_client_kwargs() + kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} - assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} assert "http_client" not in kwargs - assert "http_client" not in async_kwargs def test_make_client_kwargs_with_proxy(self, config_proxy): instance = OpenAILLM() instance.config = config_proxy - kwargs, async_kwargs = instance._make_client_kwargs() + kwargs = instance._make_client_kwargs() assert "http_client" in kwargs - assert "http_client" in async_kwargs def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): instance = OpenAILLM() instance.config = config_azure_proxy - kwargs, async_kwargs = instance._make_client_kwargs() + kwargs = instance._make_client_kwargs() assert "http_client" in kwargs - assert "http_client" in async_kwargs diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 6cc87741e..e62c287c0 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -20,8 +20,8 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, @pytest.mark.asyncio async def test_spark_acompletion(mocker): - mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion_text", mock_llm_acompletion) spark_gpt = SparkLLM() resp = await spark_gpt.acompletion([]) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 06f2cba62..29cfe2eb3 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : the unittest of ZhiPuAIGPTAPI +# @Desc : the unittest of ZhiPuAILLM import pytest from metagpt.config import CONFIG -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAILLM CONFIG.zhipuai_api_key = "xxx" @@ -30,12 +30,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion", mock_llm_acompletion) mocker.patch( - "metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion_stream", mock_llm_achat_completion_stream ) - zhipu_gpt = ZhiPuAIGPTAPI() + zhipu_gpt = ZhiPuAILLM() resp = await zhipu_gpt.acompletion(messages) assert resp["data"]["choices"][0]["content"] == resp_content @@ -59,5 +59,5 @@ def test_zhipuai_proxy(mocker): from metagpt.config import CONFIG CONFIG.openai_proxy = "http://127.0.0.1:8080" - _ = ZhiPuAIGPTAPI() + _ = ZhiPuAILLM() assert openai.proxy == CONFIG.openai_proxy