diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index c8c901eb0..576990a83 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -14,7 +14,7 @@ from pydantic import BaseModel, Field from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import ( CodeSummarizeContext, CodingContext, @@ -27,7 +27,7 @@ action_subclass_registry = {} class Action(BaseModel): name: str = "" - llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) + llm: BaseLLM = Field(default_factory=LLM, exclude=True) context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = "" prefix = "" # aask*时会加上prefix,作为system_message desc = "" # for skill manager diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 63f46ad45..b554f15dd 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -15,7 +15,7 @@ from pydantic import BaseModel, create_model, root_validator, validator from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.config import CONFIG -from metagpt.llm import BaseGPTAPI +from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess from metagpt.utils.common import OutputParser, general_after_log @@ -60,7 +60,7 @@ class ActionNode: # Action Context context: str # all the context, including all necessary info - llm: BaseGPTAPI # LLM with aask interface + llm: BaseLLM # LLM with aask interface children: dict[str, "ActionNode"] # Action Input diff --git a/metagpt/actions/clone_function.py b/metagpt/actions/clone_function.py index 429f04286..7053df97b 100644 --- a/metagpt/actions/clone_function.py +++ b/metagpt/actions/clone_function.py @@ -5,7 +5,7 @@ from pydantic import Field from metagpt.actions.write_code import WriteCode from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message from metagpt.utils.exceptions import handle_exception from metagpt.utils.highlight import highlight @@ -33,7 +33,7 @@ def run(*args) -> pd.DataFrame: class CloneFunction(WriteCode): name: str = "CloneFunction" context: list[Message] = [] - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) def _save(self, code_path, code): if isinstance(code_path, str): diff --git a/metagpt/actions/debug_error.py b/metagpt/actions/debug_error.py index 9dc6862f9..1a7c3a7c8 100644 --- a/metagpt/actions/debug_error.py +++ b/metagpt/actions/debug_error.py @@ -15,7 +15,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO -from metagpt.llm import LLM, BaseGPTAPI +from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.common import CodeParser @@ -52,7 +52,7 @@ Now you should start rewriting the code: class DebugError(Action): name: str = "DebugError" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, *args, **kwargs) -> str: output_doc = await FileRepository.get_file( diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index 055365421..8535d63b1 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -27,7 +27,7 @@ from metagpt.const import ( ) from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, Documents, Message from metagpt.utils.file_repository import FileRepository from metagpt.utils.mermaid import mermaid_to_file @@ -44,7 +44,7 @@ NEW_REQ_TEMPLATE = """ class WriteDesign(Action): name: str = "" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) desc: str = ( "Based on the PRD, think about the system design, and design the corresponding APIs, " "data structures, library tables, processes, and paths. Please provide your design, feedback " diff --git a/metagpt/actions/design_api_review.py b/metagpt/actions/design_api_review.py index 0ff522fe8..6ea76e2fc 100644 --- a/metagpt/actions/design_api_review.py +++ b/metagpt/actions/design_api_review.py @@ -12,13 +12,13 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM class DesignReview(Action): name: str = "DesignReview" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, prd, api_design): prompt = ( diff --git a/metagpt/actions/execute_task.py b/metagpt/actions/execute_task.py index b11f361b0..8577ee275 100644 --- a/metagpt/actions/execute_task.py +++ b/metagpt/actions/execute_task.py @@ -10,14 +10,14 @@ from pydantic import Field from metagpt.actions import Action from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message class ExecuteTask(Action): name: str = "ExecuteTask" context: list[Message] = [] - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, *args, **kwargs): pass diff --git a/metagpt/actions/invoice_ocr.py b/metagpt/actions/invoice_ocr.py index 87f81371e..94288d5be 100644 --- a/metagpt/actions/invoice_ocr.py +++ b/metagpt/actions/invoice_ocr.py @@ -26,7 +26,7 @@ from metagpt.prompts.invoice_ocr import ( EXTRACT_OCR_MAIN_INFO_PROMPT, REPLY_OCR_QUESTION_PROMPT, ) -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser from metagpt.utils.file import File @@ -42,7 +42,7 @@ class InvoiceOCR(Action): name: str = "InvoiceOCR" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @staticmethod async def _check_file_type(file_path: Path) -> str: @@ -132,7 +132,7 @@ class GenerateTable(Action): name: str = "GenerateTable" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) language: str = "ch" async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]: @@ -177,7 +177,7 @@ class ReplyQuestion(Action): name: str = "ReplyQuestion" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) language: str = "ch" async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str: diff --git a/metagpt/actions/prepare_documents.py b/metagpt/actions/prepare_documents.py index 696dc9a89..ad82e56dc 100644 --- a/metagpt/actions/prepare_documents.py +++ b/metagpt/actions/prepare_documents.py @@ -17,7 +17,7 @@ from metagpt.actions import Action, ActionOutput from metagpt.config import CONFIG from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document from metagpt.utils.file_repository import FileRepository from metagpt.utils.git_repository import GitRepository @@ -28,7 +28,7 @@ class PrepareDocuments(Action): name: str = "PrepareDocuments" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) def _init_repo(self): """Initialize the Git environment.""" diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index 095881e60..7eda89130 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -27,7 +27,7 @@ from metagpt.const import ( ) from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, Documents from metagpt.utils.file_repository import FileRepository @@ -43,7 +43,7 @@ NEW_REQ_TEMPLATE = """ class WriteTasks(Action): name: str = "CreateTasks" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, with_messages, schema=CONFIG.prompt_schema): system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO) diff --git a/metagpt/actions/research.py b/metagpt/actions/research.py index c47a77bdd..a6cc7cc22 100644 --- a/metagpt/actions/research.py +++ b/metagpt/actions/research.py @@ -11,7 +11,7 @@ from metagpt.actions import Action from metagpt.config import CONFIG from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.tools.search_engine import SearchEngine from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType from metagpt.utils.common import OutputParser @@ -82,7 +82,7 @@ class CollectLinks(Action): name: str = "CollectLinks" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) desc: str = "Collect links from a search engine." search_engine: SearchEngine = Field(default_factory=SearchEngine) rank_func: Union[Callable[[list[str]], None], None] = None @@ -177,7 +177,7 @@ class WebBrowseAndSummarize(Action): name: str = "WebBrowseAndSummarize" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) desc: str = "Explore the web and provide summaries of articles and webpages." browse_func: Union[Callable[[list[str]], None], None] = None web_browser_engine: WebBrowserEngine = Field( @@ -248,7 +248,7 @@ class ConductResearch(Action): name: str = "ConductResearch" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) def __init__(self, **kwargs): super().__init__(**kwargs) diff --git a/metagpt/actions/run_code.py b/metagpt/actions/run_code.py index bca9b337d..22d345b85 100644 --- a/metagpt/actions/run_code.py +++ b/metagpt/actions/run_code.py @@ -22,7 +22,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.config import CONFIG -from metagpt.llm import LLM, BaseGPTAPI +from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import RunCodeContext, RunCodeResult from metagpt.utils.exceptions import handle_exception @@ -79,7 +79,7 @@ standard errors: class RunCode(Action): name: str = "RunCode" context: RunCodeContext = Field(default_factory=RunCodeContext) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @classmethod @handle_exception diff --git a/metagpt/actions/search_and_summarize.py b/metagpt/actions/search_and_summarize.py index 9fd392a5c..615576d76 100644 --- a/metagpt/actions/search_and_summarize.py +++ b/metagpt/actions/search_and_summarize.py @@ -14,7 +14,7 @@ from metagpt.actions import Action from metagpt.config import CONFIG, Config from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message from metagpt.tools import SearchEngineType from metagpt.tools.search_engine import SearchEngine @@ -109,7 +109,7 @@ You are a member of a professional butler team and will provide helpful suggesti class SearchAndSummarize(Action): name: str = "" content: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) config: None = Field(default_factory=Config) engine: Optional[SearchEngineType] = CONFIG.search_engine search_func: Optional[Any] = None diff --git a/metagpt/actions/summarize_code.py b/metagpt/actions/summarize_code.py index 2d1cd4d3d..4025e0964 100644 --- a/metagpt/actions/summarize_code.py +++ b/metagpt/actions/summarize_code.py @@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO -from metagpt.llm import LLM, BaseGPTAPI +from metagpt.llm import LLM, BaseLLM from metagpt.logs import logger from metagpt.schema import CodeSummarizeContext from metagpt.utils.file_repository import FileRepository @@ -95,7 +95,7 @@ flowchart TB class SummarizeCode(Action): name: str = "SummarizeCode" context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60)) async def summarize_code(self, prompt): diff --git a/metagpt/actions/write_code.py b/metagpt/actions/write_code.py index 4d0690e0f..e3086f03c 100644 --- a/metagpt/actions/write_code.py +++ b/metagpt/actions/write_code.py @@ -31,7 +31,7 @@ from metagpt.const import ( ) from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import CodingContext, Document, RunCodeResult from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -90,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc class WriteCode(Action): name: str = "WriteCode" context: Document = Field(default_factory=Document) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code(self, prompt) -> str: diff --git a/metagpt/actions/write_code_review.py b/metagpt/actions/write_code_review.py index b0e7904e3..a8ed0fd01 100644 --- a/metagpt/actions/write_code_review.py +++ b/metagpt/actions/write_code_review.py @@ -16,7 +16,7 @@ from metagpt.actions.action import Action from metagpt.config import CONFIG from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import CodingContext from metagpt.utils.common import CodeParser @@ -123,7 +123,7 @@ REWRITE_CODE_TEMPLATE = """ class WriteCodeReview(Action): name: str = "WriteCodeReview" context: CodingContext = Field(default_factory=CodingContext) - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) @retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6)) async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename): diff --git a/metagpt/actions/write_docstring.py b/metagpt/actions/write_docstring.py index 1c27a9433..68856c360 100644 --- a/metagpt/actions/write_docstring.py +++ b/metagpt/actions/write_docstring.py @@ -28,7 +28,7 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser from metagpt.utils.pycst import merge_docstring @@ -163,7 +163,7 @@ class WriteDocstring(Action): desc: str = "Write docstring for code." context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run( self, diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 47e02b699..5b1108244 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -38,7 +38,7 @@ from metagpt.const import ( ) from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import BugFixContext, Document, Documents, Message from metagpt.utils.common import CodeParser from metagpt.utils.file_repository import FileRepository @@ -67,7 +67,7 @@ NEW_REQ_TEMPLATE = """ class WritePRD(Action): name: str = "" content: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message: # Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are diff --git a/metagpt/actions/write_prd_review.py b/metagpt/actions/write_prd_review.py index 6ed73b6a2..0241f192f 100644 --- a/metagpt/actions/write_prd_review.py +++ b/metagpt/actions/write_prd_review.py @@ -12,13 +12,13 @@ from pydantic import Field from metagpt.actions.action import Action from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM class WritePRDReview(Action): name: str = "" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) prd: Optional[str] = None desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback" prd_review_prompt_template: str = """ diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index 646f44aeb..d116556ba 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -11,7 +11,7 @@ from pydantic import Field from metagpt.actions import Action from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM REVIEW = ActionNode( key="Review", @@ -38,7 +38,7 @@ class WriteReview(Action): """Write a review for the given context.""" name: str = "WriteReview" - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def run(self, context): return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") diff --git a/metagpt/actions/write_teaching_plan.py b/metagpt/actions/write_teaching_plan.py index d889fdbe3..888627294 100644 --- a/metagpt/actions/write_teaching_plan.py +++ b/metagpt/actions/write_teaching_plan.py @@ -13,14 +13,14 @@ from metagpt.actions import Action from metagpt.config import CONFIG from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM class WriteTeachingPlanPart(Action): """Write Teaching Plan Part""" context: Optional[str] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) topic: str = "" language: str = "Chinese" rsp: Optional[str] = None diff --git a/metagpt/actions/write_test.py b/metagpt/actions/write_test.py index 850606ca8..321d31420 100644 --- a/metagpt/actions/write_test.py +++ b/metagpt/actions/write_test.py @@ -17,7 +17,7 @@ from metagpt.config import CONFIG from metagpt.const import TEST_CODES_FILE_REPO from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Document, TestingContext from metagpt.utils.common import CodeParser @@ -45,7 +45,7 @@ you should correctly import the necessary classes based on these file locations! class WriteTest(Action): name: str = "WriteTest" context: Optional[TestingContext] = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) async def write_code(self, prompt): code_rsp = await self._aask(prompt) diff --git a/metagpt/actions/write_tutorial.py b/metagpt/actions/write_tutorial.py index f33a6b114..a2a324b41 100644 --- a/metagpt/actions/write_tutorial.py +++ b/metagpt/actions/write_tutorial.py @@ -14,7 +14,7 @@ from pydantic import Field from metagpt.actions import Action from metagpt.llm import LLM from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import OutputParser @@ -27,7 +27,7 @@ class WriteDirectory(Action): """ name: str = "WriteDirectory" - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) language: str = "Chinese" async def run(self, topic: str, *args, **kwargs) -> Dict: @@ -54,7 +54,7 @@ class WriteContent(Action): """ name: str = "WriteContent" - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) directory: dict = dict() language: str = "Chinese" diff --git a/metagpt/llm.py b/metagpt/llm.py index f1cb98dae..76dd5a0f8 100644 --- a/metagpt/llm.py +++ b/metagpt/llm.py @@ -9,14 +9,14 @@ from typing import Optional from metagpt.config import CONFIG, LLMProviderEnum -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.human_provider import HumanProvider from metagpt.provider.llm_provider_registry import LLM_REGISTRY _ = HumanProvider() # Avoid pre-commit error -def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI: +def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM: """get the default llm provider""" if provider is None: provider = CONFIG.get_default_llm_provider_enum() diff --git a/metagpt/memory/brain_memory.py b/metagpt/memory/brain_memory.py index 8b47ba79a..0833d71a1 100644 --- a/metagpt/memory/brain_memory.py +++ b/metagpt/memory/brain_memory.py @@ -10,14 +10,15 @@ """ import json import re -from typing import Dict, List +from typing import Dict, List, Optional from pydantic import BaseModel, Field from metagpt.config import CONFIG -from metagpt.const import DEFAULT_LANGUAGE +from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import logger from metagpt.provider import MetaGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, SimpleMessage from metagpt.utils.redis import Redis @@ -30,6 +31,7 @@ class BrainMemory(BaseModel): is_dirty: bool = False last_talk: str = None cacheable: bool = True + llm: Optional[BaseLLM] = None def add_talk(self, msg: Message): """ @@ -120,6 +122,7 @@ class BrainMemory(BaseModel): if isinstance(llm, MetaGPTAPI): return await self._metagpt_summarize(max_words=max_words) + self.llm = llm return await self._openai_summarize(llm=llm, max_words=max_words, keep_language=keep_language, limit=limit) async def _openai_summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1): @@ -131,7 +134,7 @@ class BrainMemory(BaseModel): text_length = len(text) if limit > 0 and text_length < limit: return text - summary = await llm.summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit) + summary = await self._summarize(text=text, max_words=max_words, keep_language=keep_language, limit=limit) if summary: await self.set_history_summary(history_summary=summary, redis_key=CONFIG.REDIS_KEY, redis_conf=CONFIG.REDIS) return summary @@ -251,3 +254,74 @@ class BrainMemory(BaseModel): texts.append(t) return "\n".join(texts) + + async def _summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str: + max_token_count = DEFAULT_MAX_TOKENS + max_count = 100 + text_length = len(text) + if limit > 0 and text_length < limit: + return text + summary = "" + while max_count > 0: + if text_length < max_token_count: + summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) + break + + padding_size = 20 if max_token_count > 20 else 0 + text_windows = self.split_texts(text, window_size=max_token_count - padding_size) + part_max_words = min(int(max_words / len(text_windows)) + 1, 100) + summaries = [] + for ws in text_windows: + response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language) + summaries.append(response) + if len(summaries) == 1: + summary = summaries[0] + break + + # Merged and retry + text = "\n".join(summaries) + text_length = len(text) + + max_count -= 1 # safeguard + return summary + + async def _get_summary(self, text: str, max_words=20, keep_language: bool = False): + """Generate text summary""" + if len(text) < max_words: + return text + if keep_language: + command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." + else: + command = f"Translate the above content into a summary of less than {max_words} words." + msg = text + "\n\n" + command + logger.debug(f"summary ask:{msg}") + response = await self.llm.aask(msg=msg, system_msgs=[]) + logger.debug(f"summary rsp: {response}") + return response + + @staticmethod + def split_texts(text: str, window_size) -> List[str]: + """Splitting long text into sliding windows text""" + if window_size <= 0: + window_size = DEFAULT_TOKEN_SIZE + total_len = len(text) + if total_len <= window_size: + return [text] + + padding_size = 20 if window_size > 20 else 0 + windows = [] + idx = 0 + data_len = window_size - padding_size + while idx < total_len: + if window_size + idx > total_len: # 不足一个滑窗 + windows.append(text[idx:]) + break + # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] + # window_size=3, padding_size=1: + # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... + # idx=2, | idx=5 | idx=8 | ... + w = text[idx : idx + window_size] + windows.append(w) + idx += data_len + + return windows diff --git a/metagpt/provider/__init__.py b/metagpt/provider/__init__.py index 769c8e7b8..36d585c94 100644 --- a/metagpt/provider/__init__.py +++ b/metagpt/provider/__init__.py @@ -6,22 +6,22 @@ @File : __init__.py """ -from metagpt.provider.fireworks_api import FireWorksGPTAPI +from metagpt.provider.fireworks_api import FireworksLLM from metagpt.provider.google_gemini_api import GeminiGPTAPI -from metagpt.provider.ollama_api import OllamaGPTAPI +from metagpt.provider.ollama_api import OllamaLLM from metagpt.provider.open_llm_api import OpenLLMGPTAPI -from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.openai_api import OpenAILLM from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI -from metagpt.provider.azure_openai_api import AzureOpenAIGPTAPI +from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.metagpt_api import MetaGPTAPI __all__ = [ - "FireWorksGPTAPI", + "FireworksLLM", "GeminiGPTAPI", "OpenLLMGPTAPI", - "OpenAIGPTAPI", + "OpenAILLM", "ZhiPuAIGPTAPI", - "AzureOpenAIGPTAPI", + "AzureOpenAILLM", "MetaGPTAPI", - "OllamaGPTAPI", + "OllamaLLM", ] diff --git a/metagpt/provider/azure_openai_api.py b/metagpt/provider/azure_openai_api.py index ca0696830..b59326c7f 100644 --- a/metagpt/provider/azure_openai_api.py +++ b/metagpt/provider/azure_openai_api.py @@ -10,60 +10,36 @@ """ -from openai import AsyncAzureOpenAI, AzureOpenAI -from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper +from openai import AsyncAzureOpenAI +from openai._base_client import AsyncHttpxClientWrapper -from metagpt.config import CONFIG, Config, LLMProviderEnum +from metagpt.config import LLMProviderEnum from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter +from metagpt.provider.openai_api import OpenAILLM @register_provider(LLMProviderEnum.AZURE_OPENAI) -class AzureOpenAIGPTAPI(OpenAIGPTAPI): +class AzureOpenAILLM(OpenAILLM): """ Check https://platform.openai.com/examples for examples """ - def __init__(self): - self.config: Config = CONFIG - self._init_openai() - self.auto_max_tokens = False - RateLimiter.__init__(self, rpm=self.rpm) - - def _make_client(self): - kwargs, async_kwargs = self._make_client_kwargs() + def _init_client(self): + kwargs = self._make_client_kwargs() # https://learn.microsoft.com/zh-cn/azure/ai-services/openai/how-to/migration?tabs=python-new%2Cdalle-fix - self.client = AzureOpenAI(**kwargs) - self.async_client = AsyncAzureOpenAI(**async_kwargs) + self.async_client = AsyncAzureOpenAI(**kwargs) self.model = self.config.DEPLOYMENT_NAME # Used in _calc_usage & _cons_kwargs - def _make_client_kwargs(self) -> (dict, dict): + def _make_client_kwargs(self) -> dict: kwargs = dict( api_key=self.config.OPENAI_API_KEY, api_version=self.config.OPENAI_API_VERSION, azure_endpoint=self.config.OPENAI_BASE_URL, ) - async_kwargs = kwargs.copy() # to use proxy, openai v1 needs http_client proxy_params = self._get_proxy_params() if proxy_params: - kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params) - async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) - - return kwargs, async_kwargs - - def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict: - kwargs = { - "messages": messages, - "max_tokens": self.get_max_tokens(messages), - "n": 1, - "stop": None, - "temperature": 0.3, - "model": self.model, - } - if configs: - kwargs.update(configs) - kwargs["timeout"] = max(CONFIG.timeout, timeout) + kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) return kwargs diff --git a/metagpt/provider/base_chatbot.py b/metagpt/provider/base_chatbot.py deleted file mode 100644 index 535130de7..000000000 --- a/metagpt/provider/base_chatbot.py +++ /dev/null @@ -1,30 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/5 23:00 -@Author : alexanderwu -@File : base_chatbot.py -@Modified By: mashenquan, 2023/11/21. Add `timeout`. -""" -from abc import ABC, abstractmethod -from dataclasses import dataclass - - -@dataclass -class BaseChatbot(ABC): - """Abstract GPT class""" - - mode: str = "API" - use_system_prompt: bool = True - - @abstractmethod - def ask(self, msg: str, timeout=3) -> str: - """Ask GPT a question and get an answer""" - - @abstractmethod - def ask_batch(self, msgs: list, timeout=3) -> str: - """Ask GPT multiple questions and get a series of answers""" - - @abstractmethod - def ask_code(self, msgs: list, timeout=3) -> str: - """Ask GPT multiple questions and get a piece of code""" diff --git a/metagpt/provider/base_gpt_api.py b/metagpt/provider/base_llm.py similarity index 68% rename from metagpt/provider/base_gpt_api.py rename to metagpt/provider/base_llm.py index a5541324f..4d00adbc7 100644 --- a/metagpt/provider/base_gpt_api.py +++ b/metagpt/provider/base_llm.py @@ -3,19 +3,18 @@ """ @Time : 2023/5/5 23:04 @Author : alexanderwu -@File : base_gpt_api.py +@File : base_llm.py @Desc : mashenquan, 2023/8/22. + try catch """ import json -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Optional -from metagpt.provider.base_chatbot import BaseChatbot +class BaseLLM(ABC): + """LLM API abstract class, requiring all inheritors to provide a series of standard capabilities""" -class BaseGPTAPI(BaseChatbot): - """GPT API abstract class, requiring all inheritors to provide a series of standard capabilities""" - + use_system_prompt: bool = True system_prompt = "You are a helpful assistant." def _user_msg(self, msg: str) -> dict[str, str]: @@ -33,17 +32,11 @@ class BaseGPTAPI(BaseChatbot): def _default_system_msg(self): return self._system_msg(self.system_prompt) - def ask(self, msg: str, timeout=3) -> str: - message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)] - rsp = self.completion(message, timeout=timeout) - return self.get_choice_text(rsp) - async def aask( self, msg: str, system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, - generator: bool = False, timeout=3, stream=True, ) -> str: @@ -54,23 +47,12 @@ class BaseGPTAPI(BaseChatbot): if format_msgs: message.extend(format_msgs) message.append(self._user_msg(msg)) - rsp = await self.acompletion_text(message, stream=stream, generator=generator, timeout=timeout) - # logger.debug(rsp) + rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp def _extract_assistant_rsp(self, context): return "\n".join([i["content"] for i in context if i["role"] == "assistant"]) - def ask_batch(self, msgs: list, timeout=3) -> str: - context = [] - for msg in msgs: - umsg = self._user_msg(msg) - context.append(umsg) - rsp = self.completion(context, timeout=timeout) - rsp_text = self.get_choice_text(rsp) - context.append(self._assistant_msg(rsp_text)) - return self._extract_assistant_rsp(context) - async def aask_batch(self, msgs: list, timeout=3) -> str: """Sequential questioning""" context = [] @@ -81,26 +63,11 @@ class BaseGPTAPI(BaseChatbot): context.append(self._assistant_msg(rsp_text)) return self._extract_assistant_rsp(context) - def ask_code(self, msgs: list[str], timeout=3) -> str: - """FIXME: No code segment filtering has been done here, and all results are actually displayed""" - rsp_text = self.ask_batch(msgs, timeout=timeout) - return rsp_text - async def aask_code(self, msgs: list[str], timeout=3) -> str: """FIXME: No code segment filtering has been done here, and all results are actually displayed""" rsp_text = await self.aask_batch(msgs, timeout=timeout) return rsp_text - @abstractmethod - def completion(self, messages: list[dict], timeout=3): - """All GPTAPIs are required to provide the standard OpenAI completion interface - [ - {"role": "system", "content": "You are a helpful assistant."}, - {"role": "user", "content": "hello, show me python hello world code"}, - # {"role": "assistant", "content": ...}, # If there is an answer in the history, also include it - ] - """ - @abstractmethod async def acompletion(self, messages: list[dict], timeout=3): """Asynchronous version of completion @@ -113,7 +80,7 @@ class BaseGPTAPI(BaseChatbot): """ @abstractmethod - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """Asynchronous version of completion. Return str. Support stream-print""" def get_choice_text(self, rsp: dict) -> str: @@ -159,16 +126,3 @@ class BaseGPTAPI(BaseChatbot): {'language': 'python', 'code': "print('Hello, World!')"} """ return json.loads(self.get_choice_function(rsp)["arguments"]) - - def messages_to_prompt(self, messages: list[dict]): - """[{"role": "user", "content": msg}] to user: etc.""" - return "\n".join([f"{i.role}: {i.content}" for i in messages]) - - def messages_to_dict(self, messages): - """objects to [{"role": "user", "content": msg}] etc.""" - return [i.to_dict() for i in messages] - - @abstractmethod - async def close(self): - """Close connection""" - pass diff --git a/metagpt/provider/fireworks_api.py b/metagpt/provider/fireworks_api.py index 55b1b6c28..5fe86fc1c 100644 --- a/metagpt/provider/fireworks_api.py +++ b/metagpt/provider/fireworks_api.py @@ -18,7 +18,7 @@ from tenacity import ( from metagpt.config import CONFIG, Config, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise +from metagpt.provider.openai_api import OpenAILLM, log_and_reraise from metagpt.utils.cost_manager import CostManager, Costs MODEL_GRADE_TOKEN_COSTS = { @@ -72,18 +72,17 @@ class FireworksCostManager(CostManager): @register_provider(LLMProviderEnum.FIREWORKS) -class FireWorksGPTAPI(OpenAIGPTAPI): +class FireworksLLM(OpenAILLM): def __init__(self): self.config: Config = CONFIG self.__init_fireworks() self.auto_max_tokens = False self._cost_manager = FireworksCostManager() - RateLimiter.__init__(self, rpm=self.rpm) def __init_fireworks(self): self.is_azure = False self.rpm = int(self.config.get("RPM", 10)) - self._make_client() + self._init_client() self.model = self.config.fireworks_api_model # `self.model` should after `_make_client` to rewrite it def _make_client_kwargs(self) -> (dict, dict): @@ -103,7 +102,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI): return self._cost_manager.get_costs() async def _achat_completion_stream(self, messages: list[dict]) -> str: - response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( + response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages), stream=True ) @@ -133,9 +132,7 @@ class FireWorksGPTAPI(OpenAIGPTAPI): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text( - self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 - ) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: """when streaming, print each token in place.""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/general_api_base.py b/metagpt/provider/general_api_base.py index 015e34aeb..814be2f67 100644 --- a/metagpt/provider/general_api_base.py +++ b/metagpt/provider/general_api_base.py @@ -47,8 +47,7 @@ MAX_CONNECTION_RETRIES = 2 # Has one attribute per thread, 'session'. _thread_context = threading.local() -LLM_LOG = os.environ.get("LLM_LOG") -LLM_LOG = "debug" +LLM_LOG = os.environ.get("LLM_LOG", "debug") class ApiType(Enum): diff --git a/metagpt/provider/google_gemini_api.py b/metagpt/provider/google_gemini_api.py index eace329aa..5683095c7 100644 --- a/metagpt/provider/google_gemini_api.py +++ b/metagpt/provider/google_gemini_api.py @@ -21,7 +21,7 @@ from tenacity import ( from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise @@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel): @register_provider(LLMProviderEnum.GEMINI) -class GeminiGPTAPI(BaseGPTAPI): +class GeminiGPTAPI(BaseLLM): """ Refs to `https://ai.google.dev/tutorials/python_quickstart` """ @@ -136,9 +136,7 @@ class GeminiGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text( - self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 - ) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/human_provider.py b/metagpt/provider/human_provider.py index 5850dd8dc..59d236a3a 100644 --- a/metagpt/provider/human_provider.py +++ b/metagpt/provider/human_provider.py @@ -6,10 +6,10 @@ Author: garylin2099 from typing import Optional from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM -class HumanProvider(BaseGPTAPI): +class HumanProvider(BaseLLM): """Humans provide themselves as a 'model', which actually takes in human input as its response. This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction """ @@ -31,10 +31,6 @@ class HumanProvider(BaseGPTAPI): ) -> str: return self.ask(msg, timeout=timeout) - def completion(self, messages: list[dict], timeout=3): - """dummy implementation of abstract method in base""" - return [] - async def acompletion(self, messages: list[dict], timeout=3): """dummy implementation of abstract method in base""" return [] @@ -42,7 +38,3 @@ class HumanProvider(BaseGPTAPI): async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """dummy implementation of abstract method in base""" return "" - - async def close(self): - """Close connection""" - pass diff --git a/metagpt/provider/metagpt_api.py b/metagpt/provider/metagpt_api.py index 7bc48b7ad..2b7629895 100644 --- a/metagpt/provider/metagpt_api.py +++ b/metagpt/provider/metagpt_api.py @@ -6,11 +6,11 @@ @Desc : MetaGPT LLM provider. """ from metagpt.config import LLMProviderEnum -from metagpt.provider import OpenAIGPTAPI +from metagpt.provider import OpenAILLM from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.METAGPT) -class MetaGPTAPI(OpenAIGPTAPI): +class MetaGPTAPI(OpenAILLM): def __init__(self): super().__init__() diff --git a/metagpt/provider/ollama_api.py b/metagpt/provider/ollama_api.py index 90a50a154..95b944bf3 100644 --- a/metagpt/provider/ollama_api.py +++ b/metagpt/provider/ollama_api.py @@ -16,7 +16,7 @@ from tenacity import ( from metagpt.config import CONFIG, LLMProviderEnum from metagpt.const import LLM_API_TIMEOUT from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.general_api_requestor import GeneralAPIRequestor from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise @@ -39,7 +39,7 @@ class OllamaCostManager(CostManager): @register_provider(LLMProviderEnum.OLLAMA) -class OllamaGPTAPI(BaseGPTAPI): +class OllamaLLM(BaseLLM): """ Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion` """ @@ -54,12 +54,8 @@ class OllamaGPTAPI(BaseGPTAPI): def __init_ollama(self, config: CONFIG): assert config.ollama_api_base - self.model = config.ollama_api_model - def close(self): - pass - def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict: kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream} return kwargs @@ -87,18 +83,6 @@ class OllamaGPTAPI(BaseGPTAPI): chunk = chunk.decode(encoding) return json.loads(chunk) - def completion(self, messages: list[dict]) -> dict: - resp, _, _ = self.client.request( - method=self.http_method, - url=self.suffix_url, - params=self._const_kwargs(messages), - request_timeout=LLM_API_TIMEOUT, - ) - resp = self._decode_and_load(resp) - usage = self.get_usage(resp) - self._update_costs(usage) - return resp - async def _achat_completion(self, messages: list[dict]) -> dict: resp, _, _ = await self.client.arequest( method=self.http_method, @@ -111,7 +95,7 @@ class OllamaGPTAPI(BaseGPTAPI): self._update_costs(usage) return resp - async def acompletion(self, messages: list[dict]) -> dict: + async def acompletion(self, messages: list[dict], timeout=3) -> dict: return await self._achat_completion(messages) async def _achat_completion_stream(self, messages: list[dict]) -> str: @@ -147,9 +131,7 @@ class OllamaGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text( - self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 - ) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/provider/open_llm_api.py b/metagpt/provider/open_llm_api.py index dd1491780..2893f5b30 100644 --- a/metagpt/provider/open_llm_api.py +++ b/metagpt/provider/open_llm_api.py @@ -7,7 +7,7 @@ from openai.types import CompletionUsage from metagpt.config import CONFIG, Config, LLMProviderEnum from metagpt.logs import logger from metagpt.provider.llm_provider_registry import register_provider -from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter +from metagpt.provider.openai_api import OpenAILLM from metagpt.utils.cost_manager import CostManager, Costs from metagpt.utils.token_counter import count_message_tokens, count_string_tokens @@ -35,18 +35,17 @@ class OpenLLMCostManager(CostManager): @register_provider(LLMProviderEnum.OPEN_LLM) -class OpenLLMGPTAPI(OpenAIGPTAPI): +class OpenLLMGPTAPI(OpenAILLM): def __init__(self): self.config: Config = CONFIG self.__init_openllm() self.auto_max_tokens = False self._cost_manager = OpenLLMCostManager() - RateLimiter.__init__(self, rpm=self.rpm) def __init_openllm(self): self.is_azure = False self.rpm = int(self.config.get("RPM", 10)) - self._make_client() + self._init_client() self.model = self.config.open_llm_api_model # `self.model` should after `_make_client` to rewrite it def _make_client_kwargs(self) -> (dict, dict): diff --git a/metagpt/provider/openai_api.py b/metagpt/provider/openai_api.py index 195d2ea16..64adbb1c0 100644 --- a/metagpt/provider/openai_api.py +++ b/metagpt/provider/openai_api.py @@ -3,20 +3,17 @@ @Time : 2023/5/5 23:08 @Author : alexanderwu @File : openai.py -@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation; +@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for isolation; Change cost control from global to company level. @Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout. @Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x. """ -import asyncio import json -import time -from typing import List, Union +from typing import AsyncIterator, Union -import openai -from openai import APIConnectionError, AsyncOpenAI, AsyncStream, OpenAI -from openai._base_client import AsyncHttpxClientWrapper, SyncHttpxClientWrapper +from openai import APIConnectionError, AsyncOpenAI, AsyncStream +from openai._base_client import AsyncHttpxClientWrapper from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from tenacity import ( @@ -28,9 +25,8 @@ from tenacity import ( ) from metagpt.config import CONFIG, Config, LLMProviderEnum -from metagpt.const import DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE from metagpt.provider.llm_provider_registry import register_provider from metagpt.schema import Message @@ -43,31 +39,6 @@ from metagpt.utils.token_counter import ( ) -class RateLimiter: - """Rate control class, each call goes through wait_if_needed, sleep if rate control is needed""" - - def __init__(self, rpm): - self.last_call_time = 0 - # Here 1.1 is used because even if the calls are made strictly according to time, - # they will still be QOS'd; consider switching to simple error retry later - self.interval = 1.1 * 60 / rpm - self.rpm = rpm - - def split_batches(self, batch): - return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)] - - async def wait_if_needed(self, num_requests): - current_time = time.time() - elapsed_time = current_time - self.last_call_time - - if elapsed_time < self.interval * num_requests: - remaining_time = self.interval * num_requests - elapsed_time - logger.info(f"sleep {remaining_time}") - await asyncio.sleep(remaining_time) - - self.last_call_time = time.time() - - def log_and_reraise(retry_state): logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}") logger.warning( @@ -80,39 +51,31 @@ See FAQ 5.8 @register_provider(LLMProviderEnum.OPENAI) -class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): - """ - Check https://platform.openai.com/examples for examples - """ +class OpenAILLM(BaseLLM): + """Check https://platform.openai.com/examples for examples""" def __init__(self): self.config: Config = CONFIG self._init_openai() + self._init_client() self.auto_max_tokens = False - RateLimiter.__init__(self, rpm=self.rpm) def _init_openai(self): - self.rpm = int(self.config.RPM or 10) - self._make_client() - - def _make_client(self): - kwargs, async_kwargs = self._make_client_kwargs() - # https://github.com/openai/openai-python#async-usage - self.client = OpenAI(**kwargs) - self.async_client = AsyncOpenAI(**async_kwargs) self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs - def _make_client_kwargs(self) -> (dict, dict): - kwargs = dict(api_key=self.config.OPENAI_API_KEY, base_url=self.config.OPENAI_BASE_URL) - async_kwargs = kwargs.copy() + def _init_client(self): + """https://github.com/openai/openai-python#async-usage""" + kwargs = self._make_client_kwargs() + self.aclient = AsyncOpenAI(**kwargs) + + def _make_client_kwargs(self) -> dict: + kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL} # to use proxy, openai v1 needs http_client - proxy_params = self._get_proxy_params() - if proxy_params: - kwargs["http_client"] = SyncHttpxClientWrapper(**proxy_params) - async_kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) + if proxy_params := self._get_proxy_params(): + kwargs["http_client"] = AsyncHttpxClientWrapper(**proxy_params) - return kwargs, async_kwargs + return kwargs def _get_proxy_params(self) -> dict: params = {} @@ -123,8 +86,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return params - async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> str: - response: AsyncStream[ChatCompletionChunk] = await self.async_client.chat.completions.create( + async def _achat_completion_stream(self, messages: list[dict], timeout=3) -> AsyncIterator[str]: + response: AsyncStream[ChatCompletionChunk] = await self.aclient.chat.completions.create( **self._cons_kwargs(messages, timeout=timeout), stream=True ) @@ -132,35 +95,26 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): chunk_message = chunk.choices[0].delta.content or "" if chunk.choices else "" # extract the message yield chunk_message - def _cons_kwargs(self, messages: list[dict], timeout=3, **configs) -> dict: + def _cons_kwargs(self, messages: list[dict], timeout=3, **extra_kwargs) -> dict: kwargs = { "messages": messages, - "max_tokens": self.get_max_tokens(messages), + "max_tokens": self._get_max_tokens(messages), "n": 1, "stop": None, "temperature": 0.3, "model": self.model, + "timeout": max(CONFIG.timeout, timeout), } - if configs: - kwargs.update(configs) - kwargs["timeout"] = max(CONFIG.timeout, timeout) - + if extra_kwargs: + kwargs.update(extra_kwargs) return kwargs async def _achat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion: kwargs = self._cons_kwargs(messages, timeout=timeout) - rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs) + rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs) self._update_costs(rsp.usage) return rsp - def _chat_completion(self, messages: list[dict], timeout=3) -> ChatCompletion: - rsp: ChatCompletion = self.client.chat.completions.create(**self._cons_kwargs(messages, timeout=timeout)) - self._update_costs(rsp.usage) - return rsp - - def completion(self, messages: list[dict], timeout=3) -> ChatCompletion: - return self._chat_completion(messages, timeout=timeout) - async def acompletion(self, messages: list[dict], timeout=3) -> ChatCompletion: return await self._achat_completion(messages, timeout=timeout) @@ -171,12 +125,10 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): retry=retry_if_exception_type(APIConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """when streaming, print each token in place.""" if stream: resp = self._achat_completion_stream(messages, timeout=timeout) - if generator: - return resp collected_messages = [] async for i in resp: @@ -192,9 +144,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return self.get_choice_text(rsp) def _func_configs(self, messages: list[dict], timeout=3, **kwargs) -> dict: - """ - Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create - """ + """Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create""" if "tools" not in kwargs: configs = { "tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}], @@ -204,14 +154,9 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): return self._cons_kwargs(messages=messages, timeout=timeout, **kwargs) - def _chat_completion_function(self, messages: list[dict], timeout=3, **kwargs) -> ChatCompletion: - rsp: ChatCompletion = self.client.chat.completions.create(**self._func_configs(messages, **kwargs)) - self._update_costs(rsp.usage) - return rsp - async def _achat_completion_function(self, messages: list[dict], timeout=3, **chat_configs) -> ChatCompletion: kwargs = self._func_configs(messages=messages, timeout=timeout, **chat_configs) - rsp: ChatCompletion = await self.async_client.chat.completions.create(**kwargs) + rsp: ChatCompletion = await self.aclient.chat.completions.create(**kwargs) self._update_costs(rsp.usage) return rsp @@ -231,56 +176,28 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): ) return messages - def ask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: - """Use function of tools to ask a code. - - Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create - - Examples: - - >>> llm = OpenAIGPTAPI() - >>> llm.ask_code("Write a python hello world code.") - {'language': 'python', 'code': "print('Hello, World!')"} - >>> msg = [{'role': 'user', 'content': "Write a python hello world code."}] - >>> llm.ask_code(msg) - {'language': 'python', 'code': "print('Hello, World!')"} - """ - messages = self._process_message(messages) - rsp = self._chat_completion_function(messages, **kwargs) - return self.get_choice_function_arguments(rsp) - async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: """Use function of tools to ask a code. - - Note: Keep kwargs consistent with the parameters in the https://platform.openai.com/docs/api-reference/chat/create + Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create Examples: - - >>> llm = OpenAIGPTAPI() - >>> rsp = await llm.ask_code("Write a python hello world code.") - >>> rsp - {'language': 'python', 'code': "print('Hello, World!')"} + >>> llm = OpenAILLM() >>> msg = [{'role': 'user', 'content': "Write a python hello world code."}] - >>> rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} + >>> rsp = await llm.aask_code(msg) + # -> {'language': 'python', 'code': "print('Hello, World!')"} """ messages = self._process_message(messages) - try: - rsp = await self._achat_completion_function(messages, **kwargs) - return self.get_choice_function_arguments(rsp) - except openai.BadRequestError as e: - logger.error(f"API TYPE:{CONFIG.OPENAI_API_TYPE}, err:{e}") - raise e + rsp = await self._achat_completion_function(messages, **kwargs) + return self.get_choice_function_arguments(rsp) + @handle_exception def get_choice_function_arguments(self, rsp: ChatCompletion) -> dict: """Required to provide the first function arguments of choice. :return dict: return the first function arguments of choice, for example, {'language': 'python', 'code': "print('Hello, World!')"} """ - try: - return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments) - except json.JSONDecodeError: - return {} + return json.loads(rsp.choices[0].message.tool_calls[0].function.arguments) def get_choice_text(self, rsp: ChatCompletion) -> str: """Required to provide the first text of choice""" @@ -295,134 +212,24 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter): usage.prompt_tokens = count_message_tokens(messages, self.model) usage.completion_tokens = count_string_tokens(rsp, self.model) except Exception as e: - logger.error(f"usage calculation failed!: {e}") + logger.error(f"usage calculation failed: {e}") return usage - async def acompletion_batch(self, batch: list[list[dict]], timeout=3) -> list[ChatCompletion]: - """Return full JSON""" - split_batches = self.split_batches(batch) - all_results = [] - - for small_batch in split_batches: - logger.info(small_batch) - await self.wait_if_needed(len(small_batch)) - - future = [self.acompletion(prompt, timeout=timeout) for prompt in small_batch] - results = await asyncio.gather(*future) - logger.info(results) - all_results.extend(results) - - return all_results - - async def acompletion_batch_text(self, batch: list[list[dict]], timeout=3) -> list[str]: - """Only return plain text""" - raw_results = await self.acompletion_batch(batch, timeout=timeout) - results = [] - for idx, raw_result in enumerate(raw_results, start=1): - result = self.get_choice_text(raw_result) - results.append(result) - logger.info(f"Result of task {idx}: {result}") - return results - + @handle_exception def _update_costs(self, usage: CompletionUsage): if CONFIG.calc_usage and usage: - try: - CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) - except Exception as e: - logger.error(f"updating costs failed!, exp: {e}") + CONFIG.cost_manager.update_cost(usage.prompt_tokens, usage.completion_tokens, self.model) def get_costs(self) -> Costs: return CONFIG.cost_manager.get_costs() - def get_max_tokens(self, messages: list[dict]): + def _get_max_tokens(self, messages: list[dict]): if not self.auto_max_tokens: return CONFIG.max_tokens_rsp return get_max_completion_tokens(messages, self.model, CONFIG.max_tokens_rsp) - def moderation(self, content: Union[str, list[str]]): - return self.client.moderations.create(input=content) - @handle_exception async def amoderation(self, content: Union[str, list[str]]): - return await self.async_client.moderations.create(input=content) - - async def close(self): - """Close connection""" - if self.client: - self.client.close() - self.client = None - if self.async_client: - await self.async_client.close() - self.async_client = None - - async def summarize(self, text: str, max_words=200, keep_language: bool = False, limit: int = -1) -> str: - max_token_count = DEFAULT_MAX_TOKENS - max_count = 100 - text_length = len(text) - if limit > 0 and text_length < limit: - return text - summary = "" - while max_count > 0: - if text_length < max_token_count: - summary = await self._get_summary(text=text, max_words=max_words, keep_language=keep_language) - break - - padding_size = 20 if max_token_count > 20 else 0 - text_windows = self.split_texts(text, window_size=max_token_count - padding_size) - part_max_words = min(int(max_words / len(text_windows)) + 1, 100) - summaries = [] - for ws in text_windows: - response = await self._get_summary(text=ws, max_words=part_max_words, keep_language=keep_language) - summaries.append(response) - if len(summaries) == 1: - summary = summaries[0] - break - - # Merged and retry - text = "\n".join(summaries) - text_length = len(text) - - max_count -= 1 # safeguard - return summary - - async def _get_summary(self, text: str, max_words=20, keep_language: bool = False): - """Generate text summary""" - if len(text) < max_words: - return text - if keep_language: - command = f".Translate the above content into a summary of less than {max_words} words in language of the content strictly." - else: - command = f"Translate the above content into a summary of less than {max_words} words." - msg = text + "\n\n" + command - logger.debug(f"summary ask:{msg}") - response = await self.aask(msg=msg, system_msgs=[]) - logger.debug(f"summary rsp: {response}") - return response - - @staticmethod - def split_texts(text: str, window_size) -> List[str]: - """Splitting long text into sliding windows text""" - if window_size <= 0: - window_size = DEFAULT_TOKEN_SIZE - total_len = len(text) - if total_len <= window_size: - return [text] - - padding_size = 20 if window_size > 20 else 0 - windows = [] - idx = 0 - data_len = window_size - padding_size - while idx < total_len: - if window_size + idx > total_len: # 不足一个滑窗 - windows.append(text[idx:]) - break - # 每个窗口少算padding_size自然就可实现滑窗功能, 比如: [1, 2, 3, 4, 5, 6, 7, ....] - # window_size=3, padding_size=1: - # [1, 2, 3], [3, 4, 5], [5, 6, 7], .... - # idx=2, | idx=5 | idx=8 | ... - w = text[idx : idx + window_size] - windows.append(w) - idx += data_len - - return windows + """Moderate content.""" + return await self.aclient.moderations.create(input=content) diff --git a/metagpt/provider/spark_api.py b/metagpt/provider/spark_api.py index 70076bc86..ce889529a 100644 --- a/metagpt/provider/spark_api.py +++ b/metagpt/provider/spark_api.py @@ -1,9 +1,7 @@ #!/usr/bin/env python # -*- coding: utf-8 -*- """ -@Time : 2023/7/21 11:15 -@Author : Leo Xiao -@File : anthropic_api.py +@File : spark_api.py """ import _thread as thread import base64 @@ -13,7 +11,6 @@ import hmac import json import ssl from time import mktime -from typing import Optional from urllib.parse import urlencode, urlparse from wsgiref.handlers import format_date_time @@ -21,52 +18,29 @@ import websocket # 使用websocket_client from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider @register_provider(LLMProviderEnum.SPARK) -class SparkGPTAPI(BaseGPTAPI): +class SparkLLM(BaseLLM): def __init__(self): logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。") - def close(self): - pass - - def ask(self, msg: str) -> str: - message = [self._default_system_msg(), self._user_msg(msg)] - rsp = self.completion(message) - return rsp - - async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str: - if system_msgs: - message = self._system_msgs(system_msgs) + [self._user_msg(msg)] - else: - message = [self._default_system_msg(), self._user_msg(msg)] - rsp = await self.acompletion(message) - logger.debug(message) - return rsp - def get_choice_text(self, rsp: dict) -> str: return rsp["payload"]["choices"]["text"][-1]["content"] - async def acompletion_text( - self, messages: list[dict], stream=False, generator: bool = False, timeout: int = 3 - ) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout: int = 3) -> str: # 不支持 logger.error("该功能禁用。") w = GetMessageFromWeb(messages) return w.run() - async def acompletion(self, messages: list[dict]): + async def acompletion(self, messages: list[dict], timeout=3): # 不支持异步 w = GetMessageFromWeb(messages) return w.run() - def completion(self, messages: list[dict]): - w = GetMessageFromWeb(messages) - return w.run() - class GetMessageFromWeb: class WsParam: diff --git a/metagpt/provider/zhipuai_api.py b/metagpt/provider/zhipuai_api.py index 8d57cd444..e4b066a0c 100644 --- a/metagpt/provider/zhipuai_api.py +++ b/metagpt/provider/zhipuai_api.py @@ -17,7 +17,7 @@ from tenacity import ( from metagpt.config import CONFIG, LLMProviderEnum from metagpt.logs import log_llm_stream, logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.provider.llm_provider_registry import register_provider from metagpt.provider.openai_api import log_and_reraise from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI @@ -31,7 +31,7 @@ class ZhiPuEvent(Enum): @register_provider(LLMProviderEnum.ZHIPUAI) -class ZhiPuAIGPTAPI(BaseGPTAPI): +class ZhiPuAIGPTAPI(BaseLLM): """ Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo` From now, there is only one model named `chatglm_turbo` @@ -131,7 +131,7 @@ class ZhiPuAIGPTAPI(BaseGPTAPI): retry=retry_if_exception_type(ConnectionError), retry_error_callback=log_and_reraise, ) - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """response in async with stream or non-stream mode""" if stream: return await self._achat_completion_stream(messages) diff --git a/metagpt/roles/role.py b/metagpt/roles/role.py index 3e5f268f8..d6e874ffe 100644 --- a/metagpt/roles/role.py +++ b/metagpt/roles/role.py @@ -36,7 +36,7 @@ from metagpt.const import SERDESER_PATH from metagpt.llm import LLM, HumanProvider from metagpt.logs import logger from metagpt.memory import Memory -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message, MessageQueue from metagpt.utils.common import ( any_to_name, @@ -141,7 +141,7 @@ class Role(BaseModel): desc: str = "" is_human: bool = False - _llm: BaseGPTAPI = Field(default_factory=LLM) # Each role has its own LLM, use different system message + _llm: BaseLLM = Field(default_factory=LLM) # Each role has its own LLM, use different system message _role_id: str = "" _states: list[str] = [] _actions: list[Action] = [] diff --git a/metagpt/roles/sk_agent.py b/metagpt/roles/sk_agent.py index 6063205bd..d982ebb68 100644 --- a/metagpt/roles/sk_agent.py +++ b/metagpt/roles/sk_agent.py @@ -19,7 +19,7 @@ from metagpt.actions import UserRequirement from metagpt.actions.execute_task import ExecuteTask from metagpt.llm import LLM from metagpt.logs import logger -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.roles import Role from metagpt.schema import Message from metagpt.utils.make_sk_kernel import make_sk_kernel @@ -44,7 +44,7 @@ class SkAgent(Role): plan: Any = None planner_cls: Any = None planner: Any = None - llm: BaseGPTAPI = Field(default_factory=LLM) + llm: BaseLLM = Field(default_factory=LLM) kernel: Kernel = Field(default_factory=Kernel) import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None import_skill: Type[Kernel.import_skill] = None diff --git a/metagpt/tools/openai_text_to_image.py b/metagpt/tools/openai_text_to_image.py index fcfa86c7d..bd6078245 100644 --- a/metagpt/tools/openai_text_to_image.py +++ b/metagpt/tools/openai_text_to_image.py @@ -21,10 +21,6 @@ class OpenAIText2Image: """ self._llm = LLM() - def __del__(self): - if self._llm: - self._llm.close() - async def text_2_image(self, text, size_type="1024x1024"): """Text to image diff --git a/metagpt/tools/ut_writer.py b/metagpt/tools/ut_writer.py index 64423dfb1..d6d190ad7 100644 --- a/metagpt/tools/ut_writer.py +++ b/metagpt/tools/ut_writer.py @@ -4,7 +4,7 @@ import json from pathlib import Path -from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI +from metagpt.provider.openai_api import OpenAILLM as GPTAPI ICL_SAMPLE = """Interface definition: ```text @@ -278,11 +278,11 @@ class UTGenerator: question += self.build_api_doc(node, path, method) self.ask_gpt_and_save(question, tag, summary) - def gpt_msgs_to_code(self, messages: list) -> str: + async def gpt_msgs_to_code(self, messages: list) -> str: """Choose based on different calling methods""" result = "" if self.chatgpt_method == "API": - result = GPTAPI().ask_code(msgs=messages) + result = await GPTAPI().aask_code(msgs=messages) return result diff --git a/tests/metagpt/actions/test_design_api.py b/tests/metagpt/actions/test_design_api.py index fe98b9120..8d4720570 100644 --- a/tests/metagpt/actions/test_design_api.py +++ b/tests/metagpt/actions/test_design_api.py @@ -22,9 +22,9 @@ async def test_design_api(): for prd in inputs: await FileRepository.save_file("new_prd.txt", content=prd, relative_path=PRDS_FILE_REPO) - design_api = WriteDesign("design_api") + design_api = WriteDesign() - result = await design_api.run([Message(content=prd, instruct_content=None)]) + result = await design_api.run(Message(content=prd, instruct_content=None)) logger.info(result) assert result diff --git a/tests/metagpt/actions/test_project_management.py b/tests/metagpt/actions/test_project_management.py index 13e6d2247..88263ff29 100644 --- a/tests/metagpt/actions/test_project_management.py +++ b/tests/metagpt/actions/test_project_management.py @@ -6,10 +6,26 @@ @File : test_project_management.py """ +import pytest -class TestCreateProjectPlan: - pass +from metagpt.actions.project_management import WriteTasks +from metagpt.config import CONFIG +from metagpt.const import PRDS_FILE_REPO, SYSTEM_DESIGN_FILE_REPO +from metagpt.logs import logger +from metagpt.schema import Message +from metagpt.utils.file_repository import FileRepository +from tests.metagpt.actions.mock_json import DESIGN, PRD -class TestAssignTasks: - pass +@pytest.mark.asyncio +async def test_design_api(): + await FileRepository.save_file("1.txt", content=str(PRD), relative_path=PRDS_FILE_REPO) + await FileRepository.save_file("1.txt", content=str(DESIGN), relative_path=SYSTEM_DESIGN_FILE_REPO) + logger.info(CONFIG.git_repo) + + action = WriteTasks() + + result = await action.run(Message(content="", instruct_content=None)) + logger.info(result) + + assert result diff --git a/tests/metagpt/actions/test_write_code.py b/tests/metagpt/actions/test_write_code.py index ba7cb6f2d..40a3b44ed 100644 --- a/tests/metagpt/actions/test_write_code.py +++ b/tests/metagpt/actions/test_write_code.py @@ -10,7 +10,7 @@ import pytest from metagpt.actions.write_code import WriteCode from metagpt.logs import logger -from metagpt.provider.openai_api import OpenAIGPTAPI as LLM +from metagpt.provider.openai_api import OpenAILLM as LLM from metagpt.schema import CodingContext, Document from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE diff --git a/tests/metagpt/provider/test_base_gpt_api.py b/tests/metagpt/provider/test_base_gpt_api.py index aaa7b64ff..be2c0ea7a 100644 --- a/tests/metagpt/provider/test_base_gpt_api.py +++ b/tests/metagpt/provider/test_base_gpt_api.py @@ -8,7 +8,7 @@ import pytest -from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.provider.base_llm import BaseLLM from metagpt.schema import Message default_chat_resp = { @@ -27,14 +27,14 @@ prompt_msg = "who are you" resp_content = default_chat_resp["choices"][0]["message"]["content"] -class MockBaseGPTAPI(BaseGPTAPI): +class MockBaseGPTAPI(BaseLLM): def completion(self, messages: list[dict], timeout=3): return default_chat_resp async def acompletion(self, messages: list[dict], timeout=3): return default_chat_resp - async def acompletion_text(self, messages: list[dict], stream=False, generator: bool = False, timeout=3) -> str: + async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: return resp_content async def close(self): @@ -47,11 +47,6 @@ def test_base_gpt_api(): assert "user" in str(message) base_gpt_api = MockBaseGPTAPI() - msg_prompt = base_gpt_api.messages_to_prompt([message]) - assert msg_prompt == "user: hello" - - msg_dict = base_gpt_api.messages_to_dict([message]) - assert msg_dict == [{"role": "user", "content": "hello"}] openai_funccall_resp = { "choices": [ @@ -87,14 +82,14 @@ def test_base_gpt_api(): choice_text = base_gpt_api.get_choice_text(openai_funccall_resp) assert choice_text == openai_funccall_resp["choices"][0]["message"]["content"] - resp = base_gpt_api.ask(prompt_msg) - assert resp == resp_content + # resp = base_gpt_api.ask(prompt_msg) + # assert resp == resp_content - resp = base_gpt_api.ask_batch([prompt_msg]) - assert resp == resp_content + # resp = base_gpt_api.ask_batch([prompt_msg]) + # assert resp == resp_content - resp = base_gpt_api.ask_code([prompt_msg]) - assert resp == resp_content + # resp = base_gpt_api.ask_code([prompt_msg]) + # assert resp == resp_content @pytest.mark.asyncio diff --git a/tests/metagpt/provider/test_fireworks_api.py b/tests/metagpt/provider/test_fireworks_api.py index caf8b9f45..00b3c716a 100644 --- a/tests/metagpt/provider/test_fireworks_api.py +++ b/tests/metagpt/provider/test_fireworks_api.py @@ -13,7 +13,7 @@ from openai.types.completion_usage import CompletionUsage from metagpt.provider.fireworks_api import ( MODEL_GRADE_TOKEN_COSTS, FireworksCostManager, - FireWorksGPTAPI, + FireworksLLM, ) resp_content = "I'm fireworks" @@ -55,17 +55,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: return default_resp.choices[0].message.content -def test_fireworks_completion(mocker): - mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.completion", mock_llm_completion) - fireworks_gpt = FireWorksGPTAPI() - - resp = fireworks_gpt.completion(messages) - assert resp.choices[0].message.content == resp_content - - resp = fireworks_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_fireworks_acompletion(mocker): mocker.patch("metagpt.provider.fireworks_api.FireWorksGPTAPI.acompletion", mock_llm_acompletion) @@ -73,7 +62,7 @@ async def test_fireworks_acompletion(mocker): mocker.patch( "metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream ) - fireworks_gpt = FireWorksGPTAPI() + fireworks_gpt = FireworksLLM() resp = await fireworks_gpt.acompletion(messages, stream=False) assert resp.choices[0].message.content in resp_content diff --git a/tests/metagpt/provider/test_google_gemini_api.py b/tests/metagpt/provider/test_google_gemini_api.py index aec7b8520..60f50c9ad 100644 --- a/tests/metagpt/provider/test_google_gemini_api.py +++ b/tests/metagpt/provider/test_google_gemini_api.py @@ -35,16 +35,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: return resp_content -def test_gemini_completion(mocker): - mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion) - gemini_gpt = GeminiGPTAPI() - resp = gemini_gpt.completion(messages) - assert resp.text == resp_content - - resp = gemini_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion) diff --git a/tests/metagpt/provider/test_human_provider.py b/tests/metagpt/provider/test_human_provider.py index caab9f15f..8ba532781 100644 --- a/tests/metagpt/provider/test_human_provider.py +++ b/tests/metagpt/provider/test_human_provider.py @@ -17,15 +17,6 @@ async def mock_llm_aask(msg: str, timeout: int = 3) -> str: return mock_llm_ask(msg) -def test_human_provider(mocker): - mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask) - human_provider = HumanProvider() - - assert resp_content == human_provider.ask(None) - - assert not human_provider.completion(messages=[]) - - @pytest.mark.asyncio async def test_async_human_provider(mocker): mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask) diff --git a/tests/metagpt/provider/test_ollama_api.py b/tests/metagpt/provider/test_ollama_api.py index d552d9f9e..d19e23e17 100644 --- a/tests/metagpt/provider/test_ollama_api.py +++ b/tests/metagpt/provider/test_ollama_api.py @@ -5,7 +5,7 @@ import pytest from metagpt.config import CONFIG -from metagpt.provider.ollama_api import OllamaGPTAPI +from metagpt.provider.ollama_api import OllamaLLM prompt_msg = "who are you" messages = [{"role": "user", "content": prompt_msg}] @@ -28,22 +28,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: return resp_content -def test_gemini_completion(mocker): - mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion) - ollama_gpt = OllamaGPTAPI() - resp = ollama_gpt.completion(messages) - assert resp["message"]["content"] == default_resp["message"]["content"] - - resp = ollama_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_gemini_acompletion(mocker): mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion) mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion) mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream) - ollama_gpt = OllamaGPTAPI() + ollama_gpt = OllamaLLM() resp = await ollama_gpt.acompletion(messages) assert resp["message"]["content"] == default_resp["message"]["content"] diff --git a/tests/metagpt/provider/test_openai.py b/tests/metagpt/provider/test_openai.py index 1f25951b1..329edadff 100644 --- a/tests/metagpt/provider/test_openai.py +++ b/tests/metagpt/provider/test_openai.py @@ -2,13 +2,13 @@ from unittest.mock import Mock import pytest -from metagpt.provider.openai_api import OpenAIGPTAPI +from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import UserMessage @pytest.mark.asyncio async def test_aask_code(): - llm = OpenAIGPTAPI() + llm = OpenAILLM() msg = [{"role": "user", "content": "Write a python hello world code."}] rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} assert "language" in rsp @@ -18,7 +18,7 @@ async def test_aask_code(): @pytest.mark.asyncio async def test_aask_code_str(): - llm = OpenAIGPTAPI() + llm = OpenAILLM() msg = "Write a python hello world code." rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} assert "language" in rsp @@ -28,7 +28,7 @@ async def test_aask_code_str(): @pytest.mark.asyncio async def test_aask_code_Message(): - llm = OpenAIGPTAPI() + llm = OpenAILLM() msg = UserMessage("Write a python hello world code.") rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} assert "language" in rsp @@ -36,52 +36,6 @@ async def test_aask_code_Message(): assert len(rsp["code"]) > 0 -def test_ask_code(): - llm = OpenAIGPTAPI() - msg = [{"role": "user", "content": "Write a python hello world code."}] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_str(): - llm = OpenAIGPTAPI() - msg = "Write a python hello world code." - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_Message(): - llm = OpenAIGPTAPI() - msg = UserMessage("Write a python hello world code.") - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_list_Message(): - llm = OpenAIGPTAPI() - msg = [UserMessage("a=[1,2,5,10,-10]"), UserMessage("写出求a中最大值的代码python")] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - -def test_ask_code_list_str(): - llm = OpenAIGPTAPI() - msg = ["a=[1,2,5,10,-10]", "写出求a中最大值的代码python"] - rsp = llm.ask_code(msg) # -> {'language': 'python', 'code': 'max_value = max(a)\nmax_value'} - print(rsp) - assert "language" in rsp - assert "code" in rsp - assert len(rsp["code"]) > 0 - - class TestOpenAI: @pytest.fixture def config(self): @@ -130,7 +84,7 @@ class TestOpenAI: ) def test_make_client_kwargs_without_proxy(self, config): - instance = OpenAIGPTAPI() + instance = OpenAILLM() instance.config = config kwargs, async_kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} @@ -139,7 +93,7 @@ class TestOpenAI: assert "http_client" not in async_kwargs def test_make_client_kwargs_without_proxy_azure(self, config_azure): - instance = OpenAIGPTAPI() + instance = OpenAILLM() instance.config = config_azure kwargs, async_kwargs = instance._make_client_kwargs() assert kwargs == {"api_key": "test_key", "base_url": "test_url"} @@ -148,14 +102,14 @@ class TestOpenAI: assert "http_client" not in async_kwargs def test_make_client_kwargs_with_proxy(self, config_proxy): - instance = OpenAIGPTAPI() + instance = OpenAILLM() instance.config = config_proxy kwargs, async_kwargs = instance._make_client_kwargs() assert "http_client" in kwargs assert "http_client" in async_kwargs def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy): - instance = OpenAIGPTAPI() + instance = OpenAILLM() instance.config = config_azure_proxy kwargs, async_kwargs = instance._make_client_kwargs() assert "http_client" in kwargs diff --git a/tests/metagpt/provider/test_spark_api.py b/tests/metagpt/provider/test_spark_api.py index 61ae8cbec..6cc87741e 100644 --- a/tests/metagpt/provider/test_spark_api.py +++ b/tests/metagpt/provider/test_spark_api.py @@ -4,7 +4,7 @@ import pytest -from metagpt.provider.spark_api import SparkGPTAPI +from metagpt.provider.spark_api import SparkLLM prompt_msg = "who are you" resp_content = "I'm Spark" @@ -18,24 +18,13 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, return resp_content -def test_spark_completion(mocker): - mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion) - spark_gpt = SparkGPTAPI() - - resp = spark_gpt.completion([]) - assert resp == resp_content - - resp = spark_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_spark_acompletion(mocker): mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion) mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion) - spark_gpt = SparkGPTAPI() + spark_gpt = SparkLLM() - resp = await spark_gpt.acompletion([], stream=False) + resp = await spark_gpt.acompletion([]) assert resp == resp_content resp = await spark_gpt.aask(prompt_msg, stream=False) diff --git a/tests/metagpt/provider/test_zhipuai_api.py b/tests/metagpt/provider/test_zhipuai_api.py index ec02e1b47..d9cd23281 100644 --- a/tests/metagpt/provider/test_zhipuai_api.py +++ b/tests/metagpt/provider/test_zhipuai_api.py @@ -28,18 +28,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str: return resp_content -def test_zhipuai_completion(mocker): - mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion) - zhipu_gpt = ZhiPuAIGPTAPI() - - resp = zhipu_gpt.completion(messages) - assert resp["code"] == 200 - assert resp["data"]["choices"][0]["content"] == resp_content - - resp = zhipu_gpt.ask(prompt_msg) - assert resp == resp_content - - @pytest.mark.asyncio async def test_zhipuai_acompletion(mocker): mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion) diff --git a/tests/metagpt/test_gpt.py b/tests/metagpt/test_gpt.py index 1884dd54b..2b19f173d 100644 --- a/tests/metagpt/test_gpt.py +++ b/tests/metagpt/test_gpt.py @@ -14,23 +14,6 @@ from metagpt.logs import logger @pytest.mark.usefixtures("llm_api") class TestGPT: - def test_llm_api_ask(self, llm_api): - answer = llm_api.ask("hello chatgpt") - logger.info(answer) - assert len(answer) > 0 - - def test_gptapi_ask_batch(self, llm_api): - answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60) - assert len(answer) > 0 - - def test_llm_api_ask_code(self, llm_api): - try: - answer = llm_api.ask_code(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"]) - logger.info(answer) - assert len(answer) > 0 - except openai.BadRequestError: - assert CONFIG.OPENAI_API_TYPE == "azure" - @pytest.mark.asyncio async def test_llm_api_aask(self, llm_api): answer = await llm_api.aask("hello chatgpt", stream=False) diff --git a/tests/metagpt/test_llm.py b/tests/metagpt/test_llm.py index 31e6c2b24..247f043e2 100644 --- a/tests/metagpt/test_llm.py +++ b/tests/metagpt/test_llm.py @@ -9,7 +9,7 @@ import pytest -from metagpt.provider.openai_api import OpenAIGPTAPI as LLM +from metagpt.provider.openai_api import OpenAILLM as LLM @pytest.fixture() @@ -23,18 +23,11 @@ async def test_llm_aask(llm): assert len(rsp) > 0 -@pytest.mark.asyncio -async def test_llm_aask_batch(llm): - assert len(await llm.aask_batch(["hi", "write python hello world."])) > 0 - - @pytest.mark.asyncio async def test_llm_acompletion(llm): hello_msg = [{"role": "user", "content": "hello"}] rsp = await llm.acompletion(hello_msg) assert len(rsp.choices[0].message.content) > 0 - assert len(await llm.acompletion_batch([hello_msg])) > 0 - assert len(await llm.acompletion_batch_text([hello_msg])) > 0 if __name__ == "__main__": diff --git a/tests/metagpt/utils/test_custom_aio_session.py b/tests/metagpt/utils/test_custom_aio_session.py deleted file mode 100644 index e2876e4b8..000000000 --- a/tests/metagpt/utils/test_custom_aio_session.py +++ /dev/null @@ -1,21 +0,0 @@ -#!/usr/bin/env python -# -*- coding: utf-8 -*- -""" -@Time : 2023/5/7 17:23 -@Author : alexanderwu -@File : test_custom_aio_session.py -""" -from metagpt.logs import logger -from metagpt.provider.openai_api import OpenAIGPTAPI - - -async def try_hello(api): - batch = [[{"role": "user", "content": "hello"}]] - results = await api.acompletion_batch_text(batch) - return results - - -async def aask_batch(api: OpenAIGPTAPI): - results = await api.aask_batch(["hi", "write python hello world."]) - logger.info(results) - return results