From 255e2d3fa7607f796c46a3da63fb86a1bbfcfecd Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 28 Dec 2023 17:18:18 +0800 Subject: [PATCH 1/9] update provider uniform name and check tests --- metagpt/memory/brain_memory.py | 10 +++++----- metagpt/provider/__init__.py | 16 ++++++++-------- metagpt/provider/google_gemini_api.py | 2 +- metagpt/provider/metagpt_api.py | 2 +- metagpt/provider/open_llm_api.py | 2 +- metagpt/provider/zhipuai_api.py | 3 ++- tests/metagpt/provider/test_fireworks_api.py | 11 +++++++---- tests/metagpt/provider/test_google_gemini_api.py | 10 +++++----- tests/metagpt/provider/test_metagpt_llm_api.py | 4 ++-- tests/metagpt/provider/test_ollama_api.py | 6 +++--- tests/metagpt/provider/test_openai.py | 14 ++++---------- tests/metagpt/provider/test_spark_api.py | 4 ++-- tests/metagpt/provider/test_zhipuai_api.py | 14 +++++++------- 13 files changed, 48 insertions(+), 50 deletions(-) diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index b82ac1210..609344fc3 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field from metagpt.config import CONFIG from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import logger -from metagpt.provider import MetaGPTAPI +from metagpt.provider import MetaGPTLLM from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, SimpleMessage from metagpt.utils.redis import Redis @@ -122,7 +122,7 @@ class BrainMemory(BaseModel): return v async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_summarize(max_words=max_words) self.llm = llm @@ -175,7 +175,7 @@ class BrainMemory(BaseModel): async def get_title(self, llm, max_words=5, **kwargs) -> str: """Generate text title""" - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return self.history[0].content if self.history else "New" summary = await self.summarize(llm=llm, max_words=500) @@ -190,7 +190,7 @@ class BrainMemory(BaseModel): return response async def is_related(self, text1, text2, llm): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm) return await self._openai_is_related(text1=text1, text2=text2, llm=llm) @@ -212,7 +212,7 @@ class BrainMemory(BaseModel): return result async def rewrite(self, sentence: str, context: str, llm): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm) return await self._openai_rewrite(sentence=sentence, context=context, llm=llm) diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 36d585c94..28157a4e2 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -7,21 +7,21 @@ """ from metagpt.provider.fireworks_api import FireworksLLM -from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.google_gemini_api import GeminiLLM from metagpt.provider.ollama_api import OllamaLLM -from metagpt.provider.open_llm_api import OpenLLMGPTAPI +from metagpt.provider.open_llm_api import OpenLLM from metagpt.provider.openai_api import OpenAILLM -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAILLM from metagpt.provider.azure_openai_api import AzureOpenAILLM -from metagpt.provider.metagpt_api import MetaGPTAPI +from metagpt.provider.metagpt_api import MetaGPTLLM __all__ = [ "FireworksLLM", - "GeminiGPTAPI", - "OpenLLMGPTAPI", + "GeminiLLM", + "OpenLLM", "OpenAILLM", - "ZhiPuAIGPTAPI", + "ZhiPuAILLM", "AzureOpenAILLM", - "MetaGPTAPI", + "MetaGPTLLM", "OllamaLLM", ] diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index 5683095c7..b9ee73a92 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel): @register_provider(LLMProviderEnum.GEMINI) -class GeminiGPTAPI(BaseLLM): +class GeminiLLM(BaseLLM): """ Refs to `https://ai.google.dev/tutorials/python_quickstart` """ diff --git a/metagpt/provider/metagpt_api.py b/metagpt/provider/metagpt_api.py index 2b7629895..69aa7f305 100644 --- a/metagpt/provider/metagpt_api.py +++ b/metagpt/provider/metagpt_api.py @@ -11,6 +11,6 @@ from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.METAGPT) -class MetaGPTAPI(OpenAILLM): +class MetaGPTLLM(OpenAILLM): def __init__(self): super().__init__() diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 976e95c57..6ccdb4da0 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -35,7 +35,7 @@ class OpenLLMCostManager(CostManager): @register_provider(LLMProviderEnum.OPEN_LLM) -class OpenLLMGPTAPI(OpenAILLM): +class OpenLLM(OpenAILLM): def __init__(self): self.config: Config = CONFIG self.__init_openllm() diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index df8c330b8..cdc9c63e6 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -5,6 +5,7 @@ import json from enum import Enum +import openai import zhipuai from requests import ConnectionError from tenacity import ( @@ -31,7 +32,7 @@ class ZhiPuEvent(Enum): @register_provider(LLMProviderEnum.ZHIPUAI) -class ZhiPuAIGPTAPI(BaseLLM): +class ZhiPuAILLM(BaseLLM): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` From now, there is only one model named `chatglm_turbo` diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index 00b3c716a..d9c946ef7 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -15,6 +15,9 @@ from metagpt.provider.fireworks_api import ( FireworksCostManager, FireworksLLM, ) +from metagpt.config import CONFIG + +CONFIG.fireworks_api_key = "xxx" resp_content = "I'm fireworks" default_resp = ChatCompletion( @@ -23,7 +26,7 @@ default_resp = ChatCompletion( object="chat.completion", created=1703300855, choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content)) + Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content), logprobs=None) ], usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) @@ -57,10 +60,10 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.fireworks_api.FireworksLLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.fireworks_api.FireworksLLM._achat_completion", mock_llm_acompletion) mocker.patch( - "metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.fireworks_api.FireworksLLM._achat_completion_stream", mock_llm_achat_completion_stream ) fireworks_gpt = FireworksLLM() diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 60f50c9ad..7e372634c 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -7,7 +7,7 @@ from dataclasses import dataclass import pytest -from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.google_gemini_api import GeminiLLM @dataclass @@ -37,12 +37,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM._achat_completion", mock_llm_acompletion) mocker.patch( - "metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.google_gemini_api.GeminiLLM._achat_completion_stream", mock_llm_achat_completion_stream ) - gemini_gpt = GeminiGPTAPI() + gemini_gpt = GeminiLLM() resp = await gemini_gpt.acompletion(messages) assert resp.text == default_resp.text diff --git a/tests/metagpt/provider/test_metagpt_llm_api.py b/tests/metagpt/provider/test_metagpt_llm_api.py index f454b08a7..8fce6b6b0 100644 --- a/tests/metagpt/provider/test_metagpt_llm_api.py +++ b/tests/metagpt/provider/test_metagpt_llm_api.py @@ -5,11 +5,11 @@ @Author : mashenquan @File : test_metagpt_llm_api.py """ -from metagpt.provider.metagpt_api import MetaGPTAPI +from metagpt.provider.metagpt_api import MetaGPTLLM def test_metagpt(): - llm = MetaGPTAPI() + llm = MetaGPTLLM() assert llm diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index d19e23e17..ba019f295 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -30,9 +30,9 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream) + mocker.patch("metagpt.provider.ollama_api.OllamaLLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion_stream", mock_llm_achat_completion_stream) ollama_gpt = OllamaLLM() resp = await ollama_gpt.acompletion(messages) diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 329edadff..cb86dfcf9 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -86,31 +86,25 @@ class TestOpenAI: def test_make_client_kwargs_without_proxy(self, config): instance = OpenAILLM() instance.config = config - kwargs, async_kwargs = instance._make_client_kwargs() + kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} - assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} assert "http_client" not in kwargs - assert "http_client" not in async_kwargs def test_make_client_kwargs_without_proxy_azure(self, config_azure): instance = OpenAILLM() instance.config = config_azure - kwargs, async_kwargs = instance._make_client_kwargs() + kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} - assert async_kwargs == {"api_key": "test_key", "base_url": "test_url"} assert "http_client" not in kwargs - assert "http_client" not in async_kwargs def test_make_client_kwargs_with_proxy(self, config_proxy): instance = OpenAILLM() instance.config = config_proxy - kwargs, async_kwargs = instance._make_client_kwargs() + kwargs = instance._make_client_kwargs() assert "http_client" in kwargs - assert "http_client" in async_kwargs def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): instance = OpenAILLM() instance.config = config_azure_proxy - kwargs, async_kwargs = instance._make_client_kwargs() + kwargs = instance._make_client_kwargs() assert "http_client" in kwargs - assert "http_client" in async_kwargs diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 6cc87741e..e62c287c0 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -20,8 +20,8 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, @pytest.mark.asyncio async def test_spark_acompletion(mocker): - mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion_text", mock_llm_acompletion) spark_gpt = SparkLLM() resp = await spark_gpt.acompletion([]) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 06f2cba62..29cfe2eb3 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -1,11 +1,11 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : the unittest of ZhiPuAIGPTAPI +# @Desc : the unittest of ZhiPuAILLM import pytest from metagpt.config import CONFIG -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAILLM CONFIG.zhipuai_api_key = "xxx" @@ -30,12 +30,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM.acompletion", mock_llm_acompletion) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion", mock_llm_acompletion) mocker.patch( - "metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion_stream", mock_llm_achat_completion_stream ) - zhipu_gpt = ZhiPuAIGPTAPI() + zhipu_gpt = ZhiPuAILLM() resp = await zhipu_gpt.acompletion(messages) assert resp["data"]["choices"][0]["content"] == resp_content @@ -59,5 +59,5 @@ def test_zhipuai_proxy(mocker): from metagpt.config import CONFIG CONFIG.openai_proxy = "http://127.0.0.1:8080" - _ = ZhiPuAIGPTAPI() + _ = ZhiPuAILLM() assert openai.proxy == CONFIG.openai_proxy From 5fc8207950197618e039f5eb5968f9fe1a7b4382 Mon Sep 17 00:00:00 2001 From: better629 Date: Thu, 28 Dec 2023 17:18:28 +0800 Subject: [PATCH 2/9] update provider uniform name and check tests --- tests/metagpt/provider/test_fireworks_api.py | 9 +++++++-- tests/metagpt/provider/test_zhipuai_api.py | 4 +--- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index d9c946ef7..496465e5f 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -10,12 +10,12 @@ from openai.types.chat.chat_completion import ( ) from openai.types.completion_usage import CompletionUsage +from metagpt.config import CONFIG from metagpt.provider.fireworks_api import ( MODEL_GRADE_TOKEN_COSTS, FireworksCostManager, FireworksLLM, ) -from metagpt.config import CONFIG CONFIG.fireworks_api_key = "xxx" @@ -26,7 +26,12 @@ default_resp = ChatCompletion( object="chat.completion", created=1703300855, choices=[ - Choice(finish_reason="stop", index=0, message=ChatCompletionMessage(role="assistant", content=resp_content), logprobs=None) + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=resp_content), + logprobs=None, + ) ], usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 29cfe2eb3..c1af2f0be 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -32,9 +32,7 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: async def test_zhipuai_acompletion(mocker): mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM.acompletion", mock_llm_acompletion) mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion", mock_llm_acompletion) - mocker.patch( - "metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion_stream", mock_llm_achat_completion_stream - ) + mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion_stream", mock_llm_achat_completion_stream) zhipu_gpt = ZhiPuAILLM() resp = await zhipu_gpt.acompletion(messages) From 0f047e5693ebe5f5f92f95c81cfbd4cf4cd9ad67 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 02:39:00 +0800 Subject: [PATCH 3/9] update provider unittests to update coverage rate --- metagpt/actions/action_node.py | 4 +- metagpt/provider/general_api_base.py | 2 +- metagpt/provider/google_gemini_api.py | 3 - metagpt/provider/open_llm_api.py | 1 - .../{postprecess => postprocess}/__init__.py | 0 .../base_postprocess_plugin.py} | 4 +- .../llm_output_postprocess.py} | 10 +- metagpt/provider/zhipuai/zhipu_model_api.py | 2 +- metagpt/provider/zhipuai_api.py | 3 - .../metagpt/provider/postprocess/__init__.py | 3 + .../test_base_postprocess_plugin.py | 38 ++++++++ .../test_llm_output_postprocess.py | 14 +++ tests/metagpt/provider/test_anthropic_api.py | 19 ++-- .../metagpt/provider/test_azure_openai_api.py | 15 +++ tests/metagpt/provider/test_fireworks_api.py | 58 +++++++++--- .../metagpt/provider/test_general_api_base.py | 84 +++++++++++++++++ .../provider/test_general_api_requestor.py | 15 ++- .../provider/test_google_gemini_api.py | 49 ++++++++-- tests/metagpt/provider/test_ollama_api.py | 31 +++++-- tests/metagpt/provider/test_open_llm_api.py | 93 +++++++++++++++++++ tests/metagpt/provider/test_openai.py | 5 + tests/metagpt/provider/test_spark_api.py | 19 ++-- tests/metagpt/provider/test_zhipuai_api.py | 52 +++++++++-- tests/metagpt/provider/zhipuai/__init__.py | 3 + .../provider/zhipuai/test_async_sse_client.py | 18 ++++ .../provider/zhipuai/test_zhipu_model_api.py | 40 ++++++++ 26 files changed, 509 insertions(+), 76 deletions(-) rename metagpt/provider/{postprecess => postprocess}/__init__.py (100%) rename metagpt/provider/{postprecess/base_postprecess_plugin.py => postprocess/base_postprocess_plugin.py} (98%) rename metagpt/provider/{postprecess/llm_output_postprecess.py => postprocess/llm_output_postprocess.py} (58%) create mode 100644 tests/metagpt/provider/postprocess/__init__.py create mode 100644 tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py create mode 100644 tests/metagpt/provider/postprocess/test_llm_output_postprocess.py create mode 100644 tests/metagpt/provider/test_azure_openai_api.py create mode 100644 tests/metagpt/provider/test_general_api_base.py create mode 100644 tests/metagpt/provider/test_open_llm_api.py create mode 100644 tests/metagpt/provider/zhipuai/__init__.py create mode 100644 tests/metagpt/provider/zhipuai/test_async_sse_client.py create mode 100644 tests/metagpt/provider/zhipuai/test_zhipu_model_api.py diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 3389b8964..35f2b76f8 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -17,7 +17,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.config import CONFIG from metagpt.llm import BaseLLM from metagpt.logs import logger -from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess from metagpt.utils.common import OutputParser, general_after_log TAG = "CONTENT" @@ -275,7 +275,7 @@ class ActionNode: output_class = self.create_model_class(output_class_name, output_data_mapping) if schema == "json": - parsed_data = llm_output_postprecess( + parsed_data = llm_output_postprocess( output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" ) else: # using markdown parser diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index 814be2f67..bbe03774c 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -100,7 +100,7 @@ def log_info(message, **params): def log_warn(message, **params): msg = logfmt(dict(message=message, **params)) print(msg, file=sys.stderr) - logger.warn(msg) + logger.warning(msg) def logfmt(props): diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index b9ee73a92..c99a14b38 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -79,9 +79,6 @@ class GeminiLLM(BaseLLM): except Exception as e: logger.error(f"google gemini updats costs failed! exp: {e}") - def close(self): - pass - def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 6ccdb4da0..7f5870702 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -31,7 +31,6 @@ class OpenLLMCostManager(CostManager): f"Max budget: ${CONFIG.max_budget:.3f} | reference " f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" ) - CONFIG.total_cost = self.total_cost @register_provider(LLMProviderEnum.OPEN_LLM) diff --git a/metagpt/provider/postprecess/__init__.py b/metagpt/provider/postprocess/__init__.py similarity index 100% rename from metagpt/provider/postprecess/__init__.py rename to metagpt/provider/postprocess/__init__.py diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprocess/base_postprocess_plugin.py similarity index 98% rename from metagpt/provider/postprecess/base_postprecess_plugin.py rename to metagpt/provider/postprocess/base_postprocess_plugin.py index 46646be91..48130ede8 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprocess/base_postprocess_plugin.py @@ -12,8 +12,8 @@ from metagpt.utils.repair_llm_raw_output import ( ) -class BasePostPrecessPlugin(object): - model = None # the plugin of the `model`, use to judge in `llm_postprecess` +class BasePostProcessPlugin(object): + model = None # the plugin of the `model`, use to judge in `llm_postprocess` def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]: """ diff --git a/metagpt/provider/postprecess/llm_output_postprecess.py b/metagpt/provider/postprocess/llm_output_postprocess.py similarity index 58% rename from metagpt/provider/postprecess/llm_output_postprecess.py rename to metagpt/provider/postprocess/llm_output_postprocess.py index 85405543d..f898ba3d7 100644 --- a/metagpt/provider/postprecess/llm_output_postprecess.py +++ b/metagpt/provider/postprocess/llm_output_postprocess.py @@ -4,17 +4,17 @@ from typing import Union -from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin +from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin -def llm_output_postprecess( +def llm_output_postprocess( output: str, schema: dict, req_key: str = "[/CONTENT]", model_name: str = None ) -> Union[dict, str]: """ - default use BasePostPrecessPlugin if there is not matched plugin. + default use BasePostProcessPlugin if there is not matched plugin. """ # TODO choose different model's plugin according to the model_name - postprecess_plugin = BasePostPrecessPlugin() + postprocess_plugin = BasePostProcessPlugin() - result = postprecess_plugin.run(output=output, schema=schema, req_key=req_key) + result = postprocess_plugin.run(output=output, schema=schema, req_key=req_key) return result diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index 19eb52530..72be0f333 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -33,7 +33,7 @@ class ZhiPuModelAPI(ModelAPI): zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method} """ arr = zhipu_api_url.split("/api/") - # ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke") + # ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke") return f"{arr[0]}/api", f"/{arr[1]}" @classmethod diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index cdc9c63e6..addbe58af 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -68,9 +68,6 @@ class ZhiPuAILLM(BaseLLM): except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") - def close(self): - pass - def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] diff --git a/tests/metagpt/provider/postprocess/__init__.py b/tests/metagpt/provider/postprocess/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/provider/postprocess/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py new file mode 100644 index 000000000..e63e4ecfe --- /dev/null +++ b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py @@ -0,0 +1,38 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin + +raw_output = """ +[CONTENT] +{ +"Original Requirements": "xxx" +} +[/CONTENT] +""" +raw_schema = { + "title":"prd", + "type":"object", + "properties":{ + "Original Requirements":{ + "title":"Original Requirements", + "type":"string" + }, + }, + "required":[ + "Original Requirements", + ] + } + + +def test_llm_post_process_plugin(): + post_process_plugin = BasePostProcessPlugin() + + output = post_process_plugin.run( + output=raw_output, + schema=raw_schema + ) + assert "Original Requirements" in output diff --git a/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py new file mode 100644 index 000000000..3cb627216 --- /dev/null +++ b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py @@ -0,0 +1,14 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess + +from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import raw_output, raw_schema + + +def test_llm_output_postprocess(): + output = llm_output_postprocess(output=raw_output, schema=raw_schema) + assert "Original Requirements" in output diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py index 4d3de5320..4410717a9 100644 --- a/tests/metagpt/provider/test_anthropic_api.py +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -2,28 +2,33 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of Claude2 -import pytest +import pytest +from anthropic.resources.completions import Completion + +from metagpt.config import CONFIG from metagpt.provider.anthropic_api import Claude2 +CONFIG.anthropic_api_key = "xxx" + prompt = "who are you" resp = "I'am Claude2" -def mock_llm_ask(self, msg: str) -> str: - return resp +def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: + return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") -async def mock_llm_aask(self, msg: str) -> str: - return resp +async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: + return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") def test_claude2_ask(mocker): - mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask) + mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create) assert resp == Claude2().ask(prompt) @pytest.mark.asyncio async def test_claude2_aask(mocker): - mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask) + mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create) assert resp == await Claude2().aask(prompt) diff --git a/tests/metagpt/provider/test_azure_openai_api.py b/tests/metagpt/provider/test_azure_openai_api.py new file mode 100644 index 000000000..208e3104a --- /dev/null +++ b/tests/metagpt/provider/test_azure_openai_api.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.azure_openai_api import AzureOpenAILLM +from metagpt.config import CONFIG + +CONFIG.OPENAI_API_VERSION = "xx" +CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value + + +def test_azure_openai_api(): + _ = AzureOpenAILLM() diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index 496465e5f..d48686eaa 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -8,6 +8,9 @@ from openai.types.chat.chat_completion import ( ChatCompletionMessage, Choice, ) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.config import CONFIG @@ -16,8 +19,11 @@ from metagpt.provider.fireworks_api import ( FireworksCostManager, FireworksLLM, ) +from metagpt.utils.cost_manager import Costs CONFIG.fireworks_api_key = "xxx" +CONFIG.max_budget = 10 +CONFIG.calc_usage = True resp_content = "I'm fireworks" default_resp = ChatCompletion( @@ -36,6 +42,22 @@ default_resp = ChatCompletion( usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) +default_resp_chunk = ChatCompletionChunk( + id=default_resp.id, + model=default_resp.model, + object="chat.completion.chunk", + created=default_resp.created, + choices=[ + AChoice( + delta=ChoiceDelta(content=resp_content, role="assistant"), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], + usage=dict(default_resp.usage), +) + prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] @@ -50,29 +72,37 @@ def test_fireworks_costmanager(): assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat") assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat") - -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion: - return default_resp + cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat") + assert cost_manager.total_cost == 0.5 -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion: - return default_resp +async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: + if stream: + class Iterator(object): + async def __aiter__(self): + yield default_resp_chunk -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return default_resp.choices[0].message.content + return Iterator() + else: + return default_resp @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireworksLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.fireworks_api.FireworksLLM._achat_completion", mock_llm_acompletion) - mocker.patch( - "metagpt.provider.fireworks_api.FireworksLLM._achat_completion_stream", mock_llm_achat_completion_stream - ) - fireworks_gpt = FireworksLLM() + mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - resp = await fireworks_gpt.acompletion(messages, stream=False) + fireworks_gpt = FireworksLLM() + fireworks_gpt.model = "llama-v2-13b-chat" + + fireworks_gpt._update_costs( + usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000) + ) + assert fireworks_gpt.get_costs() == Costs( + total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0 + ) + + resp = await fireworks_gpt.acompletion(messages) assert resp.choices[0].message.content in resp_content resp = await fireworks_gpt.aask(prompt_msg, stream=False) diff --git a/tests/metagpt/provider/test_general_api_base.py b/tests/metagpt/provider/test_general_api_base.py new file mode 100644 index 000000000..52ba32f01 --- /dev/null +++ b/tests/metagpt/provider/test_general_api_base.py @@ -0,0 +1,84 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest +import os +import requests +import aiohttp +from typing import Iterator, Tuple, Union, Generator, AsyncGenerator + +from openai import OpenAIError +from metagpt.provider.general_api_base import ApiType, log_debug, log_info, log_warn, OpenAIResponse, \ + _requests_proxies_arg, _aiohttp_proxies_arg, _make_session, parse_stream_helper, parse_stream, APIRequestor + + +def test_basic(): + _ = ApiType.from_str("azure") + _ = ApiType.from_str("azuread") + _ = ApiType.from_str("openai") + with pytest.raises(OpenAIError): + _ = ApiType.from_str("xx") + + os.environ.setdefault("LLM_LOG", "debug") + log_debug("debug") + log_warn("warn") + log_info("info") + + +def test_openai_response(): + resp = OpenAIResponse(data=[], headers={"retry-after": 3}) + assert resp.request_id is None + assert resp.retry_after == 3 + assert resp.operation_location is None + assert resp.organization is None + assert resp.response_ms is None + + +def test_proxy(): + assert _requests_proxies_arg(proxy=None) is None + + proxy = "127.0.0.1:80" + assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy} + proxy_dict = {"http": proxy} + assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict + proxy_dict = {"https": proxy} + assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict + + assert _make_session() is not None + + +def test_parse_stream(): + assert parse_stream_helper(None) is None + assert parse_stream_helper(b"data: [DONE]") is None + assert parse_stream_helper(b"data: test") == "test" + assert parse_stream_helper(b"test") is None + for line in parse_stream([b"data: test"]): + assert line == "test" + + +api_requestor = APIRequestor(base_url="http://www.baidu.com") + + +def mock_interpret_response(self, result: requests.Response, stream: bool + ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: + return b"baidu", False + + +async def mock_interpret_async_response(self, result: aiohttp.ClientResponse, stream: bool + ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: + return b"baidu", True + + +def test_api_requestor(mocker): + mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response) + resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") + + resp, _, _ = api_requestor.request(method="post", url="/s?wd=baidu") + + +@pytest.mark.asyncio +async def test_async_api_requestor(mocker): + mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response) + resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu") + resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu") diff --git a/tests/metagpt/provider/test_general_api_requestor.py b/tests/metagpt/provider/test_general_api_requestor.py index 28130fa65..dcbcc0567 100644 --- a/tests/metagpt/provider/test_general_api_requestor.py +++ b/tests/metagpt/provider/test_general_api_requestor.py @@ -4,11 +4,24 @@ import pytest -from metagpt.provider.general_api_requestor import GeneralAPIRequestor +from metagpt.provider.general_api_requestor import ( + GeneralAPIRequestor, + parse_stream, + parse_stream_helper, +) api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com") +def test_parse_stream(): + assert parse_stream_helper(None) is None + assert parse_stream_helper(b"data: [DONE]") is None + assert parse_stream_helper(b"data: test") == b"test" + assert parse_stream_helper(b"test") is None + for line in parse_stream([b"data: test"]): + assert line == b"test" + + def test_api_requestor(): resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") assert b"baidu" in resp diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 7e372634c..ffd10df7f 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -6,9 +6,14 @@ from abc import ABC from dataclasses import dataclass import pytest +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types +from metagpt.config import CONFIG from metagpt.provider.google_gemini_api import GeminiLLM +CONFIG.gemini_api_key = "xx" + @dataclass class MockGeminiResponse(ABC): @@ -21,29 +26,53 @@ resp_content = "I'm gemini from google" default_resp = MockGeminiResponse(text=resp_content) -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse: +def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + return glm.CountTokensResponse(total_tokens=20) + + +async def mock_gemini_count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + return glm.CountTokensResponse(total_tokens=20) + + +def mock_gemini_generate_content(self, **kwargs) -> MockGeminiResponse: return default_resp -async def mock_llm_acompletion( - self, messgaes: list[dict], stream: bool = False, timeout: int = 60 -) -> MockGeminiResponse: - return default_resp +async def mock_gemini_generate_content_async(self, stream: bool = False, **kwargs) -> MockGeminiResponse: + if stream: + class Iterator(object): + async def __aiter__(self): + yield default_resp -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content + return Iterator() + else: + return default_resp @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.google_gemini_api.GeminiLLM._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens", mock_gemini_count_tokens) mocker.patch( - "metagpt.provider.google_gemini_api.GeminiLLM._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens_async", mock_gemini_count_tokens_async ) + mocker.patch("google.generativeai.generative_models.GenerativeModel.generate_content", mock_gemini_generate_content) + mocker.patch( + "google.generativeai.generative_models.GenerativeModel.generate_content_async", + mock_gemini_generate_content_async, + ) + gemini_gpt = GeminiLLM() + assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]} + assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]} + + usage = gemini_gpt.get_usage(messages, resp_content) + assert usage == {"prompt_tokens": 20, "completion_tokens": 20} + + resp = gemini_gpt.completion(messages) + assert resp == default_resp + resp = await gemini_gpt.acompletion(messages) assert resp.text == default_resp.text diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index ba019f295..1c604768e 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -2,6 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of ollama api +import json +from typing import Any, Tuple + import pytest from metagpt.config import CONFIG @@ -14,25 +17,33 @@ resp_content = "I'm ollama" default_resp = {"message": {"role": "assistant", "content": resp_content}} CONFIG.ollama_api_base = "http://xxx" +CONFIG.max_budget = 10 -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: - return default_resp +async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]: + if stream: + class Iterator(object): + events = [ + b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}', + b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}', + ] -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: - return default_resp + async def __aiter__(self): + for event in self.events: + yield event - -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content + return Iterator(), None, None + else: + raw_default_resp = default_resp.copy() + raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20}) + return json.dumps(raw_default_resp).encode(), None, None @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaLLM._achat_completion_stream", mock_llm_achat_completion_stream) + mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest) + ollama_gpt = OllamaLLM() resp = await ollama_gpt.acompletion(messages) diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py new file mode 100644 index 000000000..bf094d54a --- /dev/null +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, Choice as AChoice +from openai.types.completion_usage import CompletionUsage + +from metagpt.config import CONFIG +from metagpt.provider.open_llm_api import OpenLLM +from metagpt.utils.cost_manager import Costs + +CONFIG.max_budget = 10 +CONFIG.calc_usage = True + +resp_content = "I'm llama2" +default_resp = ChatCompletion( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="llama-v2-13b-chat", + object="chat.completion", + created=1703302755, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=resp_content), + logprobs=None, + ) + ] +) + +default_resp_chunk = ChatCompletionChunk( + id=default_resp.id, + model=default_resp.model, + object="chat.completion.chunk", + created=default_resp.created, + choices=[ + AChoice( + delta=ChoiceDelta( + content=resp_content, + role="assistant" + ), + finish_reason="stop", + index=0, + logprobs=None, + ) + ] +) + +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] + + +async def mock_openai_acompletions_create(self, stream: bool=False, **kwargs) -> ChatCompletionChunk: + if stream: + class Iterator(object): + async def __aiter__(self): + yield default_resp_chunk + return Iterator() + else: + return default_resp + + +@pytest.mark.asyncio +async def test_openllm_acompletion(mocker): + mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) + + openllm_gpt = OpenLLM() + openllm_gpt.model = "llama-v2-13b-chat" + + openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) + assert openllm_gpt.get_costs() == Costs(total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0) + + resp = await openllm_gpt.acompletion(messages) + assert resp.choices[0].message.content in resp_content + + resp = await openllm_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await openllm_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await openllm_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await openllm_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index cb86dfcf9..ddc290731 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -2,9 +2,14 @@ from unittest.mock import Mock import pytest +from metagpt.config import CONFIG from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import UserMessage +CONFIG.openai_proxy = None + +print("openai_api_key ", CONFIG.openai_api_key) + @pytest.mark.asyncio async def test_aask_code(): diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index e62c287c0..6d5a0e1f6 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,24 +4,31 @@ import pytest -from metagpt.provider.spark_api import SparkLLM +from metagpt.config import CONFIG +from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM + +CONFIG.spark_appid = "xxx" +CONFIG.spark_api_secret = "xxx" +CONFIG.spark_api_key = "xxx" +CONFIG.domain = "xxxxxx" +CONFIG.spark_url = "xxxx" prompt_msg = "who are you" resp_content = "I'm Spark" -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str: - return resp_content +def test_get_msg_from_web(): + get_msg_from_web = GetMessageFromWeb(text=prompt_msg) + assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx" -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str: +def mock_spark_get_msg_from_web_run(self) -> str: return resp_content @pytest.mark.asyncio async def test_spark_acompletion(mocker): - mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.spark_api.SparkLLM.acompletion_text", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) spark_gpt = SparkLLM() resp = await spark_gpt.acompletion([]) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index c1af2f0be..826e706e8 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -3,36 +3,68 @@ # @Desc : the unittest of ZhiPuAILLM import pytest +from zhipuai.utils.sse_client import Event from metagpt.config import CONFIG from metagpt.provider.zhipuai_api import ZhiPuAILLM -CONFIG.zhipuai_api_key = "xxx" +CONFIG.zhipuai_api_key = "xxx.xxx" prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] resp_content = "I'm chatglm-turbo" -default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}} +default_resp = { + "code": 200, + "data": { + "choices": [{"role": "assistant", "content": resp_content}], + "usage": {"prompt_tokens": 20, "completion_tokens": 20}, + }, +} -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: +def mock_zhipuai_invoke(**kwargs) -> dict: return default_resp -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: +async def mock_zhipuai_ainvoke(**kwargs) -> dict: return default_resp -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content +async def mock_zhipuai_asse_invoke(**kwargs): + class MockResponse(object): + async def _aread(self): + class Iterator(object): + events = [ + Event(id="xxx", event="add", data=resp_content, retry=0), + Event( + id="xxx", + event="finish", + data="", + meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}', + ), + ] + + async def __aiter__(self): + for event in self.events: + yield event + + async for chunk in Iterator(): + yield chunk + + async def async_events(self): + async for chunk in self._aread(): + yield chunk + + return MockResponse() @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion", mock_llm_acompletion) - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAILLM._achat_completion_stream", mock_llm_achat_completion_stream) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke) + zhipu_gpt = ZhiPuAILLM() resp = await zhipu_gpt.acompletion(messages) @@ -51,7 +83,7 @@ async def test_zhipuai_acompletion(mocker): assert resp == resp_content -def test_zhipuai_proxy(mocker): +def test_zhipuai_proxy(): import openai from metagpt.config import CONFIG diff --git a/tests/metagpt/provider/zhipuai/__init__.py b/tests/metagpt/provider/zhipuai/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/provider/zhipuai/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/provider/zhipuai/test_async_sse_client.py b/tests/metagpt/provider/zhipuai/test_async_sse_client.py new file mode 100644 index 000000000..af75e40df --- /dev/null +++ b/tests/metagpt/provider/zhipuai/test_async_sse_client.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient + + +@pytest.mark.asyncio +async def test_async_sse_client(): + class Iterator(object): + async def __aiter__(self): + yield b"data: test_value" + + async_sse_client = AsyncSSEClient(event_source=Iterator()) + async for event in async_sse_client.async_events(): + assert event.data, "test_value" diff --git a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py new file mode 100644 index 000000000..b3838e813 --- /dev/null +++ b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Any, Tuple +import pytest + +import zhipuai +from zhipuai.model_api.api import InvokeType +from zhipuai.utils.http_client import headers as zhipuai_default_headers + +from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI + + +api_key = "xxx.xxx" +zhipuai.api_key = api_key + +default_resp = {"result": "test response"} + + +async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]: + return default_resp, None, None + + +@pytest.mark.asyncio +async def test_zhipu_model_api(mocker): + header = ZhiPuModelAPI.get_header() + zhipuai_default_headers.update({"Authorization": api_key}) + assert header == zhipuai_default_headers + + sse_header = ZhiPuModelAPI.get_sse_header() + assert len(sse_header["Authorization"]) == 191 + + url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"}) + assert url_prefix == "https://open.bigmodel.cn/api" + assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke" + + mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest) + result = await ZhiPuModelAPI.arequest(InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}) + assert result == default_resp From c8e351f3c863950d9d23b2f556d028227b53c2b1 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 02:45:54 +0800 Subject: [PATCH 4/9] format --- .../test_base_postprocess_plugin.py | 29 ++++++-------- .../test_llm_output_postprocess.py | 9 +++-- .../metagpt/provider/test_azure_openai_api.py | 5 +-- .../metagpt/provider/test_general_api_base.py | 39 +++++++++++++------ tests/metagpt/provider/test_open_llm_api.py | 24 ++++++------ .../provider/zhipuai/test_async_sse_client.py | 2 +- .../provider/zhipuai/test_zhipu_model_api.py | 7 ++-- 7 files changed, 63 insertions(+), 52 deletions(-) diff --git a/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py index e63e4ecfe..824bb88f3 100644 --- a/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py +++ b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py @@ -1,8 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : -import pytest from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin @@ -14,25 +13,19 @@ raw_output = """ [/CONTENT] """ raw_schema = { - "title":"prd", - "type":"object", - "properties":{ - "Original Requirements":{ - "title":"Original Requirements", - "type":"string" - }, - }, - "required":[ - "Original Requirements", - ] - } + "title": "prd", + "type": "object", + "properties": { + "Original Requirements": {"title": "Original Requirements", "type": "string"}, + }, + "required": [ + "Original Requirements", + ], +} def test_llm_post_process_plugin(): post_process_plugin = BasePostProcessPlugin() - output = post_process_plugin.run( - output=raw_output, - schema=raw_schema - ) + output = post_process_plugin.run(output=raw_output, schema=raw_schema) assert "Original Requirements" in output diff --git a/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py index 3cb627216..40457b186 100644 --- a/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py +++ b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py @@ -1,12 +1,13 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : -import pytest from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess - -from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import raw_output, raw_schema +from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import ( + raw_output, + raw_schema, +) def test_llm_output_postprocess(): diff --git a/tests/metagpt/provider/test_azure_openai_api.py b/tests/metagpt/provider/test_azure_openai_api.py index 208e3104a..f36740e65 100644 --- a/tests/metagpt/provider/test_azure_openai_api.py +++ b/tests/metagpt/provider/test_azure_openai_api.py @@ -1,11 +1,10 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : -import pytest -from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.config import CONFIG +from metagpt.provider.azure_openai_api import AzureOpenAILLM CONFIG.OPENAI_API_VERSION = "xx" CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value diff --git a/tests/metagpt/provider/test_general_api_base.py b/tests/metagpt/provider/test_general_api_base.py index 52ba32f01..ae768ce95 100644 --- a/tests/metagpt/provider/test_general_api_base.py +++ b/tests/metagpt/provider/test_general_api_base.py @@ -1,16 +1,27 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : -import pytest import os -import requests -import aiohttp -from typing import Iterator, Tuple, Union, Generator, AsyncGenerator +from typing import AsyncGenerator, Generator, Iterator, Tuple, Union +import aiohttp +import pytest +import requests from openai import OpenAIError -from metagpt.provider.general_api_base import ApiType, log_debug, log_info, log_warn, OpenAIResponse, \ - _requests_proxies_arg, _aiohttp_proxies_arg, _make_session, parse_stream_helper, parse_stream, APIRequestor + +from metagpt.provider.general_api_base import ( + APIRequestor, + ApiType, + OpenAIResponse, + _make_session, + _requests_proxies_arg, + log_debug, + log_info, + log_warn, + parse_stream, + parse_stream_helper, +) def test_basic(): @@ -60,13 +71,15 @@ def test_parse_stream(): api_requestor = APIRequestor(base_url="http://www.baidu.com") -def mock_interpret_response(self, result: requests.Response, stream: bool - ) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: +def mock_interpret_response( + self, result: requests.Response, stream: bool +) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: return b"baidu", False -async def mock_interpret_async_response(self, result: aiohttp.ClientResponse, stream: bool - ) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: +async def mock_interpret_async_response( + self, result: aiohttp.ClientResponse, stream: bool +) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: return b"baidu", True @@ -79,6 +92,8 @@ def test_api_requestor(mocker): @pytest.mark.asyncio async def test_async_api_requestor(mocker): - mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response) + mocker.patch( + "metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response + ) resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu") resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu") diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index bf094d54a..85069c5e1 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -1,15 +1,16 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : import pytest - from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, Choice, ) -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta, Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage from metagpt.config import CONFIG @@ -32,7 +33,7 @@ default_resp = ChatCompletion( message=ChatCompletionMessage(role="assistant", content=resp_content), logprobs=None, ) - ] + ], ) default_resp_chunk = ChatCompletionChunk( @@ -42,26 +43,25 @@ default_resp_chunk = ChatCompletionChunk( created=default_resp.created, choices=[ AChoice( - delta=ChoiceDelta( - content=resp_content, - role="assistant" - ), + delta=ChoiceDelta(content=resp_content, role="assistant"), finish_reason="stop", index=0, logprobs=None, ) - ] + ], ) prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] -async def mock_openai_acompletions_create(self, stream: bool=False, **kwargs) -> ChatCompletionChunk: +async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: if stream: + class Iterator(object): async def __aiter__(self): yield default_resp_chunk + return Iterator() else: return default_resp @@ -75,7 +75,9 @@ async def test_openllm_acompletion(mocker): openllm_gpt.model = "llama-v2-13b-chat" openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) - assert openllm_gpt.get_costs() == Costs(total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0) + assert openllm_gpt.get_costs() == Costs( + total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 + ) resp = await openllm_gpt.acompletion(messages) assert resp.choices[0].message.content in resp_content diff --git a/tests/metagpt/provider/zhipuai/test_async_sse_client.py b/tests/metagpt/provider/zhipuai/test_async_sse_client.py index af75e40df..9e5bd5f2e 100644 --- a/tests/metagpt/provider/zhipuai/test_async_sse_client.py +++ b/tests/metagpt/provider/zhipuai/test_async_sse_client.py @@ -1,6 +1,6 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : +# @Desc : import pytest diff --git a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py index b3838e813..83ae2de60 100644 --- a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py +++ b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py @@ -3,15 +3,14 @@ # @Desc : from typing import Any, Tuple -import pytest +import pytest import zhipuai from zhipuai.model_api.api import InvokeType from zhipuai.utils.http_client import headers as zhipuai_default_headers from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI - api_key = "xxx.xxx" zhipuai.api_key = api_key @@ -36,5 +35,7 @@ async def test_zhipu_model_api(mocker): assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke" mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest) - result = await ZhiPuModelAPI.arequest(InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}) + result = await ZhiPuModelAPI.arequest( + InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"} + ) assert result == default_resp From edce4ac47afed589b81bd856b5cf3752ccd41329 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 03:10:23 +0800 Subject: [PATCH 5/9] fix memory unittest --- metagpt/memory/longterm_memory.py | 5 +++-- tests/metagpt/memory/test_longterm_memory.py | 4 ++++ tests/metagpt/memory/test_memory_storage.py | 4 ++-- 3 files changed, 9 insertions(+), 4 deletions(-) diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 8da6ed84a..b54653970 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -12,6 +12,7 @@ from pydantic import ConfigDict, Field from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage +from metagpt.roles.role import RoleContext from metagpt.schema import Message @@ -25,10 +26,10 @@ class LongTermMemory(Memory): model_config = ConfigDict(arbitrary_types_allowed=True) memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) - rc: Optional["RoleContext"] = None + rc: Optional[RoleContext] = None msg_from_recover: bool = False - def recover_memory(self, role_id: str, rc: "RoleContext"): + def recover_memory(self, role_id: str, rc: RoleContext): messages = self.memory_storage.recover_memory(role_id) self.rc = rc if not self.memory_storage.is_initialized: diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index ac33552b3..c915a6610 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -20,6 +20,10 @@ def test_ltm_search(): assert len(CONFIG.openai_api_key) > 20 role_id = "UTUserLtm(Product Manager)" + from metagpt.environment import Environment + + Environment + RoleContext.model_rebuild() rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"}) ltm = LongTermMemory() ltm.recover_memory(role_id, rc) diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index f1cc12aac..0eb1069d5 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -24,7 +24,7 @@ def test_idea_message(): role_id = "UTUser1(Product Manager)" message = Message(role="User", content=idea, cause_by=UserRequirement) - shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/")) + shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True) memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) @@ -58,7 +58,7 @@ def test_actionout_message(): content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD ) # WritePRD as test action - shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/")) + shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True) memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) From 311e48b6042dc0c525b3ba3f4dcf36d006bf95fd Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 04:26:50 +0800 Subject: [PATCH 6/9] fix debate with send_to --- metagpt/roles/role.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 29f3b0595..81815e91b 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -394,7 +394,9 @@ class Role(SerializationMixin, is_polymorphic_base=True): old_messages = [] if ignore_memory else self.rc.memory.get() self.rc.memory.add_batch(news) # Filter out messages of interest. - self.rc.news = [n for n in news if n.cause_by in self.rc.watch and n not in old_messages] + self.rc.news = [ + n for n in news if (n.cause_by in self.rc.watch or self.name in n.send_to) and n not in old_messages + ] self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None # record the latest observed msg # Design Rules: From e52957026be06c276251276196ffd8748e3a6efc Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 04:27:44 +0800 Subject: [PATCH 7/9] update ser&deser unittest --- metagpt/actions/debug_error.py | 2 +- metagpt/actions/design_api.py | 2 +- metagpt/actions/design_api_review.py | 2 +- metagpt/actions/execute_task.py | 2 +- metagpt/actions/invoice_ocr.py | 2 +- metagpt/actions/prepare_documents.py | 2 +- metagpt/actions/project_management.py | 2 +- metagpt/actions/research.py | 2 +- metagpt/actions/run_code.py | 2 +- metagpt/actions/search_and_summarize.py | 2 +- metagpt/actions/summarize_code.py | 2 +- metagpt/actions/write_code.py | 2 +- metagpt/actions/write_code_review.py | 2 +- metagpt/actions/write_docstring.py | 2 +- metagpt/actions/write_prd.py | 2 +- metagpt/actions/write_prd_review.py | 2 +- metagpt/actions/write_review.py | 2 +- metagpt/actions/write_teaching_plan.py | 2 +- metagpt/actions/write_test.py | 2 +- metagpt/actions/write_tutorial.py | 2 +- metagpt/schema.py | 2 +- tests/metagpt/serialize_deserialize/test_action.py | 2 +- tests/metagpt/serialize_deserialize/test_environment.py | 3 ++- tests/metagpt/serialize_deserialize/test_write_code.py | 2 -- .../metagpt/serialize_deserialize/test_write_code_review.py | 2 -- tests/metagpt/serialize_deserialize/test_write_design.py | 3 --- tests/metagpt/serialize_deserialize/test_write_prd.py | 2 -- tests/metagpt/utils/test_serialize.py | 6 +++--- 28 files changed, 27 insertions(+), 35 deletions(-) diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 1a7c3a7c8..710dff344 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -52,7 +52,7 @@ Now you should start rewriting the code: class DebugError(Action): name: str = "DebugError" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 32e2a2a19..e8cf139e8 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -44,7 +44,7 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) desc: str = ( "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " diff --git a/metagpt/actions/design_api_review.py b/metagpt/actions/design_api_review.py index 6ea76e2fc..a9ae15ad8 100644 --- a/metagpt/actions/design_api_review.py +++ b/metagpt/actions/design_api_review.py @@ -18,7 +18,7 @@ from metagpt.provider.base_llm import BaseLLM class DesignReview(Action): name: str = "DesignReview" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, prd, api_design): prompt = ( diff --git a/metagpt/actions/execute_task.py b/metagpt/actions/execute_task.py index 8577ee275..2c35b541d 100644 --- a/metagpt/actions/execute_task.py +++ b/metagpt/actions/execute_task.py @@ -17,7 +17,7 @@ from metagpt.schema import Message class ExecuteTask(Action): name: str = "ExecuteTask" context: list[Message] = [] - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, *args, **kwargs): pass diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index 94288d5be..b9eb2c396 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -42,7 +42,7 @@ class InvoiceOCR(Action): name: str = "InvoiceOCR" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) @staticmethod async def _check_file_type(file_path: Path) -> str: diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 97d3828bf..5ca6877d4 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -28,7 +28,7 @@ class PrepareDocuments(Action): name: str = "PrepareDocuments" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) def _init_repo(self): """Initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index a53f13e4c..9c2ec8cda 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -43,7 +43,7 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index a1535a723..41571f776 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -82,7 +82,7 @@ class CollectLinks(Action): name: str = "CollectLinks" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) desc: str = "Collect links from a search engine." search_engine: SearchEngine = Field(default_factory=SearchEngine) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 22d345b85..5b9e26fa9 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -79,7 +79,7 @@ standard errors: class RunCode(Action): name: str = "RunCode" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) @classmethod @handle_exception diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index cd3ef7d77..e8276a79e 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -109,7 +109,7 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) config: None = Field(default_factory=Config) engine: Optional[SearchEngineType] = CONFIG.search_engine search_func: Optional[Any] = None diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 4025e0964..5db7a7b0a 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -95,7 +95,7 @@ flowchart TB class SummarizeCode(Action): name: str = "SummarizeCode" context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index e3086f03c..3e29a9494 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -90,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Document = Field(default_factory=Document) - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index a8ed0fd01..138bde289 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -123,7 +123,7 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: CodingContext = Field(default_factory=CodingContext) - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index 68856c360..462e2d077 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -163,7 +163,7 @@ class WriteDocstring(Action): desc: str = "Write docstring for code." context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run( self, diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 9cefb70d8..17b5573ae 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -68,7 +68,7 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): name: str = "" content: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 9199e7536..a332d24c3 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -18,7 +18,7 @@ from metagpt.provider.base_llm import BaseLLM class WritePRDReview(Action): name: str = "" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) prd: Optional[str] = None desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index d116556ba..64b8450e9 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -38,7 +38,7 @@ class WriteReview(Action): """Write a review for the given context.""" name: str = "WriteReview" - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, context): return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py index 888627294..dae553b79 100644 --- a/metagpt/actions/write_teaching_plan.py +++ b/metagpt/actions/write_teaching_plan.py @@ -20,7 +20,7 @@ class WriteTeachingPlanPart(Action): """Write Teaching Plan Part""" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) topic: str = "" language: str = "Chinese" rsp: Optional[str] = None diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 321d31420..5bff34017 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -45,7 +45,7 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): name: str = "WriteTest" context: Optional[TestingContext] = None - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index a2a324b41..67bc85eef 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -27,7 +27,7 @@ class WriteDirectory(Action): """ name: str = "WriteDirectory" - llm: BaseLLM = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) language: str = "Chinese" async def run(self, topic: str, *args, **kwargs) -> Dict: diff --git a/metagpt/schema.py b/metagpt/schema.py index 41303ea46..5dde0ee46 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -174,7 +174,7 @@ class Message(BaseModel): role: str = "user" # system / user / assistant cause_by: str = Field(default="", validate_default=True) sent_from: str = Field(default="", validate_default=True) - send_to: set = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) + send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) @field_validator("id", mode="before") @classmethod diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index b3206696b..677988e2f 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -28,5 +28,5 @@ async def test_action_deserialize(): new_action = Action(**serialized_data) assert new_action.name == "" - assert new_action.llm == LLM() + assert isinstance(new_action.llm, type(LLM())) assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py index 557c3f4cd..5a68288a6 100644 --- a/tests/metagpt/serialize_deserialize/test_environment.py +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -13,6 +13,7 @@ from metagpt.schema import Message from metagpt.utils.common import any_to_str from tests.metagpt.serialize_deserialize.test_serdeser_base import ( ActionOK, + ActionRaise, RoleC, serdeser_path, ) @@ -55,9 +56,9 @@ def test_environment_serdeser(): assert len(new_env.roles) == 1 assert list(new_env.roles.values())[0].states == list(environment.roles.values())[0].states - assert list(new_env.roles.values())[0].actions == list(environment.roles.values())[0].actions assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK) assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK + assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise def test_environment_serdeser_v2(): diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 2fb669a6b..cb262bb45 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -6,7 +6,6 @@ import pytest from metagpt.actions import WriteCode -from metagpt.llm import LLM from metagpt.schema import CodingContext, Document @@ -28,5 +27,4 @@ async def test_write_code_deserialize(): new_action = WriteCode(**serialized_data) assert new_action.name == "WriteCode" - assert new_action.llm == LLM() await action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py index e9ad4b858..991b3c13b 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code_review.py +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -5,7 +5,6 @@ import pytest from metagpt.actions import WriteCodeReview -from metagpt.llm import LLM from metagpt.schema import CodingContext, Document @@ -28,5 +27,4 @@ def div(a: int, b: int = 0): new_action = WriteCodeReview(**serialized_data) assert new_action.name == "WriteCodeReview" - assert new_action.llm == LLM() await new_action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index d556c144d..a2fce8047 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -5,7 +5,6 @@ import pytest from metagpt.actions import WriteDesign, WriteTasks -from metagpt.llm import LLM def test_write_design_serialize(): @@ -28,7 +27,6 @@ async def test_write_design_deserialize(): serialized_data = action.model_dump() new_action = WriteDesign(**serialized_data) assert new_action.name == "" - assert new_action.llm == LLM() await new_action.run(with_messages="write a cli snake game") @@ -38,5 +36,4 @@ async def test_write_task_deserialize(): serialized_data = action.model_dump() new_action = WriteTasks(**serialized_data) assert new_action.name == "CreateTasks" - assert new_action.llm == LLM() await new_action.run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_write_prd.py b/tests/metagpt/serialize_deserialize/test_write_prd.py index 79b9a8677..69238545f 100644 --- a/tests/metagpt/serialize_deserialize/test_write_prd.py +++ b/tests/metagpt/serialize_deserialize/test_write_prd.py @@ -6,7 +6,6 @@ import pytest from metagpt.actions import WritePRD -from metagpt.llm import LLM from metagpt.schema import Message @@ -23,6 +22,5 @@ async def test_action_deserialize(): serialized_data = action.model_dump() new_action = WritePRD(**serialized_data) assert new_action.name == "" - assert new_action.llm == LLM() action_output = await new_action.run(with_messages=Message(content="write a cli snake game")) assert len(action_output.content) > 0 diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index f027d53f8..0ba3a8d41 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -4,7 +4,7 @@ @Desc : the unittest of serialize """ -from typing import List, Tuple +from typing import List from metagpt.actions import WritePRD from metagpt.actions.action_node import ActionNode @@ -27,7 +27,7 @@ def test_actionoutout_schema_to_mapping(): "properties": {"field": {"title": "field", "type": "array", "items": {"type": "string"}}}, } mapping = actionoutout_schema_to_mapping(schema) - assert mapping["field"] == (List[str], ...) + assert mapping["field"] == (list[str], ...) schema = { "title": "test", @@ -46,7 +46,7 @@ def test_actionoutout_schema_to_mapping(): }, } mapping = actionoutout_schema_to_mapping(schema) - assert mapping["field"] == (List[Tuple[str, str]], ...) + assert mapping["field"] == (list[list[str]], ...) assert True, True From 65671a3bca0fffd5a3f8f2577cbe99a9254d3d67 Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 09:22:31 +0800 Subject: [PATCH 8/9] no need to define new llm field in subclass again --- metagpt/actions/debug_error.py | 2 -- metagpt/actions/design_api.py | 5 ----- metagpt/actions/design_api_review.py | 5 ----- metagpt/actions/execute_task.py | 4 ---- metagpt/actions/invoice_ocr.py | 1 - metagpt/actions/prepare_documents.py | 5 ----- metagpt/actions/project_management.py | 5 ----- metagpt/actions/research.py | 1 - metagpt/actions/run_code.py | 2 -- metagpt/actions/search_and_summarize.py | 3 --- metagpt/actions/summarize_code.py | 2 -- metagpt/actions/write_code.py | 3 --- metagpt/actions/write_code_review.py | 3 --- metagpt/actions/write_docstring.py | 5 ----- metagpt/actions/write_prd.py | 5 ----- metagpt/actions/write_prd_review.py | 5 ----- metagpt/actions/write_review.py | 5 ----- metagpt/actions/write_teaching_plan.py | 5 ----- metagpt/actions/write_test.py | 5 ----- metagpt/actions/write_tutorial.py | 6 ------ 20 files changed, 77 deletions(-) diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 710dff344..34f784072 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -15,7 +15,6 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO -from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser @@ -52,7 +51,6 @@ Now you should start rewriting the code: class DebugError(Action): name: str = "DebugError" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index e8cf139e8..03f3d7704 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -13,8 +13,6 @@ import json from pathlib import Path from typing import Optional -from pydantic import Field - from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE from metagpt.config import CONFIG @@ -25,9 +23,7 @@ from metagpt.const import ( SYSTEM_DESIGN_FILE_REPO, SYSTEM_DESIGN_PDF_FILE_REPO, ) -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, Documents, Message from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file @@ -44,7 +40,6 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) desc: str = ( "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " diff --git a/metagpt/actions/design_api_review.py b/metagpt/actions/design_api_review.py index a9ae15ad8..fb1b92d85 100644 --- a/metagpt/actions/design_api_review.py +++ b/metagpt/actions/design_api_review.py @@ -8,17 +8,12 @@ from typing import Optional -from pydantic import Field - from metagpt.actions.action import Action -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM class DesignReview(Action): name: str = "DesignReview" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, prd, api_design): prompt = ( diff --git a/metagpt/actions/execute_task.py b/metagpt/actions/execute_task.py index 2c35b541d..4ae4ee17b 100644 --- a/metagpt/actions/execute_task.py +++ b/metagpt/actions/execute_task.py @@ -6,18 +6,14 @@ @File : execute_task.py """ -from pydantic import Field from metagpt.actions import Action -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message class ExecuteTask(Action): name: str = "ExecuteTask" context: list[Message] = [] - llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, *args, **kwargs): pass diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index b9eb2c396..826d37ef7 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -42,7 +42,6 @@ class InvoiceOCR(Action): name: str = "InvoiceOCR" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) @staticmethod async def _check_file_type(file_path: Path) -> str: diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index aa880b5be..a936ea655 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -11,13 +11,9 @@ import shutil from pathlib import Path from typing import Optional -from pydantic import Field - from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository @@ -28,7 +24,6 @@ class PrepareDocuments(Action): name: str = "PrepareDocuments" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) def _init_repo(self): """Initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 93c1a852d..b33f3426d 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -13,8 +13,6 @@ import json from typing import Optional -from pydantic import Field - from metagpt.actions import ActionOutput from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE @@ -25,9 +23,7 @@ from metagpt.const import ( TASK_FILE_REPO, TASK_PDF_FILE_REPO, ) -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository @@ -43,7 +39,6 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 875aa7192..90b08cb6a 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -82,7 +82,6 @@ class CollectLinks(Action): name: str = "CollectLinks" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) desc: str = "Collect links from a search engine." search_engine: SearchEngine = Field(default_factory=SearchEngine) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index 010cab4a8..30b06f1a6 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -22,7 +22,6 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.config import CONFIG -from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.exceptions import handle_exception @@ -79,7 +78,6 @@ standard errors: class RunCode(Action): name: str = "RunCode" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseLLM = Field(default_factory=LLM, exclude=True) @classmethod async def run_text(cls, code) -> Tuple[str, str]: diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index e8276a79e..d2e361f73 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -12,9 +12,7 @@ from pydantic import Field, model_validator from metagpt.actions import Action from metagpt.config import CONFIG, Config -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -109,7 +107,6 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) config: None = Field(default_factory=Config) engine: Optional[SearchEngineType] = CONFIG.search_engine search_func: Optional[Any] = None diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 5db7a7b0a..bdad546d7 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -13,7 +13,6 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO -from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext from metagpt.utils.file_repository import FileRepository @@ -95,7 +94,6 @@ flowchart TB class SummarizeCode(Action): name: str = "SummarizeCode" context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) - llm: BaseLLM = Field(default_factory=LLM, exclude=True) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 3e29a9494..25c4912c3 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -29,9 +29,7 @@ from metagpt.const import ( TASK_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -90,7 +88,6 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Document = Field(default_factory=Document) - llm: BaseLLM = Field(default_factory=LLM, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index 138bde289..a8c913573 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -14,9 +14,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode from metagpt.actions.action import Action from metagpt.config import CONFIG -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser @@ -123,7 +121,6 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: CodingContext = Field(default_factory=CodingContext) - llm: BaseLLM = Field(default_factory=LLM, exclude=True) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index f77226832..8b8335517 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -27,11 +27,7 @@ import ast from pathlib import Path from typing import Literal, Optional -from pydantic import Field - from metagpt.actions.action import Action -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser, aread, awrite from metagpt.utils.pycst import merge_docstring @@ -166,7 +162,6 @@ class WriteDocstring(Action): desc: str = "Write docstring for code." context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run( self, diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index df9b7549b..d51c0a7be 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -17,8 +17,6 @@ import json from pathlib import Path from typing import Optional -from pydantic import Field - from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.fix_bug import FixBug @@ -37,9 +35,7 @@ from metagpt.const import ( PRDS_FILE_REPO, REQUIREMENT_FILENAME, ) -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import BugFixContext, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -68,7 +64,6 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): name: str = "WritePRD" content: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index a332d24c3..2babe38db 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -8,17 +8,12 @@ from typing import Optional -from pydantic import Field - from metagpt.actions.action import Action -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM class WritePRDReview(Action): name: str = "" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) prd: Optional[str] = None desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index 64b8450e9..db8512946 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -6,12 +6,8 @@ """ from typing import List -from pydantic import Field - from metagpt.actions import Action from metagpt.actions.action_node import ActionNode -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM REVIEW = ActionNode( key="Review", @@ -38,7 +34,6 @@ class WriteReview(Action): """Write a review for the given context.""" name: str = "WriteReview" - llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def run(self, context): return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py index dae553b79..b824e055e 100644 --- a/metagpt/actions/write_teaching_plan.py +++ b/metagpt/actions/write_teaching_plan.py @@ -7,20 +7,15 @@ """ from typing import Optional -from pydantic import Field - from metagpt.actions import Action from metagpt.config import CONFIG -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM class WriteTeachingPlanPart(Action): """Write Teaching Plan Part""" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) topic: str = "" language: str = "Chinese" rsp: Optional[str] = None diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 5bff34017..0166f5417 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -10,14 +10,10 @@ from typing import Optional -from pydantic import Field - from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, TestingContext from metagpt.utils.common import CodeParser @@ -45,7 +41,6 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): name: str = "WriteTest" context: Optional[TestingContext] = None - llm: BaseLLM = Field(default_factory=LLM, exclude=True) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index 67bc85eef..184cd8573 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -9,12 +9,8 @@ from typing import Dict -from pydantic import Field - from metagpt.actions import Action -from metagpt.llm import LLM from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT -from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser @@ -27,7 +23,6 @@ class WriteDirectory(Action): """ name: str = "WriteDirectory" - llm: BaseLLM = Field(default_factory=LLM, exclude=True) language: str = "Chinese" async def run(self, topic: str, *args, **kwargs) -> Dict: @@ -54,7 +49,6 @@ class WriteContent(Action): """ name: str = "WriteContent" - llm: BaseLLM = Field(default_factory=LLM) directory: dict = dict() language: str = "Chinese" From 961fecf8c05bdf96d3078fa4e6112e4a3c0bbcff Mon Sep 17 00:00:00 2001 From: better629 Date: Fri, 29 Dec 2023 09:24:34 +0800 Subject: [PATCH 9/9] update write_prd ut --- tests/metagpt/serialize_deserialize/test_write_prd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/metagpt/serialize_deserialize/test_write_prd.py b/tests/metagpt/serialize_deserialize/test_write_prd.py index 69238545f..890e2438b 100644 --- a/tests/metagpt/serialize_deserialize/test_write_prd.py +++ b/tests/metagpt/serialize_deserialize/test_write_prd.py @@ -21,6 +21,6 @@ async def test_action_deserialize(): action = WritePRD() serialized_data = action.model_dump() new_action = WritePRD(**serialized_data) - assert new_action.name == "" + assert new_action.name == "WritePRD" action_output = await new_action.run(with_messages=Message(content="write a cli snake game")) assert len(action_output.content) > 0