diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 8488fbe4c..4c06d0d1d 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -17,7 +17,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.config import CONFIG from metagpt.llm import BaseLLM from metagpt.logs import logger -from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess from metagpt.utils.common import OutputParser, general_after_log TAG = "CONTENT" @@ -275,7 +275,7 @@ class ActionNode: output_class = self.create_model_class(output_class_name, output_data_mapping) if schema == "json": - parsed_data = llm_output_postprecess( + parsed_data = llm_output_postprocess( output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]" ) else: # using markdown parser diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 1a7c3a7c8..34f784072 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -15,7 +15,6 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO -from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser @@ -52,7 +51,6 @@ Now you should start rewriting the code: class DebugError(Action): name: str = "DebugError" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseLLM = Field(default_factory=LLM) async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 32e2a2a19..03f3d7704 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -13,8 +13,6 @@ import json from pathlib import Path from typing import Optional -from pydantic import Field - from metagpt.actions import Action, ActionOutput from metagpt.actions.design_api_an import DESIGN_API_NODE from metagpt.config import CONFIG @@ -25,9 +23,7 @@ from metagpt.const import ( SYSTEM_DESIGN_FILE_REPO, SYSTEM_DESIGN_PDF_FILE_REPO, ) -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, Documents, Message from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file @@ -44,7 +40,6 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) desc: str = ( "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " diff --git a/metagpt/actions/design_api_review.py b/metagpt/actions/design_api_review.py index 6ea76e2fc..fb1b92d85 100644 --- a/metagpt/actions/design_api_review.py +++ b/metagpt/actions/design_api_review.py @@ -8,17 +8,12 @@ from typing import Optional -from pydantic import Field - from metagpt.actions.action import Action -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM class DesignReview(Action): name: str = "DesignReview" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) async def run(self, prd, api_design): prompt = ( diff --git a/metagpt/actions/execute_task.py b/metagpt/actions/execute_task.py index 8577ee275..4ae4ee17b 100644 --- a/metagpt/actions/execute_task.py +++ b/metagpt/actions/execute_task.py @@ -6,18 +6,14 @@ @File : execute_task.py """ -from pydantic import Field from metagpt.actions import Action -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message class ExecuteTask(Action): name: str = "ExecuteTask" context: list[Message] = [] - llm: BaseLLM = Field(default_factory=LLM) async def run(self, *args, **kwargs): pass diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index 94288d5be..826d37ef7 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -42,7 +42,6 @@ class InvoiceOCR(Action): name: str = "InvoiceOCR" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) @staticmethod async def _check_file_type(file_path: Path) -> str: diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index c0aa9d9d6..a936ea655 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -11,13 +11,9 @@ import shutil from pathlib import Path from typing import Optional -from pydantic import Field - from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository @@ -28,7 +24,6 @@ class PrepareDocuments(Action): name: str = "PrepareDocuments" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) def _init_repo(self): """Initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index a32bf6151..b33f3426d 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -13,8 +13,6 @@ import json from typing import Optional -from pydantic import Field - from metagpt.actions import ActionOutput from metagpt.actions.action import Action from metagpt.actions.project_management_an import PM_NODE @@ -25,9 +23,7 @@ from metagpt.const import ( TASK_FILE_REPO, TASK_PDF_FILE_REPO, ) -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository @@ -43,7 +39,6 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index 2e042ef83..90b08cb6a 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -82,7 +82,6 @@ class CollectLinks(Action): name: str = "CollectLinks" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) desc: str = "Collect links from a search engine." search_engine: SearchEngine = Field(default_factory=SearchEngine) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index d22aa47ce..30b06f1a6 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -22,7 +22,6 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.config import CONFIG -from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.exceptions import handle_exception @@ -79,7 +78,6 @@ standard errors: class RunCode(Action): name: str = "RunCode" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseLLM = Field(default_factory=LLM) @classmethod async def run_text(cls, code) -> Tuple[str, str]: diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index cd3ef7d77..d2e361f73 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -12,9 +12,7 @@ from pydantic import Field, model_validator from metagpt.actions import Action from metagpt.config import CONFIG, Config -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -109,7 +107,6 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) config: None = Field(default_factory=Config) engine: Optional[SearchEngineType] = CONFIG.search_engine search_func: Optional[Any] = None diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 4025e0964..bdad546d7 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -13,7 +13,6 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO -from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext from metagpt.utils.file_repository import FileRepository @@ -95,7 +94,6 @@ flowchart TB class SummarizeCode(Action): name: str = "SummarizeCode" context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) - llm: BaseLLM = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index e3086f03c..25c4912c3 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -29,9 +29,7 @@ from metagpt.const import ( TASK_FILE_REPO, TEST_OUTPUTS_FILE_REPO, ) -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -90,7 +88,6 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Document = Field(default_factory=Document) - llm: BaseLLM = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index a8ed0fd01..a8c913573 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -14,9 +14,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions import WriteCode from metagpt.actions.action import Action from metagpt.config import CONFIG -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser @@ -123,7 +121,6 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: CodingContext = Field(default_factory=CodingContext) - llm: BaseLLM = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index 728b49fab..8b8335517 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -27,11 +27,7 @@ import ast from pathlib import Path from typing import Literal, Optional -from pydantic import Field - from metagpt.actions.action import Action -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser, aread, awrite from metagpt.utils.pycst import merge_docstring @@ -166,7 +162,6 @@ class WriteDocstring(Action): desc: str = "Write docstring for code." context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) async def run( self, diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 8e4229991..d51c0a7be 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -17,8 +17,6 @@ import json from pathlib import Path from typing import Optional -from pydantic import Field - from metagpt.actions import Action, ActionOutput from metagpt.actions.action_node import ActionNode from metagpt.actions.fix_bug import FixBug @@ -37,9 +35,7 @@ from metagpt.const import ( PRDS_FILE_REPO, REQUIREMENT_FILENAME, ) -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import BugFixContext, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -68,7 +64,6 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): name: str = "WritePRD" content: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 9199e7536..2babe38db 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -8,17 +8,12 @@ from typing import Optional -from pydantic import Field - from metagpt.actions.action import Action -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM class WritePRDReview(Action): name: str = "" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) prd: Optional[str] = None desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index d116556ba..db8512946 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -6,12 +6,8 @@ """ from typing import List -from pydantic import Field - from metagpt.actions import Action from metagpt.actions.action_node import ActionNode -from metagpt.llm import LLM -from metagpt.provider.base_llm import BaseLLM REVIEW = ActionNode( key="Review", @@ -38,7 +34,6 @@ class WriteReview(Action): """Write a review for the given context.""" name: str = "WriteReview" - llm: BaseLLM = Field(default_factory=LLM) async def run(self, context): return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py index 888627294..b824e055e 100644 --- a/metagpt/actions/write_teaching_plan.py +++ b/metagpt/actions/write_teaching_plan.py @@ -7,20 +7,15 @@ """ from typing import Optional -from pydantic import Field - from metagpt.actions import Action from metagpt.config import CONFIG -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM class WriteTeachingPlanPart(Action): """Write Teaching Plan Part""" context: Optional[str] = None - llm: BaseLLM = Field(default_factory=LLM) topic: str = "" language: str = "Chinese" rsp: Optional[str] = None diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 321d31420..0166f5417 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -10,14 +10,10 @@ from typing import Optional -from pydantic import Field - from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO -from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, TestingContext from metagpt.utils.common import CodeParser @@ -45,7 +41,6 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): name: str = "WriteTest" context: Optional[TestingContext] = None - llm: BaseLLM = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index a2a324b41..184cd8573 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -9,12 +9,8 @@ from typing import Dict -from pydantic import Field - from metagpt.actions import Action -from metagpt.llm import LLM from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT -from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser @@ -27,7 +23,6 @@ class WriteDirectory(Action): """ name: str = "WriteDirectory" - llm: BaseLLM = Field(default_factory=LLM) language: str = "Chinese" async def run(self, topic: str, *args, **kwargs) -> Dict: @@ -54,7 +49,6 @@ class WriteContent(Action): """ name: str = "WriteContent" - llm: BaseLLM = Field(default_factory=LLM) directory: dict = dict() language: str = "Chinese" diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 44451a9e6..fe6bf991d 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -17,7 +17,7 @@ from pydantic import BaseModel, Field from metagpt.config import CONFIG from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import logger -from metagpt.provider import MetaGPTAPI +from metagpt.provider import MetaGPTLLM from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, SimpleMessage from metagpt.utils.redis import Redis @@ -123,7 +123,7 @@ class BrainMemory(BaseModel): return v async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_summarize(max_words=max_words) self.llm = llm @@ -176,7 +176,7 @@ class BrainMemory(BaseModel): async def get_title(self, llm, max_words=5, **kwargs) -> str: """Generate text title""" - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return self.history[0].content if self.history else "New" summary = await self.summarize(llm=llm, max_words=500) @@ -191,7 +191,7 @@ class BrainMemory(BaseModel): return response async def is_related(self, text1, text2, llm): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm) return await self._openai_is_related(text1=text1, text2=text2, llm=llm) @@ -213,7 +213,7 @@ class BrainMemory(BaseModel): return result async def rewrite(self, sentence: str, context: str, llm): - if isinstance(llm, MetaGPTAPI): + if isinstance(llm, MetaGPTLLM): return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm) return await self._openai_rewrite(sentence=sentence, context=context, llm=llm) diff --git a/metagpt/memory/longterm_memory.py b/metagpt/memory/longterm_memory.py index 8da6ed84a..b54653970 100644 --- a/metagpt/memory/longterm_memory.py +++ b/metagpt/memory/longterm_memory.py @@ -12,6 +12,7 @@ from pydantic import ConfigDict, Field from metagpt.logs import logger from metagpt.memory import Memory from metagpt.memory.memory_storage import MemoryStorage +from metagpt.roles.role import RoleContext from metagpt.schema import Message @@ -25,10 +26,10 @@ class LongTermMemory(Memory): model_config = ConfigDict(arbitrary_types_allowed=True) memory_storage: MemoryStorage = Field(default_factory=MemoryStorage) - rc: Optional["RoleContext"] = None + rc: Optional[RoleContext] = None msg_from_recover: bool = False - def recover_memory(self, role_id: str, rc: "RoleContext"): + def recover_memory(self, role_id: str, rc: RoleContext): messages = self.memory_storage.recover_memory(role_id) self.rc = rc if not self.memory_storage.is_initialized: diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 36d585c94..28157a4e2 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -7,21 +7,21 @@ """ from metagpt.provider.fireworks_api import FireworksLLM -from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.provider.google_gemini_api import GeminiLLM from metagpt.provider.ollama_api import OllamaLLM -from metagpt.provider.open_llm_api import OpenLLMGPTAPI +from metagpt.provider.open_llm_api import OpenLLM from metagpt.provider.openai_api import OpenAILLM -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAILLM from metagpt.provider.azure_openai_api import AzureOpenAILLM -from metagpt.provider.metagpt_api import MetaGPTAPI +from metagpt.provider.metagpt_api import MetaGPTLLM __all__ = [ "FireworksLLM", - "GeminiGPTAPI", - "OpenLLMGPTAPI", + "GeminiLLM", + "OpenLLM", "OpenAILLM", - "ZhiPuAIGPTAPI", + "ZhiPuAILLM", "AzureOpenAILLM", - "MetaGPTAPI", + "MetaGPTLLM", "OllamaLLM", ] diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index 814be2f67..bbe03774c 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -100,7 +100,7 @@ def log_info(message, **params): def log_warn(message, **params): msg = logfmt(dict(message=message, **params)) print(msg, file=sys.stderr) - logger.warn(msg) + logger.warning(msg) def logfmt(props): diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index f862e8084..795687773 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel): @register_provider(LLMProviderEnum.GEMINI) -class GeminiGPTAPI(BaseLLM): +class GeminiLLM(BaseLLM): """ Refs to `https://ai.google.dev/tutorials/python_quickstart` """ @@ -79,9 +79,6 @@ class GeminiGPTAPI(BaseLLM): except Exception as e: logger.error(f"google gemini updats costs failed! exp: {e}") - def close(self): - pass - def get_choice_text(self, resp: GenerateContentResponse) -> str: return resp.text diff --git a/metagpt/provider/metagpt_api.py b/metagpt/provider/metagpt_api.py index 2b7629895..69aa7f305 100644 --- a/metagpt/provider/metagpt_api.py +++ b/metagpt/provider/metagpt_api.py @@ -11,6 +11,6 @@ from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.METAGPT) -class MetaGPTAPI(OpenAILLM): +class MetaGPTLLM(OpenAILLM): def __init__(self): super().__init__() diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 976e95c57..7f5870702 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -31,11 +31,10 @@ class OpenLLMCostManager(CostManager): f"Max budget: ${CONFIG.max_budget:.3f} | reference " f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}" ) - CONFIG.total_cost = self.total_cost @register_provider(LLMProviderEnum.OPEN_LLM) -class OpenLLMGPTAPI(OpenAILLM): +class OpenLLM(OpenAILLM): def __init__(self): self.config: Config = CONFIG self.__init_openllm() diff --git a/metagpt/provider/postprecess/__init__.py b/metagpt/provider/postprocess/__init__.py similarity index 100% rename from metagpt/provider/postprecess/__init__.py rename to metagpt/provider/postprocess/__init__.py diff --git a/metagpt/provider/postprecess/base_postprecess_plugin.py b/metagpt/provider/postprocess/base_postprocess_plugin.py similarity index 98% rename from metagpt/provider/postprecess/base_postprecess_plugin.py rename to metagpt/provider/postprocess/base_postprocess_plugin.py index 46646be91..48130ede8 100644 --- a/metagpt/provider/postprecess/base_postprecess_plugin.py +++ b/metagpt/provider/postprocess/base_postprocess_plugin.py @@ -12,8 +12,8 @@ from metagpt.utils.repair_llm_raw_output import ( ) -class BasePostPrecessPlugin(object): - model = None # the plugin of the `model`, use to judge in `llm_postprecess` +class BasePostProcessPlugin(object): + model = None # the plugin of the `model`, use to judge in `llm_postprocess` def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]: """ diff --git a/metagpt/provider/postprecess/llm_output_postprecess.py b/metagpt/provider/postprocess/llm_output_postprocess.py similarity index 58% rename from metagpt/provider/postprecess/llm_output_postprecess.py rename to metagpt/provider/postprocess/llm_output_postprocess.py index 85405543d..f898ba3d7 100644 --- a/metagpt/provider/postprecess/llm_output_postprecess.py +++ b/metagpt/provider/postprocess/llm_output_postprocess.py @@ -4,17 +4,17 @@ from typing import Union -from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin +from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin -def llm_output_postprecess( +def llm_output_postprocess( output: str, schema: dict, req_key: str = "[/CONTENT]", model_name: str = None ) -> Union[dict, str]: """ - default use BasePostPrecessPlugin if there is not matched plugin. + default use BasePostProcessPlugin if there is not matched plugin. """ # TODO choose different model's plugin according to the model_name - postprecess_plugin = BasePostPrecessPlugin() + postprocess_plugin = BasePostProcessPlugin() - result = postprecess_plugin.run(output=output, schema=schema, req_key=req_key) + result = postprocess_plugin.run(output=output, schema=schema, req_key=req_key) return result diff --git a/metagpt/provider/zhipuai/zhipu_model_api.py b/metagpt/provider/zhipuai/zhipu_model_api.py index 19eb52530..72be0f333 100644 --- a/metagpt/provider/zhipuai/zhipu_model_api.py +++ b/metagpt/provider/zhipuai/zhipu_model_api.py @@ -33,7 +33,7 @@ class ZhiPuModelAPI(ModelAPI): zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method} """ arr = zhipu_api_url.split("/api/") - # ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke") + # ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke") return f"{arr[0]}/api", f"/{arr[1]}" @classmethod diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index df8c330b8..addbe58af 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -5,6 +5,7 @@ import json from enum import Enum +import openai import zhipuai from requests import ConnectionError from tenacity import ( @@ -31,7 +32,7 @@ class ZhiPuEvent(Enum): @register_provider(LLMProviderEnum.ZHIPUAI) -class ZhiPuAIGPTAPI(BaseLLM): +class ZhiPuAILLM(BaseLLM): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` From now, there is only one model named `chatglm_turbo` @@ -67,9 +68,6 @@ class ZhiPuAIGPTAPI(BaseLLM): except Exception as e: logger.error(f"zhipuai updats costs failed! exp: {e}") - def close(self): - pass - def get_choice_text(self, resp: dict) -> str: """get the first text of choice from llm response""" assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1] diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 29f3b0595..81815e91b 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -394,7 +394,9 @@ class Role(SerializationMixin, is_polymorphic_base=True): old_messages = [] if ignore_memory else self.rc.memory.get() self.rc.memory.add_batch(news) # Filter out messages of interest. - self.rc.news = [n for n in news if n.cause_by in self.rc.watch and n not in old_messages] + self.rc.news = [ + n for n in news if (n.cause_by in self.rc.watch or self.name in n.send_to) and n not in old_messages + ] self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None # record the latest observed msg # Design Rules: diff --git a/metagpt/schema.py b/metagpt/schema.py index 41303ea46..5dde0ee46 100644 --- a/metagpt/schema.py +++ b/metagpt/schema.py @@ -174,7 +174,7 @@ class Message(BaseModel): role: str = "user" # system / user / assistant cause_by: str = Field(default="", validate_default=True) sent_from: str = Field(default="", validate_default=True) - send_to: set = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) + send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True) @field_validator("id", mode="before") @classmethod diff --git a/tests/metagpt/memory/test_longterm_memory.py b/tests/metagpt/memory/test_longterm_memory.py index ac33552b3..c915a6610 100644 --- a/tests/metagpt/memory/test_longterm_memory.py +++ b/tests/metagpt/memory/test_longterm_memory.py @@ -20,6 +20,10 @@ def test_ltm_search(): assert len(CONFIG.openai_api_key) > 20 role_id = "UTUserLtm(Product Manager)" + from metagpt.environment import Environment + + Environment + RoleContext.model_rebuild() rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"}) ltm = LongTermMemory() ltm.recover_memory(role_id, rc) diff --git a/tests/metagpt/memory/test_memory_storage.py b/tests/metagpt/memory/test_memory_storage.py index f1cc12aac..0eb1069d5 100644 --- a/tests/metagpt/memory/test_memory_storage.py +++ b/tests/metagpt/memory/test_memory_storage.py @@ -24,7 +24,7 @@ def test_idea_message(): role_id = "UTUser1(Product Manager)" message = Message(role="User", content=idea, cause_by=UserRequirement) - shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/")) + shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True) memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) @@ -58,7 +58,7 @@ def test_actionout_message(): content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD ) # WritePRD as test action - shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/")) + shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True) memory_storage: MemoryStorage = MemoryStorage() messages = memory_storage.recover_memory(role_id) diff --git a/tests/metagpt/provider/postprocess/__init__.py b/tests/metagpt/provider/postprocess/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/provider/postprocess/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py new file mode 100644 index 000000000..824bb88f3 --- /dev/null +++ b/tests/metagpt/provider/postprocess/test_base_postprocess_plugin.py @@ -0,0 +1,31 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + + +from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin + +raw_output = """ +[CONTENT] +{ +"Original Requirements": "xxx" +} +[/CONTENT] +""" +raw_schema = { + "title": "prd", + "type": "object", + "properties": { + "Original Requirements": {"title": "Original Requirements", "type": "string"}, + }, + "required": [ + "Original Requirements", + ], +} + + +def test_llm_post_process_plugin(): + post_process_plugin = BasePostProcessPlugin() + + output = post_process_plugin.run(output=raw_output, schema=raw_schema) + assert "Original Requirements" in output diff --git a/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py new file mode 100644 index 000000000..40457b186 --- /dev/null +++ b/tests/metagpt/provider/postprocess/test_llm_output_postprocess.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + + +from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess +from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import ( + raw_output, + raw_schema, +) + + +def test_llm_output_postprocess(): + output = llm_output_postprocess(output=raw_output, schema=raw_schema) + assert "Original Requirements" in output diff --git a/tests/metagpt/provider/test_anthropic_api.py b/tests/metagpt/provider/test_anthropic_api.py index 4d3de5320..4410717a9 100644 --- a/tests/metagpt/provider/test_anthropic_api.py +++ b/tests/metagpt/provider/test_anthropic_api.py @@ -2,28 +2,33 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of Claude2 -import pytest +import pytest +from anthropic.resources.completions import Completion + +from metagpt.config import CONFIG from metagpt.provider.anthropic_api import Claude2 +CONFIG.anthropic_api_key = "xxx" + prompt = "who are you" resp = "I'am Claude2" -def mock_llm_ask(self, msg: str) -> str: - return resp +def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: + return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") -async def mock_llm_aask(self, msg: str) -> str: - return resp +async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion: + return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion") def test_claude2_ask(mocker): - mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask) + mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create) assert resp == Claude2().ask(prompt) @pytest.mark.asyncio async def test_claude2_aask(mocker): - mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask) + mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create) assert resp == await Claude2().aask(prompt) diff --git a/tests/metagpt/provider/test_azure_openai_api.py b/tests/metagpt/provider/test_azure_openai_api.py index a1f1effeb..f36740e65 100644 --- a/tests/metagpt/provider/test_azure_openai_api.py +++ b/tests/metagpt/provider/test_azure_openai_api.py @@ -1,20 +1,14 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -@Time : 2023/12/28 -@Author : mashenquan -@File : test_azure_openai.py -""" -from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.llm import LLM +# @Desc : -def test_llm(): - # Prerequisites - assert CONFIG.DEPLOYMENT_NAME and CONFIG.DEPLOYMENT_NAME != "YOUR_DEPLOYMENT_NAME" - assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_AZURE_API_KEY" - assert CONFIG.OPENAI_API_VERSION - assert CONFIG.OPENAI_BASE_URL +from metagpt.config import CONFIG +from metagpt.provider.azure_openai_api import AzureOpenAILLM - llm = LLM(provider=LLMProviderEnum.AZURE_OPENAI) - assert llm +CONFIG.OPENAI_API_VERSION = "xx" +CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value + + +def test_azure_openai_api(): + _ = AzureOpenAILLM() diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index b7f728e73..d48686eaa 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -8,13 +8,22 @@ from openai.types.chat.chat_completion import ( ChatCompletionMessage, Choice, ) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta from openai.types.completion_usage import CompletionUsage +from metagpt.config import CONFIG from metagpt.provider.fireworks_api import ( MODEL_GRADE_TOKEN_COSTS, FireworksCostManager, FireworksLLM, ) +from metagpt.utils.cost_manager import Costs + +CONFIG.fireworks_api_key = "xxx" +CONFIG.max_budget = 10 +CONFIG.calc_usage = True resp_content = "I'm fireworks" default_resp = ChatCompletion( @@ -25,14 +34,30 @@ default_resp = ChatCompletion( choices=[ Choice( finish_reason="stop", - logprobs=None, index=0, message=ChatCompletionMessage(role="assistant", content=resp_content), + logprobs=None, ) ], usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202), ) +default_resp_chunk = ChatCompletionChunk( + id=default_resp.id, + model=default_resp.model, + object="chat.completion.chunk", + created=default_resp.created, + choices=[ + AChoice( + delta=ChoiceDelta(content=resp_content, role="assistant"), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], + usage=dict(default_resp.usage), +) + prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] @@ -47,29 +72,37 @@ def test_fireworks_costmanager(): assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat") assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat") - -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion: - return default_resp + cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat") + assert cost_manager.total_cost == 0.5 -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion: - return default_resp +async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: + if stream: + class Iterator(object): + async def __aiter__(self): + yield default_resp_chunk -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return default_resp.choices[0].message.content + return Iterator() + else: + return default_resp @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireworksLLM.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.fireworks_api.FireworksLLM._achat_completion", mock_llm_acompletion) - mocker.patch( - "metagpt.provider.fireworks_api.FireworksLLM._achat_completion_stream", mock_llm_achat_completion_stream - ) - fireworks_gpt = FireworksLLM() + mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - resp = await fireworks_gpt.acompletion(messages, stream=False) + fireworks_gpt = FireworksLLM() + fireworks_gpt.model = "llama-v2-13b-chat" + + fireworks_gpt._update_costs( + usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000) + ) + assert fireworks_gpt.get_costs() == Costs( + total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0 + ) + + resp = await fireworks_gpt.acompletion(messages) assert resp.choices[0].message.content in resp_content resp = await fireworks_gpt.aask(prompt_msg, stream=False) diff --git a/tests/metagpt/provider/test_general_api_base.py b/tests/metagpt/provider/test_general_api_base.py new file mode 100644 index 000000000..ae768ce95 --- /dev/null +++ b/tests/metagpt/provider/test_general_api_base.py @@ -0,0 +1,99 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import os +from typing import AsyncGenerator, Generator, Iterator, Tuple, Union + +import aiohttp +import pytest +import requests +from openai import OpenAIError + +from metagpt.provider.general_api_base import ( + APIRequestor, + ApiType, + OpenAIResponse, + _make_session, + _requests_proxies_arg, + log_debug, + log_info, + log_warn, + parse_stream, + parse_stream_helper, +) + + +def test_basic(): + _ = ApiType.from_str("azure") + _ = ApiType.from_str("azuread") + _ = ApiType.from_str("openai") + with pytest.raises(OpenAIError): + _ = ApiType.from_str("xx") + + os.environ.setdefault("LLM_LOG", "debug") + log_debug("debug") + log_warn("warn") + log_info("info") + + +def test_openai_response(): + resp = OpenAIResponse(data=[], headers={"retry-after": 3}) + assert resp.request_id is None + assert resp.retry_after == 3 + assert resp.operation_location is None + assert resp.organization is None + assert resp.response_ms is None + + +def test_proxy(): + assert _requests_proxies_arg(proxy=None) is None + + proxy = "127.0.0.1:80" + assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy} + proxy_dict = {"http": proxy} + assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict + proxy_dict = {"https": proxy} + assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict + + assert _make_session() is not None + + +def test_parse_stream(): + assert parse_stream_helper(None) is None + assert parse_stream_helper(b"data: [DONE]") is None + assert parse_stream_helper(b"data: test") == "test" + assert parse_stream_helper(b"test") is None + for line in parse_stream([b"data: test"]): + assert line == "test" + + +api_requestor = APIRequestor(base_url="http://www.baidu.com") + + +def mock_interpret_response( + self, result: requests.Response, stream: bool +) -> Tuple[Union[bytes, Iterator[Generator]], bytes]: + return b"baidu", False + + +async def mock_interpret_async_response( + self, result: aiohttp.ClientResponse, stream: bool +) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]: + return b"baidu", True + + +def test_api_requestor(mocker): + mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response) + resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") + + resp, _, _ = api_requestor.request(method="post", url="/s?wd=baidu") + + +@pytest.mark.asyncio +async def test_async_api_requestor(mocker): + mocker.patch( + "metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response + ) + resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu") + resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu") diff --git a/tests/metagpt/provider/test_general_api_requestor.py b/tests/metagpt/provider/test_general_api_requestor.py index 28130fa65..dcbcc0567 100644 --- a/tests/metagpt/provider/test_general_api_requestor.py +++ b/tests/metagpt/provider/test_general_api_requestor.py @@ -4,11 +4,24 @@ import pytest -from metagpt.provider.general_api_requestor import GeneralAPIRequestor +from metagpt.provider.general_api_requestor import ( + GeneralAPIRequestor, + parse_stream, + parse_stream_helper, +) api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com") +def test_parse_stream(): + assert parse_stream_helper(None) is None + assert parse_stream_helper(b"data: [DONE]") is None + assert parse_stream_helper(b"data: test") == b"test" + assert parse_stream_helper(b"test") is None + for line in parse_stream([b"data: test"]): + assert line == b"test" + + def test_api_requestor(): resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu") assert b"baidu" in resp diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index 60f50c9ad..ffd10df7f 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -6,8 +6,13 @@ from abc import ABC from dataclasses import dataclass import pytest +from google.ai import generativelanguage as glm +from google.generativeai.types import content_types -from metagpt.provider.google_gemini_api import GeminiGPTAPI +from metagpt.config import CONFIG +from metagpt.provider.google_gemini_api import GeminiLLM + +CONFIG.gemini_api_key = "xx" @dataclass @@ -21,28 +26,52 @@ resp_content = "I'm gemini from google" default_resp = MockGeminiResponse(text=resp_content) -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse: +def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + return glm.CountTokensResponse(total_tokens=20) + + +async def mock_gemini_count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse: + return glm.CountTokensResponse(total_tokens=20) + + +def mock_gemini_generate_content(self, **kwargs) -> MockGeminiResponse: return default_resp -async def mock_llm_acompletion( - self, messgaes: list[dict], stream: bool = False, timeout: int = 60 -) -> MockGeminiResponse: - return default_resp +async def mock_gemini_generate_content_async(self, stream: bool = False, **kwargs) -> MockGeminiResponse: + if stream: + class Iterator(object): + async def __aiter__(self): + yield default_resp -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content + return Iterator() + else: + return default_resp @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion) + mocker.patch("metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens", mock_gemini_count_tokens) mocker.patch( - "metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream + "metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens_async", mock_gemini_count_tokens_async ) - gemini_gpt = GeminiGPTAPI() + mocker.patch("google.generativeai.generative_models.GenerativeModel.generate_content", mock_gemini_generate_content) + mocker.patch( + "google.generativeai.generative_models.GenerativeModel.generate_content_async", + mock_gemini_generate_content_async, + ) + + gemini_gpt = GeminiLLM() + + assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]} + assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]} + + usage = gemini_gpt.get_usage(messages, resp_content) + assert usage == {"prompt_tokens": 20, "completion_tokens": 20} + + resp = gemini_gpt.completion(messages) + assert resp == default_resp resp = await gemini_gpt.acompletion(messages) assert resp.text == default_resp.text diff --git a/tests/metagpt/provider/test_metagpt_llm_api.py b/tests/metagpt/provider/test_metagpt_llm_api.py index f454b08a7..8fce6b6b0 100644 --- a/tests/metagpt/provider/test_metagpt_llm_api.py +++ b/tests/metagpt/provider/test_metagpt_llm_api.py @@ -5,11 +5,11 @@ @Author : mashenquan @File : test_metagpt_llm_api.py """ -from metagpt.provider.metagpt_api import MetaGPTAPI +from metagpt.provider.metagpt_api import MetaGPTLLM def test_metagpt(): - llm = MetaGPTAPI() + llm = MetaGPTLLM() assert llm diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index d19e23e17..1c604768e 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -2,6 +2,9 @@ # -*- coding: utf-8 -*- # @Desc : the unittest of ollama api +import json +from typing import Any, Tuple + import pytest from metagpt.config import CONFIG @@ -14,25 +17,33 @@ resp_content = "I'm ollama" default_resp = {"message": {"role": "assistant", "content": resp_content}} CONFIG.ollama_api_base = "http://xxx" +CONFIG.max_budget = 10 -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: - return default_resp +async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]: + if stream: + class Iterator(object): + events = [ + b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}', + b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}', + ] -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: - return default_resp + async def __aiter__(self): + for event in self.events: + yield event - -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content + return Iterator(), None, None + else: + raw_default_resp = default_resp.copy() + raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20}) + return json.dumps(raw_default_resp).encode(), None, None @pytest.mark.asyncio async def test_gemini_acompletion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion) - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream) + mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest) + ollama_gpt = OllamaLLM() resp = await ollama_gpt.acompletion(messages) diff --git a/tests/metagpt/provider/test_open_llm_api.py b/tests/metagpt/provider/test_open_llm_api.py index b8be68504..85069c5e1 100644 --- a/tests/metagpt/provider/test_open_llm_api.py +++ b/tests/metagpt/provider/test_open_llm_api.py @@ -1,25 +1,95 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -""" -@Time : 2023/12/28 -@Author : mashenquan -@File : test_open_llm_api.py -""" -from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.llm import LLM -from metagpt.provider.open_llm_api import OpenLLMCostManager +# @Desc : + +import pytest +from openai.types.chat.chat_completion import ( + ChatCompletion, + ChatCompletionMessage, + Choice, +) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk +from openai.types.chat.chat_completion_chunk import Choice as AChoice +from openai.types.chat.chat_completion_chunk import ChoiceDelta +from openai.types.completion_usage import CompletionUsage + +from metagpt.config import CONFIG +from metagpt.provider.open_llm_api import OpenLLM +from metagpt.utils.cost_manager import Costs + +CONFIG.max_budget = 10 +CONFIG.calc_usage = True + +resp_content = "I'm llama2" +default_resp = ChatCompletion( + id="cmpl-a6652c1bb181caae8dd19ad8", + model="llama-v2-13b-chat", + object="chat.completion", + created=1703302755, + choices=[ + Choice( + finish_reason="stop", + index=0, + message=ChatCompletionMessage(role="assistant", content=resp_content), + logprobs=None, + ) + ], +) + +default_resp_chunk = ChatCompletionChunk( + id=default_resp.id, + model=default_resp.model, + object="chat.completion.chunk", + created=default_resp.created, + choices=[ + AChoice( + delta=ChoiceDelta(content=resp_content, role="assistant"), + finish_reason="stop", + index=0, + logprobs=None, + ) + ], +) + +prompt_msg = "who are you" +messages = [{"role": "user", "content": prompt_msg}] -def test_llm(): - llm = LLM(provider=LLMProviderEnum.OPEN_LLM) - assert llm +async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk: + if stream: + + class Iterator(object): + async def __aiter__(self): + yield default_resp_chunk + + return Iterator() + else: + return default_resp -def test_cost(): - # Prerequisites - CONFIG.max_budget = 10 +@pytest.mark.asyncio +async def test_openllm_acompletion(mocker): + mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create) - cost = OpenLLMCostManager() - cost.update_cost(prompt_tokens=10, completion_tokens=1, model="gpt-35-turbo") - assert cost.get_total_prompt_tokens() > 0 - assert cost.get_total_completion_tokens() > 0 + openllm_gpt = OpenLLM() + openllm_gpt.model = "llama-v2-13b-chat" + + openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200)) + assert openllm_gpt.get_costs() == Costs( + total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0 + ) + + resp = await openllm_gpt.acompletion(messages) + assert resp.choices[0].message.content in resp_content + + resp = await openllm_gpt.aask(prompt_msg, stream=False) + assert resp == resp_content + + resp = await openllm_gpt.acompletion_text(messages, stream=False) + assert resp == resp_content + + resp = await openllm_gpt.acompletion_text(messages, stream=True) + assert resp == resp_content + + resp = await openllm_gpt.aask(prompt_msg) + assert resp == resp_content diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index cb86dfcf9..ddc290731 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -2,9 +2,14 @@ from unittest.mock import Mock import pytest +from metagpt.config import CONFIG from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import UserMessage +CONFIG.openai_proxy = None + +print("openai_api_key ", CONFIG.openai_api_key) + @pytest.mark.asyncio async def test_aask_code(): diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 6cc87741e..6d5a0e1f6 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,24 +4,31 @@ import pytest -from metagpt.provider.spark_api import SparkLLM +from metagpt.config import CONFIG +from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM + +CONFIG.spark_appid = "xxx" +CONFIG.spark_api_secret = "xxx" +CONFIG.spark_api_key = "xxx" +CONFIG.domain = "xxxxxx" +CONFIG.spark_url = "xxxx" prompt_msg = "who are you" resp_content = "I'm Spark" -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str: - return resp_content +def test_get_msg_from_web(): + get_msg_from_web = GetMessageFromWeb(text=prompt_msg) + assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx" -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str: +def mock_spark_get_msg_from_web_run(self) -> str: return resp_content @pytest.mark.asyncio async def test_spark_acompletion(mocker): - mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion) + mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run) spark_gpt = SparkLLM() resp = await spark_gpt.acompletion([]) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index 06f2cba62..826e706e8 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -1,41 +1,71 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- -# @Desc : the unittest of ZhiPuAIGPTAPI +# @Desc : the unittest of ZhiPuAILLM import pytest +from zhipuai.utils.sse_client import Event from metagpt.config import CONFIG -from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI +from metagpt.provider.zhipuai_api import ZhiPuAILLM -CONFIG.zhipuai_api_key = "xxx" +CONFIG.zhipuai_api_key = "xxx.xxx" prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] resp_content = "I'm chatglm-turbo" -default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}} +default_resp = { + "code": 200, + "data": { + "choices": [{"role": "assistant", "content": resp_content}], + "usage": {"prompt_tokens": 20, "completion_tokens": 20}, + }, +} -def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict: +def mock_zhipuai_invoke(**kwargs) -> dict: return default_resp -async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict: +async def mock_zhipuai_ainvoke(**kwargs) -> dict: return default_resp -async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: - return resp_content +async def mock_zhipuai_asse_invoke(**kwargs): + class MockResponse(object): + async def _aread(self): + class Iterator(object): + events = [ + Event(id="xxx", event="add", data=resp_content, retry=0), + Event( + id="xxx", + event="finish", + data="", + meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}', + ), + ] + + async def __aiter__(self): + for event in self.events: + yield event + + async for chunk in Iterator(): + yield chunk + + async def async_events(self): + async for chunk in self._aread(): + yield chunk + + return MockResponse() @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion) - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion) - mocker.patch( - "metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream - ) - zhipu_gpt = ZhiPuAIGPTAPI() + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke) + mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke) + + zhipu_gpt = ZhiPuAILLM() resp = await zhipu_gpt.acompletion(messages) assert resp["data"]["choices"][0]["content"] == resp_content @@ -53,11 +83,11 @@ async def test_zhipuai_acompletion(mocker): assert resp == resp_content -def test_zhipuai_proxy(mocker): +def test_zhipuai_proxy(): import openai from metagpt.config import CONFIG CONFIG.openai_proxy = "http://127.0.0.1:8080" - _ = ZhiPuAIGPTAPI() + _ = ZhiPuAILLM() assert openai.proxy == CONFIG.openai_proxy diff --git a/tests/metagpt/provider/zhipuai/__init__.py b/tests/metagpt/provider/zhipuai/__init__.py new file mode 100644 index 000000000..2bcf8efd0 --- /dev/null +++ b/tests/metagpt/provider/zhipuai/__init__.py @@ -0,0 +1,3 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : diff --git a/tests/metagpt/provider/zhipuai/test_async_sse_client.py b/tests/metagpt/provider/zhipuai/test_async_sse_client.py new file mode 100644 index 000000000..9e5bd5f2e --- /dev/null +++ b/tests/metagpt/provider/zhipuai/test_async_sse_client.py @@ -0,0 +1,18 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +import pytest + +from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient + + +@pytest.mark.asyncio +async def test_async_sse_client(): + class Iterator(object): + async def __aiter__(self): + yield b"data: test_value" + + async_sse_client = AsyncSSEClient(event_source=Iterator()) + async for event in async_sse_client.async_events(): + assert event.data, "test_value" diff --git a/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py new file mode 100644 index 000000000..83ae2de60 --- /dev/null +++ b/tests/metagpt/provider/zhipuai/test_zhipu_model_api.py @@ -0,0 +1,41 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +# @Desc : + +from typing import Any, Tuple + +import pytest +import zhipuai +from zhipuai.model_api.api import InvokeType +from zhipuai.utils.http_client import headers as zhipuai_default_headers + +from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI + +api_key = "xxx.xxx" +zhipuai.api_key = api_key + +default_resp = {"result": "test response"} + + +async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]: + return default_resp, None, None + + +@pytest.mark.asyncio +async def test_zhipu_model_api(mocker): + header = ZhiPuModelAPI.get_header() + zhipuai_default_headers.update({"Authorization": api_key}) + assert header == zhipuai_default_headers + + sse_header = ZhiPuModelAPI.get_sse_header() + assert len(sse_header["Authorization"]) == 191 + + url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"}) + assert url_prefix == "https://open.bigmodel.cn/api" + assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke" + + mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest) + result = await ZhiPuModelAPI.arequest( + InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"} + ) + assert result == default_resp diff --git a/tests/metagpt/serialize_deserialize/test_action.py b/tests/metagpt/serialize_deserialize/test_action.py index b3206696b..677988e2f 100644 --- a/tests/metagpt/serialize_deserialize/test_action.py +++ b/tests/metagpt/serialize_deserialize/test_action.py @@ -28,5 +28,5 @@ async def test_action_deserialize(): new_action = Action(**serialized_data) assert new_action.name == "" - assert new_action.llm == LLM() + assert isinstance(new_action.llm, type(LLM())) assert len(await new_action._aask("who are you")) > 0 diff --git a/tests/metagpt/serialize_deserialize/test_environment.py b/tests/metagpt/serialize_deserialize/test_environment.py index 557c3f4cd..5a68288a6 100644 --- a/tests/metagpt/serialize_deserialize/test_environment.py +++ b/tests/metagpt/serialize_deserialize/test_environment.py @@ -13,6 +13,7 @@ from metagpt.schema import Message from metagpt.utils.common import any_to_str from tests.metagpt.serialize_deserialize.test_serdeser_base import ( ActionOK, + ActionRaise, RoleC, serdeser_path, ) @@ -55,9 +56,9 @@ def test_environment_serdeser(): assert len(new_env.roles) == 1 assert list(new_env.roles.values())[0].states == list(environment.roles.values())[0].states - assert list(new_env.roles.values())[0].actions == list(environment.roles.values())[0].actions assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK) assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK + assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise def test_environment_serdeser_v2(): diff --git a/tests/metagpt/serialize_deserialize/test_write_code.py b/tests/metagpt/serialize_deserialize/test_write_code.py index 2fb669a6b..cb262bb45 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code.py +++ b/tests/metagpt/serialize_deserialize/test_write_code.py @@ -6,7 +6,6 @@ import pytest from metagpt.actions import WriteCode -from metagpt.llm import LLM from metagpt.schema import CodingContext, Document @@ -28,5 +27,4 @@ async def test_write_code_deserialize(): new_action = WriteCode(**serialized_data) assert new_action.name == "WriteCode" - assert new_action.llm == LLM() await action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_code_review.py b/tests/metagpt/serialize_deserialize/test_write_code_review.py index e9ad4b858..991b3c13b 100644 --- a/tests/metagpt/serialize_deserialize/test_write_code_review.py +++ b/tests/metagpt/serialize_deserialize/test_write_code_review.py @@ -5,7 +5,6 @@ import pytest from metagpt.actions import WriteCodeReview -from metagpt.llm import LLM from metagpt.schema import CodingContext, Document @@ -28,5 +27,4 @@ def div(a: int, b: int = 0): new_action = WriteCodeReview(**serialized_data) assert new_action.name == "WriteCodeReview" - assert new_action.llm == LLM() await new_action.run() diff --git a/tests/metagpt/serialize_deserialize/test_write_design.py b/tests/metagpt/serialize_deserialize/test_write_design.py index d556c144d..a2fce8047 100644 --- a/tests/metagpt/serialize_deserialize/test_write_design.py +++ b/tests/metagpt/serialize_deserialize/test_write_design.py @@ -5,7 +5,6 @@ import pytest from metagpt.actions import WriteDesign, WriteTasks -from metagpt.llm import LLM def test_write_design_serialize(): @@ -28,7 +27,6 @@ async def test_write_design_deserialize(): serialized_data = action.model_dump() new_action = WriteDesign(**serialized_data) assert new_action.name == "" - assert new_action.llm == LLM() await new_action.run(with_messages="write a cli snake game") @@ -38,5 +36,4 @@ async def test_write_task_deserialize(): serialized_data = action.model_dump() new_action = WriteTasks(**serialized_data) assert new_action.name == "CreateTasks" - assert new_action.llm == LLM() await new_action.run(with_messages="write a cli snake game") diff --git a/tests/metagpt/serialize_deserialize/test_write_prd.py b/tests/metagpt/serialize_deserialize/test_write_prd.py index 79b9a8677..890e2438b 100644 --- a/tests/metagpt/serialize_deserialize/test_write_prd.py +++ b/tests/metagpt/serialize_deserialize/test_write_prd.py @@ -6,7 +6,6 @@ import pytest from metagpt.actions import WritePRD -from metagpt.llm import LLM from metagpt.schema import Message @@ -22,7 +21,6 @@ async def test_action_deserialize(): action = WritePRD() serialized_data = action.model_dump() new_action = WritePRD(**serialized_data) - assert new_action.name == "" - assert new_action.llm == LLM() + assert new_action.name == "WritePRD" action_output = await new_action.run(with_messages=Message(content="write a cli snake game")) assert len(action_output.content) > 0 diff --git a/tests/metagpt/utils/test_serialize.py b/tests/metagpt/utils/test_serialize.py index f027d53f8..0ba3a8d41 100644 --- a/tests/metagpt/utils/test_serialize.py +++ b/tests/metagpt/utils/test_serialize.py @@ -4,7 +4,7 @@ @Desc : the unittest of serialize """ -from typing import List, Tuple +from typing import List from metagpt.actions import WritePRD from metagpt.actions.action_node import ActionNode @@ -27,7 +27,7 @@ def test_actionoutout_schema_to_mapping(): "properties": {"field": {"title": "field", "type": "array", "items": {"type": "string"}}}, } mapping = actionoutout_schema_to_mapping(schema) - assert mapping["field"] == (List[str], ...) + assert mapping["field"] == (list[str], ...) schema = { "title": "test", @@ -46,7 +46,7 @@ def test_actionoutout_schema_to_mapping(): }, } mapping = actionoutout_schema_to_mapping(schema) - assert mapping["field"] == (List[Tuple[str, str]], ...) + assert mapping["field"] == (list[list[str]], ...) assert True, True