diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index c8c901eb0..576990a83 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, Field from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import ( CodeSummarizeContext, CodingContext, @@ -27,7 +27,7 @@ action_subclass_registry = {} class Action(BaseModel): name: str = "" - llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = "" prefix = "" # aask*时会加上prefix,作为system_message desc = "" # for skill manager diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 63f46ad45..b554f15dd 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -15,7 +15,7 @@ from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.config import CONFIG -from metagpt.llm import BaseGPTAPI +from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser, general_after_log @@ -60,7 +60,7 @@ class ActionNode: # Action Context context: str # all the context, including all necessary info - llm: BaseGPTAPI # LLM with aask interface + llm: BaseLLM # LLM with aask interface children: dict[str, "ActionNode"] # Action Input diff --git a/metagpt/actions/clone_function.py b/metagpt/actions/clone_function.py index 429f04286..7053df97b 100644 --- a/metagpt/actions/clone_function.py +++ b/metagpt/actions/clone_function.py @@ -5,7 +5,7 @@ from pydantic import Field from metagpt.actions.write_code import WriteCode from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message from metagpt.utils.exceptions import handle_exception from metagpt.utils.highlight import highlight @@ -33,7 +33,7 @@ def run(*args) -> pd.DataFrame: class CloneFunction(WriteCode): name: str = "CloneFunction" context: list[Message] = [] - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) def _save(self, code_path, code): if isinstance(code_path, str): diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 9dc6862f9..1a7c3a7c8 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -15,7 +15,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO -from metagpt.llm import LLM, BaseGPTAPI +from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser @@ -52,7 +52,7 @@ Now you should start rewriting the code: class DebugError(Action): name: str = "DebugError" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 055365421..8535d63b1 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -27,7 +27,7 @@ from metagpt.const import ( ) from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, Documents, Message from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file @@ -44,7 +44,7 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) desc: str = ( "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " diff --git a/metagpt/actions/design_api_review.py b/metagpt/actions/design_api_review.py index 0ff522fe8..6ea76e2fc 100644 --- a/metagpt/actions/design_api_review.py +++ b/metagpt/actions/design_api_review.py @@ -12,13 +12,13 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM class DesignReview(Action): name: str = "DesignReview" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, prd, api_design): prompt = ( diff --git a/metagpt/actions/execute_task.py b/metagpt/actions/execute_task.py index b11f361b0..8577ee275 100644 --- a/metagpt/actions/execute_task.py +++ b/metagpt/actions/execute_task.py @@ -10,14 +10,14 @@ from pydantic import Field from metagpt.actions import Action from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message class ExecuteTask(Action): name: str = "ExecuteTask" context: list[Message] = [] - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, *args, **kwargs): pass diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index 87f81371e..94288d5be 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -26,7 +26,7 @@ from metagpt.prompts.invoice_ocr import ( EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT, ) -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser from metagpt.utils.file import File @@ -42,7 +42,7 @@ class InvoiceOCR(Action): name: str = "InvoiceOCR" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @staticmethod async def _check_file_type(file_path: Path) -> str: @@ -132,7 +132,7 @@ class GenerateTable(Action): name: str = "GenerateTable" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) language: str = "ch" async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]: @@ -177,7 +177,7 @@ class ReplyQuestion(Action): name: str = "ReplyQuestion" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) language: str = "ch" async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str: diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 696dc9a89..ad82e56dc 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -17,7 +17,7 @@ from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository @@ -28,7 +28,7 @@ class PrepareDocuments(Action): name: str = "PrepareDocuments" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) def _init_repo(self): """Initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 095881e60..7eda89130 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -27,7 +27,7 @@ from metagpt.const import ( ) from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository @@ -43,7 +43,7 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index c47a77bdd..a6cc7cc22 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -11,7 +11,7 @@ from metagpt.actions import Action from metagpt.config import CONFIG from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.tools.search_engine import SearchEngine from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType from metagpt.utils.common import OutputParser @@ -82,7 +82,7 @@ class CollectLinks(Action): name: str = "CollectLinks" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) desc: str = "Collect links from a search engine." search_engine: SearchEngine = Field(default_factory=SearchEngine) rank_func: Union[Callable[[list[str]], None], None] = None @@ -177,7 +177,7 @@ class WebBrowseAndSummarize(Action): name: str = "WebBrowseAndSummarize" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) desc: str = "Explore the web and provide summaries of articles and webpages." browse_func: Union[Callable[[list[str]], None], None] = None web_browser_engine: WebBrowserEngine = Field( @@ -248,7 +248,7 @@ class ConductResearch(Action): name: str = "ConductResearch" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index bca9b337d..22d345b85 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -22,7 +22,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.config import CONFIG -from metagpt.llm import LLM, BaseGPTAPI +from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.exceptions import handle_exception @@ -79,7 +79,7 @@ standard errors: class RunCode(Action): name: str = "RunCode" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @classmethod @handle_exception diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 9fd392a5c..615576d76 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -14,7 +14,7 @@ from metagpt.actions import Action from metagpt.config import CONFIG, Config from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -109,7 +109,7 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) config: None = Field(default_factory=Config) engine: Optional[SearchEngineType] = CONFIG.search_engine search_func: Optional[Any] = None diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 2d1cd4d3d..4025e0964 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO -from metagpt.llm import LLM, BaseGPTAPI +from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext from metagpt.utils.file_repository import FileRepository @@ -95,7 +95,7 @@ flowchart TB class SummarizeCode(Action): name: str = "SummarizeCode" context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 4d0690e0f..e3086f03c 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -31,7 +31,7 @@ from metagpt.const import ( ) from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -90,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Document = Field(default_factory=Document) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index b0e7904e3..a8ed0fd01 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -16,7 +16,7 @@ from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser @@ -123,7 +123,7 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: CodingContext = Field(default_factory=CodingContext) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index 1c27a9433..68856c360 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -28,7 +28,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser from metagpt.utils.pycst import merge_docstring @@ -163,7 +163,7 @@ class WriteDocstring(Action): desc: str = "Write docstring for code." context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run( self, diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 47e02b699..5b1108244 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -38,7 +38,7 @@ from metagpt.const import ( ) from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import BugFixContext, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -67,7 +67,7 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): name: str = "" content: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 6ed73b6a2..0241f192f 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -12,13 +12,13 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM class WritePRDReview(Action): name: str = "" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) prd: Optional[str] = None desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" prd_review_prompt_template: str = """ diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index 646f44aeb..d116556ba 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -11,7 +11,7 @@ from pydantic import Field from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM REVIEW = ActionNode( key="Review", @@ -38,7 +38,7 @@ class WriteReview(Action): """Write a review for the given context.""" name: str = "WriteReview" - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, context): return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py index d889fdbe3..888627294 100644 --- a/metagpt/actions/write_teaching_plan.py +++ b/metagpt/actions/write_teaching_plan.py @@ -13,14 +13,14 @@ from metagpt.actions import Action from metagpt.config import CONFIG from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM class WriteTeachingPlanPart(Action): """Write Teaching Plan Part""" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) topic: str = "" language: str = "Chinese" rsp: Optional[str] = None diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 850606ca8..321d31420 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -17,7 +17,7 @@ from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, TestingContext from metagpt.utils.common import CodeParser @@ -45,7 +45,7 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): name: str = "WriteTest" context: Optional[TestingContext] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index f33a6b114..a2a324b41 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -14,7 +14,7 @@ from pydantic import Field from metagpt.actions import Action from metagpt.llm import LLM from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser @@ -27,7 +27,7 @@ class WriteDirectory(Action): """ name: str = "WriteDirectory" - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) language: str = "Chinese" async def run(self, topic: str, *args, **kwargs) -> Dict: @@ -54,7 +54,7 @@ class WriteContent(Action): """ name: str = "WriteContent" - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) directory: dict = dict() language: str = "Chinese" diff --git a/metagpt/llm.py b/metagpt/llm.py index f1cb98dae..76dd5a0f8 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -9,14 +9,14 @@ from typing import Optional from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.human_provider import HumanProvider from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error -def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI: +def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM: """get the default llm provider""" if provider is None: provider = CONFIG.get_default_llm_provider_enum() diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 347e3e0fb..0833d71a1 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -18,7 +18,7 @@ from metagpt.config import CONFIG from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import logger from metagpt.provider import MetaGPTAPI -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, SimpleMessage from metagpt.utils.redis import Redis @@ -31,7 +31,7 @@ class BrainMemory(BaseModel): is_dirty: bool = False last_talk: str = None cacheable: bool = True - llm: Optional[BaseGPTAPI] = None + llm: Optional[BaseLLM] = None def add_talk(self, msg: Message): """ diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 769c8e7b8..36d585c94 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -6,22 +6,22 @@ @File : __init__.py """ -from metagpt.provider.fireworks_api import FireWorksGPTAPI +from metagpt.provider.fireworks_api import FireworksLLM from metagpt.provider.google_gemini_api import GeminiGPTAPI -from metagpt.provider.ollama_api import OllamaGPTAPI +from metagpt.provider.ollama_api import OllamaLLM from metagpt.provider.open_llm_api import OpenLLMGPTAPI -from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.openai_api import OpenAILLM from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -from metagpt.provider.azure_openai_api import AzureOpenAIGPTAPI +from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.metagpt_api import MetaGPTAPI __all__ = [ - "FireWorksGPTAPI", + "FireworksLLM", "GeminiGPTAPI", "OpenLLMGPTAPI", - "OpenAIGPTAPI", + "OpenAILLM", "ZhiPuAIGPTAPI", - "AzureOpenAIGPTAPI", + "AzureOpenAILLM", "MetaGPTAPI", - "OllamaGPTAPI", + "OllamaLLM", ] diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py index 6a267b7ee..b59326c7f 100644 --- a/metagpt/provider/azure_openai_api.py +++ b/metagpt/provider/azure_openai_api.py @@ -15,11 +15,11 @@ from openai._base_client import AsyncHttpxClientWrapper from metagpt.config import LLMProviderEnum from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.openai_api import OpenAILLM @register_provider(LLMProviderEnum.AZURE_OPENAI) -class AzureOpenAIGPTAPI(OpenAIGPTAPI): +class AzureOpenAILLM(OpenAILLM): """ Check https://platform.openai.com/examples for examples """ diff --git a/metagpt/provider/base_chatbot.py b/metagpt/provider/base_chatbot.py deleted file mode 100644 index 8d490f1a6..000000000 --- a/metagpt/provider/base_chatbot.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/5 23:00 -@Author : alexanderwu -@File : base_chatbot.py -@Modified By: mashenquan, 2023/11/21. Add `timeout`. -""" -from abc import ABC, abstractmethod -from dataclasses import dataclass - - -@dataclass -class BaseChatbot(ABC): - """Abstract GPT class""" - - use_system_prompt: bool = True - - @abstractmethod - def ask(self, msg: str, timeout=3) -> str: - """Ask GPT a question and get an answer""" diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_llm.py similarity index 83% rename from metagpt/provider/base_gpt_api.py rename to metagpt/provider/base_llm.py index e6b180eaa..4d00adbc7 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_llm.py @@ -3,19 +3,18 @@ """ @Time : 2023/5/5 23:04 @Author : alexanderwu -@File : base_gpt_api.py +@File : base_llm.py @Desc : mashenquan, 2023/8/22. + try catch """ import json -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Optional -from metagpt.provider.base_chatbot import BaseChatbot +class BaseLLM(ABC): + """LLM API abstract class, requiring all inheritors to provide a series of standard capabilities""" -class BaseGPTAPI(BaseChatbot): - """GPT API abstract class, requiring all inheritors to provide a series of standard capabilities""" - + use_system_prompt: bool = True system_prompt = "You are a helpful assistant." def _user_msg(self, msg: str) -> dict[str, str]: @@ -33,11 +32,6 @@ class BaseGPTAPI(BaseChatbot): def _default_system_msg(self): return self._system_msg(self.system_prompt) - def ask(self, msg: str, timeout=3) -> str: - message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)] - rsp = self.completion(message, timeout=timeout) - return self.get_choice_text(rsp) - async def aask( self, msg: str, @@ -54,7 +48,6 @@ class BaseGPTAPI(BaseChatbot): message.extend(format_msgs) message.append(self._user_msg(msg)) rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) - # logger.debug(rsp) return rsp def _extract_assistant_rsp(self, context): @@ -75,15 +68,6 @@ class BaseGPTAPI(BaseChatbot): rsp_text = await self.aask_batch(msgs, timeout=timeout) return rsp_text - def completion(self, messages: list[dict], timeout=3) -> dict: - """All GPTAPIs are required to provide the standard OpenAI completion interface - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "hello, show me python hello world code"}, - # {"role": "assistant", "content": ...}, # If there is an answer in the history, also include it - ] - """ - @abstractmethod async def acompletion(self, messages: list[dict], timeout=3): """Asynchronous version of completion diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index e42088213..5fe86fc1c 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -18,7 +18,7 @@ from tenacity import ( from metagpt.config import CONFIG, Config, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise +from metagpt.provider.openai_api import OpenAILLM, log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs MODEL_GRADE_TOKEN_COSTS = { @@ -72,13 +72,12 @@ class FireworksCostManager(CostManager): @register_provider(LLMProviderEnum.FIREWORKS) -class FireWorksGPTAPI(OpenAIGPTAPI): +class FireworksLLM(OpenAILLM): def __init__(self): self.config: Config = CONFIG self.__init_fireworks() self.auto_max_tokens = False self._cost_manager = FireworksCostManager() - RateLimiter.__init__(self, rpm=self.rpm) def __init_fireworks(self): self.is_azure = False diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index 015e34aeb..814be2f67 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -47,8 +47,7 @@ MAX_CONNECTION_RETRIES = 2 # Has one attribute per thread, 'session'. _thread_context = threading.local() -LLM_LOG = os.environ.get("LLM_LOG") -LLM_LOG = "debug" +LLM_LOG = os.environ.get("LLM_LOG", "debug") class ApiType(Enum): diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index ca2133cfa..5683095c7 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -21,7 +21,7 @@ from tenacity import ( from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise @@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel): @register_provider(LLMProviderEnum.GEMINI) -class GeminiGPTAPI(BaseGPTAPI): +class GeminiGPTAPI(BaseLLM): """ Refs to `https://ai.google.dev/tutorials/python_quickstart` """ diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index a90c78192..59d236a3a 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -6,10 +6,10 @@ Author: garylin2099 from typing import Optional from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM -class HumanProvider(BaseGPTAPI): +class HumanProvider(BaseLLM): """Humans provide themselves as a 'model', which actually takes in human input as its response. This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction """ diff --git a/metagpt/provider/metagpt_api.py b/metagpt/provider/metagpt_api.py index 7bc48b7ad..2b7629895 100644 --- a/metagpt/provider/metagpt_api.py +++ b/metagpt/provider/metagpt_api.py @@ -6,11 +6,11 @@ @Desc : MetaGPT LLM provider. """ from metagpt.config import LLMProviderEnum -from metagpt.provider import OpenAIGPTAPI +from metagpt.provider import OpenAILLM from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.METAGPT) -class MetaGPTAPI(OpenAIGPTAPI): +class MetaGPTAPI(OpenAILLM): def __init__(self): super().__init__() diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 0d6d51e04..95b944bf3 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -16,7 +16,7 @@ from tenacity import ( from metagpt.config import CONFIG, LLMProviderEnum from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.general_api_requestor import GeneralAPIRequestor from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise @@ -39,7 +39,7 @@ class OllamaCostManager(CostManager): @register_provider(LLMProviderEnum.OLLAMA) -class OllamaGPTAPI(BaseGPTAPI): +class OllamaLLM(BaseLLM): """ Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion` """ @@ -54,12 +54,8 @@ class OllamaGPTAPI(BaseGPTAPI): def __init_ollama(self, config: CONFIG): assert config.ollama_api_base - self.model = config.ollama_api_model - def close(self): - pass - def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} return kwargs @@ -87,18 +83,6 @@ class OllamaGPTAPI(BaseGPTAPI): chunk = chunk.decode(encoding) return json.loads(chunk) - def completion(self, messages: list[dict]) -> dict: - resp, _, _ = self.client.request( - method=self.http_method, - url=self.suffix_url, - params=self._const_kwargs(messages), - request_timeout=LLM_API_TIMEOUT, - ) - resp = self._decode_and_load(resp) - usage = self.get_usage(resp) - self._update_costs(usage) - return resp - async def _achat_completion(self, messages: list[dict]) -> dict: resp, _, _ = await self.client.arequest( method=self.http_method, @@ -111,7 +95,7 @@ class OllamaGPTAPI(BaseGPTAPI): self._update_costs(usage) return resp - async def acompletion(self, messages: list[dict]) -> dict: + async def acompletion(self, messages: list[dict], timeout=3) -> dict: return await self._achat_completion(messages) async def _achat_completion_stream(self, messages: list[dict]) -> str: diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index 21efb6677..2893f5b30 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -7,7 +7,7 @@ from openai.types import CompletionUsage from metagpt.config import CONFIG, Config, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter +from metagpt.provider.openai_api import OpenAILLM from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.token_counter import count_message_tokens, count_string_tokens @@ -35,13 +35,12 @@ class OpenLLMCostManager(CostManager): @register_provider(LLMProviderEnum.OPEN_LLM) -class OpenLLMGPTAPI(OpenAIGPTAPI): +class OpenLLMGPTAPI(OpenAILLM): def __init__(self): self.config: Config = CONFIG self.__init_openllm() self.auto_max_tokens = False self._cost_manager = OpenLLMCostManager() - RateLimiter.__init__(self, rpm=self.rpm) def __init_openllm(self): self.is_azure = False diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index bfd6c7917..64adbb1c0 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -3,15 +3,13 @@ @Time : 2023/5/5 23:08 @Author : alexanderwu @File : openai.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation; +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for isolation; Change cost control from global to company level. @Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout. @Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x. """ -import asyncio import json -import time from typing import AsyncIterator, Union from openai import APIConnectionError, AsyncOpenAI, AsyncStream @@ -28,7 +26,7 @@ from tenacity import ( from metagpt.config import CONFIG, Config, LLMProviderEnum from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message @@ -41,31 +39,6 @@ from metagpt.utils.token_counter import ( ) -class RateLimiter: - """Rate control class, each call goes through wait_if_needed, sleep if rate control is needed""" - - def __init__(self, rpm): - self.last_call_time = 0 - # Here 1.1 is used because even if the calls are made strictly according to time, - # they will still be QOS'd; consider switching to simple error retry later - self.interval = 1.1 * 60 / rpm - self.rpm = rpm - - def split_batches(self, batch): - return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)] - - async def wait_if_needed(self, num_requests): - current_time = time.time() - elapsed_time = current_time - self.last_call_time - - if elapsed_time < self.interval * num_requests: - remaining_time = self.interval * num_requests - elapsed_time - logger.info(f"sleep {remaining_time}") - await asyncio.sleep(remaining_time) - - self.last_call_time = time.time() - - def log_and_reraise(retry_state): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( @@ -78,7 +51,7 @@ See FAQ 5.8 @register_provider(LLMProviderEnum.OPENAI) -class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): +class OpenAILLM(BaseLLM): """Check https://platform.openai.com/examples for examples""" def __init__(self): @@ -86,11 +59,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): self._init_openai() self._init_client() self.auto_max_tokens = False - RateLimiter.__init__(self, rpm=self.rpm) - super().__init__() def _init_openai(self): - self.rpm = int(self.config.openai_api_rpm) self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs def _init_client(self): @@ -211,7 +181,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create Examples: - >>> llm = OpenAIGPTAPI() + >>> llm = OpenAILLM() >>> msg = [{'role': 'user', 'content': "Write a python hello world code."}] >>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 4ec7be8cf..ce889529a 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -1,9 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -@Time : 2023/7/21 11:15 -@Author : Leo Xiao -@File : anthropic_api.py +@File : spark_api.py """ import _thread as thread import base64 @@ -13,7 +11,6 @@ import hmac import json import ssl from time import mktime -from typing import Optional from urllib.parse import urlencode, urlparse from wsgiref.handlers import format_date_time @@ -21,32 +18,15 @@ import websocket # 使用websocket_client from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.SPARK) -class SparkGPTAPI(BaseGPTAPI): +class SparkLLM(BaseLLM): def __init__(self): logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") - def close(self): - pass - - def ask(self, msg: str) -> str: - message = [self._default_system_msg(), self._user_msg(msg)] - rsp = self.completion(message) - return rsp - - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str: - if system_msgs: - message = self._system_msgs(system_msgs) + [self._user_msg(msg)] - else: - message = [self._default_system_msg(), self._user_msg(msg)] - rsp = await self.acompletion(message) - logger.debug(message) - return rsp - def get_choice_text(self, rsp: dict) -> str: return rsp["payload"]["choices"]["text"][-1]["content"] @@ -56,15 +36,11 @@ class SparkGPTAPI(BaseGPTAPI): w = GetMessageFromWeb(messages) return w.run() - async def acompletion(self, messages: list[dict]): + async def acompletion(self, messages: list[dict], timeout=3): # 不支持异步 w = GetMessageFromWeb(messages) return w.run() - def completion(self, messages: list[dict]): - w = GetMessageFromWeb(messages) - return w.run() - class GetMessageFromWeb: class WsParam: diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 533ce5719..e4b066a0c 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -17,7 +17,7 @@ from tenacity import ( from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI @@ -31,7 +31,7 @@ class ZhiPuEvent(Enum): @register_provider(LLMProviderEnum.ZHIPUAI) -class ZhiPuAIGPTAPI(BaseGPTAPI): +class ZhiPuAIGPTAPI(BaseLLM): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` From now, there is only one model named `chatglm_turbo` diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 3e5f268f8..d6e874ffe 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -36,7 +36,7 @@ from metagpt.const import SERDESER_PATH from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, MessageQueue from metagpt.utils.common import ( any_to_name, @@ -141,7 +141,7 @@ class Role(BaseModel): desc: str = "" is_human: bool = False - _llm: BaseGPTAPI = Field(default_factory=LLM) # Each role has its own LLM, use different system message + _llm: BaseLLM = Field(default_factory=LLM) # Each role has its own LLM, use different system message _role_id: str = "" _states: list[str] = [] _actions: list[Action] = [] diff --git a/metagpt/roles/sk_agent.py b/metagpt/roles/sk_agent.py index 6063205bd..d982ebb68 100644 --- a/metagpt/roles/sk_agent.py +++ b/metagpt/roles/sk_agent.py @@ -19,7 +19,7 @@ from metagpt.actions import UserRequirement from metagpt.actions.execute_task import ExecuteTask from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.make_sk_kernel import make_sk_kernel @@ -44,7 +44,7 @@ class SkAgent(Role): plan: Any = None planner_cls: Any = None planner: Any = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) kernel: Kernel = Field(default_factory=Kernel) import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None import_skill: Type[Kernel.import_skill] = None diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 8f827986c..d6d190ad7 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -4,7 +4,7 @@ import json from pathlib import Path -from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI +from metagpt.provider.openai_api import OpenAILLM as GPTAPI ICL_SAMPLE = """Interface definition: ```text diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index ba7cb6f2d..40a3b44ed 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -10,7 +10,7 @@ import pytest from metagpt.actions.write_code import WriteCode from metagpt.logs import logger -from metagpt.provider.openai_api import OpenAIGPTAPI as LLM +from metagpt.provider.openai_api import OpenAILLM as LLM from metagpt.schema import CodingContext, Document from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE diff --git a/tests/metagpt/provider/test_base_gpt_api.py b/tests/metagpt/provider/test_base_gpt_api.py index 0bee0ce75..be2c0ea7a 100644 --- a/tests/metagpt/provider/test_base_gpt_api.py +++ b/tests/metagpt/provider/test_base_gpt_api.py @@ -8,7 +8,7 @@ import pytest -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message default_chat_resp = { @@ -27,7 +27,7 @@ prompt_msg = "who are you" resp_content = default_chat_resp["choices"][0]["message"]["content"] -class MockBaseGPTAPI(BaseGPTAPI): +class MockBaseGPTAPI(BaseLLM): def completion(self, messages: list[dict], timeout=3): return default_chat_resp diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index 4d92c5f45..00b3c716a 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -13,7 +13,7 @@ from openai.types.completion_usage import CompletionUsage from metagpt.provider.fireworks_api import ( MODEL_GRADE_TOKEN_COSTS, FireworksCostManager, - FireWorksGPTAPI, + FireworksLLM, ) resp_content = "I'm fireworks" @@ -62,7 +62,7 @@ async def test_fireworks_acompletion(mocker): mocker.patch( "metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream ) - fireworks_gpt = FireWorksGPTAPI() + fireworks_gpt = FireworksLLM() resp = await fireworks_gpt.acompletion(messages, stream=False) assert resp.choices[0].message.content in resp_content diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index aec7b8520..60f50c9ad 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -35,16 +35,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: return resp_content -def test_gemini_completion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion) - gemini_gpt = GeminiGPTAPI() - resp = gemini_gpt.completion(messages) - assert resp.text == resp_content - - resp = gemini_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion) diff --git a/tests/metagpt/provider/test_human_provider.py b/tests/metagpt/provider/test_human_provider.py index caab9f15f..8ba532781 100644 --- a/tests/metagpt/provider/test_human_provider.py +++ b/tests/metagpt/provider/test_human_provider.py @@ -17,15 +17,6 @@ async def mock_llm_aask(msg: str, timeout: int = 3) -> str: return mock_llm_ask(msg) -def test_human_provider(mocker): - mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask) - human_provider = HumanProvider() - - assert resp_content == human_provider.ask(None) - - assert not human_provider.completion(messages=[]) - - @pytest.mark.asyncio async def test_async_human_provider(mocker): mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask) diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index d552d9f9e..d19e23e17 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -5,7 +5,7 @@ import pytest from metagpt.config import CONFIG -from metagpt.provider.ollama_api import OllamaGPTAPI +from metagpt.provider.ollama_api import OllamaLLM prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] @@ -28,22 +28,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: return resp_content -def test_gemini_completion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion) - ollama_gpt = OllamaGPTAPI() - resp = ollama_gpt.completion(messages) - assert resp["message"]["content"] == default_resp["message"]["content"] - - resp = ollama_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion) mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion) mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream) - ollama_gpt = OllamaGPTAPI() + ollama_gpt = OllamaLLM() resp = await ollama_gpt.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 0736b1d4a..329edadff 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -2,13 +2,13 @@ from unittest.mock import Mock import pytest -from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import UserMessage @pytest.mark.asyncio async def test_aask_code(): - llm = OpenAIGPTAPI() + llm = OpenAILLM() msg = [{"role": "user", "content": "Write a python hello world code."}] rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} assert "language" in rsp @@ -18,7 +18,7 @@ async def test_aask_code(): @pytest.mark.asyncio async def test_aask_code_str(): - llm = OpenAIGPTAPI() + llm = OpenAILLM() msg = "Write a python hello world code." rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} assert "language" in rsp @@ -28,7 +28,7 @@ async def test_aask_code_str(): @pytest.mark.asyncio async def test_aask_code_Message(): - llm = OpenAIGPTAPI() + llm = OpenAILLM() msg = UserMessage("Write a python hello world code.") rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} assert "language" in rsp @@ -84,7 +84,7 @@ class TestOpenAI: ) def test_make_client_kwargs_without_proxy(self, config): - instance = OpenAIGPTAPI() + instance = OpenAILLM() instance.config = config kwargs, async_kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} @@ -93,7 +93,7 @@ class TestOpenAI: assert "http_client" not in async_kwargs def test_make_client_kwargs_without_proxy_azure(self, config_azure): - instance = OpenAIGPTAPI() + instance = OpenAILLM() instance.config = config_azure kwargs, async_kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} @@ -102,14 +102,14 @@ class TestOpenAI: assert "http_client" not in async_kwargs def test_make_client_kwargs_with_proxy(self, config_proxy): - instance = OpenAIGPTAPI() + instance = OpenAILLM() instance.config = config_proxy kwargs, async_kwargs = instance._make_client_kwargs() assert "http_client" in kwargs assert "http_client" in async_kwargs def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): - instance = OpenAIGPTAPI() + instance = OpenAILLM() instance.config = config_azure_proxy kwargs, async_kwargs = instance._make_client_kwargs() assert "http_client" in kwargs diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 61ae8cbec..6cc87741e 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,7 +4,7 @@ import pytest -from metagpt.provider.spark_api import SparkGPTAPI +from metagpt.provider.spark_api import SparkLLM prompt_msg = "who are you" resp_content = "I'm Spark" @@ -18,24 +18,13 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, return resp_content -def test_spark_completion(mocker): - mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion) - spark_gpt = SparkGPTAPI() - - resp = spark_gpt.completion([]) - assert resp == resp_content - - resp = spark_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion) mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion) - spark_gpt = SparkGPTAPI() + spark_gpt = SparkLLM() - resp = await spark_gpt.acompletion([], stream=False) + resp = await spark_gpt.acompletion([]) assert resp == resp_content resp = await spark_gpt.aask(prompt_msg, stream=False) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index ec02e1b47..d9cd23281 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -28,18 +28,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: return resp_content -def test_zhipuai_completion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion) - zhipu_gpt = ZhiPuAIGPTAPI() - - resp = zhipu_gpt.completion(messages) - assert resp["code"] == 200 - assert resp["data"]["choices"][0]["content"] == resp_content - - resp = zhipu_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion) diff --git a/tests/metagpt/test_gpt.py b/tests/metagpt/test_gpt.py index caa1eb277..2b19f173d 100644 --- a/tests/metagpt/test_gpt.py +++ b/tests/metagpt/test_gpt.py @@ -14,15 +14,6 @@ from metagpt.logs import logger @pytest.mark.usefixtures("llm_api") class TestGPT: - def test_llm_api_ask(self, llm_api): - answer = llm_api.ask("hello chatgpt") - logger.info(answer) - assert len(answer) > 0 - - def test_gptapi_ask_batch(self, llm_api): - answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60) - assert len(answer) > 0 - @pytest.mark.asyncio async def test_llm_api_aask(self, llm_api): answer = await llm_api.aask("hello chatgpt", stream=False) diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index bc685ed8b..247f043e2 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -9,7 +9,7 @@ import pytest -from metagpt.provider.openai_api import OpenAIGPTAPI as LLM +from metagpt.provider.openai_api import OpenAILLM as LLM @pytest.fixture()