mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
refine code
This commit is contained in:
parent
ba8bf01870
commit
0435b1321f
53 changed files with 118 additions and 289 deletions
|
|
@ -14,7 +14,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import (
|
||||
CodeSummarizeContext,
|
||||
CodingContext,
|
||||
|
|
@ -27,7 +27,7 @@ action_subclass_registry = {}
|
|||
|
||||
class Action(BaseModel):
|
||||
name: str = ""
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
|
||||
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
|
||||
prefix = "" # aask*时会加上prefix,作为system_message
|
||||
desc = "" # for skill manager
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from pydantic import BaseModel, create_model, root_validator, validator
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import BaseGPTAPI
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
|
||||
from metagpt.utils.common import OutputParser, general_after_log
|
||||
|
|
@ -60,7 +60,7 @@ class ActionNode:
|
|||
|
||||
# Action Context
|
||||
context: str # all the context, including all necessary info
|
||||
llm: BaseGPTAPI # LLM with aask interface
|
||||
llm: BaseLLM # LLM with aask interface
|
||||
children: dict[str, "ActionNode"]
|
||||
|
||||
# Action Input
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from pydantic import Field
|
|||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.highlight import highlight
|
||||
|
|
@ -33,7 +33,7 @@ def run(*args) -> pd.DataFrame:
|
|||
class CloneFunction(WriteCode):
|
||||
name: str = "CloneFunction"
|
||||
context: list[Message] = []
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def _save(self, code_path, code):
|
||||
if isinstance(code_path, str):
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ from pydantic import Field
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -52,7 +52,7 @@ Now you should start rewriting the code:
|
|||
class DebugError(Action):
|
||||
name: str = "DebugError"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await FileRepository.get_file(
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
|
|
@ -44,7 +44,7 @@ NEW_REQ_TEMPLATE = """
|
|||
class WriteDesign(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = (
|
||||
"Based on the PRD, think about the system design, and design the corresponding APIs, "
|
||||
"data structures, library tables, processes, and paths. Please provide your design, feedback "
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class DesignReview(Action):
|
||||
name: str = "DesignReview"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, prd, api_design):
|
||||
prompt = (
|
||||
|
|
|
|||
|
|
@ -10,14 +10,14 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class ExecuteTask(Action):
|
||||
name: str = "ExecuteTask"
|
||||
context: list[Message] = []
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ from metagpt.prompts.invoice_ocr import (
|
|||
EXTRACT_OCR_MAIN_INFO_PROMPT,
|
||||
REPLY_OCR_QUESTION_PROMPT,
|
||||
)
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.file import File
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ class InvoiceOCR(Action):
|
|||
|
||||
name: str = "InvoiceOCR"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@staticmethod
|
||||
async def _check_file_type(file_path: Path) -> str:
|
||||
|
|
@ -132,7 +132,7 @@ class GenerateTable(Action):
|
|||
|
||||
name: str = "GenerateTable"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, ocr_results: list, filename: str, *args, **kwargs) -> dict[str, str]:
|
||||
|
|
@ -177,7 +177,7 @@ class ReplyQuestion(Action):
|
|||
|
||||
name: str = "ReplyQuestion"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "ch"
|
||||
|
||||
async def run(self, query: str, ocr_result: list, *args, **kwargs) -> str:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from metagpt.actions import Action, ActionOutput
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -28,7 +28,7 @@ class PrepareDocuments(Action):
|
|||
|
||||
name: str = "PrepareDocuments"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def _init_repo(self):
|
||||
"""Initialize the Git environment."""
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, Documents
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
|
@ -43,7 +43,7 @@ NEW_REQ_TEMPLATE = """
|
|||
class WriteTasks(Action):
|
||||
name: str = "CreateTasks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema):
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from metagpt.actions import Action
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
from metagpt.tools.web_browser_engine import WebBrowserEngine, WebBrowserEngineType
|
||||
from metagpt.utils.common import OutputParser
|
||||
|
|
@ -82,7 +82,7 @@ class CollectLinks(Action):
|
|||
|
||||
name: str = "CollectLinks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Collect links from a search engine."
|
||||
search_engine: SearchEngine = Field(default_factory=SearchEngine)
|
||||
rank_func: Union[Callable[[list[str]], None], None] = None
|
||||
|
|
@ -177,7 +177,7 @@ class WebBrowseAndSummarize(Action):
|
|||
|
||||
name: str = "WebBrowseAndSummarize"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Explore the web and provide summaries of articles and webpages."
|
||||
browse_func: Union[Callable[[list[str]], None], None] = None
|
||||
web_browser_engine: WebBrowserEngine = Field(
|
||||
|
|
@ -248,7 +248,7 @@ class ConductResearch(Action):
|
|||
|
||||
name: str = "ConductResearch"
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
|
@ -79,7 +79,7 @@ standard errors:
|
|||
class RunCode(Action):
|
||||
name: str = "RunCode"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from metagpt.actions import Action
|
|||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
|
@ -109,7 +109,7 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
class SearchAndSummarize(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
config: None = Field(default_factory=Config)
|
||||
engine: Optional[SearchEngineType] = CONFIG.search_engine
|
||||
search_func: Optional[Any] = None
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseGPTAPI
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -95,7 +95,7 @@ flowchart TB
|
|||
class SummarizeCode(Action):
|
||||
name: str = "SummarizeCode"
|
||||
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
|
||||
async def summarize_code(self, prompt):
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -90,7 +90,7 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
class WriteCode(Action):
|
||||
name: str = "WriteCode"
|
||||
context: Document = Field(default_factory=Document)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code(self, prompt) -> str:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from metagpt.actions.action import Action
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -123,7 +123,7 @@ REWRITE_CODE_TEMPLATE = """
|
|||
class WriteCodeReview(Action):
|
||||
name: str = "WriteCodeReview"
|
||||
context: CodingContext = Field(default_factory=CodingContext)
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
from metagpt.utils.pycst import merge_docstring
|
||||
|
||||
|
|
@ -163,7 +163,7 @@ class WriteDocstring(Action):
|
|||
|
||||
desc: str = "Write docstring for code."
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ from metagpt.const import (
|
|||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import BugFixContext, Document, Documents, Message
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -67,7 +67,7 @@ NEW_REQ_TEMPLATE = """
|
|||
class WritePRD(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
|
|
|
|||
|
|
@ -12,13 +12,13 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class WritePRDReview(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
prd: Optional[str] = None
|
||||
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
|
||||
prd_review_prompt_template: str = """
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from pydantic import Field
|
|||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
REVIEW = ActionNode(
|
||||
key="Review",
|
||||
|
|
@ -38,7 +38,7 @@ class WriteReview(Action):
|
|||
"""Write a review for the given context."""
|
||||
|
||||
name: str = "WriteReview"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, context):
|
||||
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
|
|
|||
|
|
@ -13,14 +13,14 @@ from metagpt.actions import Action
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class WriteTeachingPlanPart(Action):
|
||||
"""Write Teaching Plan Part"""
|
||||
|
||||
context: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
topic: str = ""
|
||||
language: str = "Chinese"
|
||||
rsp: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from metagpt.config import CONFIG
|
|||
from metagpt.const import TEST_CODES_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, TestingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -45,7 +45,7 @@ you should correctly import the necessary classes based on these file locations!
|
|||
class WriteTest(Action):
|
||||
name: str = "WriteTest"
|
||||
context: Optional[TestingContext] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def write_code(self, prompt):
|
||||
code_rsp = await self._aask(prompt)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ from pydantic import Field
|
|||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
|
||||
|
||||
|
|
@ -27,7 +27,7 @@ class WriteDirectory(Action):
|
|||
"""
|
||||
|
||||
name: str = "WriteDirectory"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "Chinese"
|
||||
|
||||
async def run(self, topic: str, *args, **kwargs) -> Dict:
|
||||
|
|
@ -54,7 +54,7 @@ class WriteContent(Action):
|
|||
"""
|
||||
|
||||
name: str = "WriteContent"
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
directory: dict = dict()
|
||||
language: str = "Chinese"
|
||||
|
||||
|
|
|
|||
|
|
@ -9,14 +9,14 @@
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.human_provider import HumanProvider
|
||||
from metagpt.provider.llm_provider_registry import LLM_REGISTRY
|
||||
|
||||
_ = HumanProvider() # Avoid pre-commit error
|
||||
|
||||
|
||||
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseGPTAPI:
|
||||
def LLM(provider: Optional[LLMProviderEnum] = None) -> BaseLLM:
|
||||
"""get the default llm provider"""
|
||||
if provider is None:
|
||||
provider = CONFIG.get_default_llm_provider_enum()
|
||||
|
|
|
|||
|
|
@ -18,7 +18,7 @@ from metagpt.config import CONFIG
|
|||
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import MetaGPTAPI
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, SimpleMessage
|
||||
from metagpt.utils.redis import Redis
|
||||
|
||||
|
|
@ -31,7 +31,7 @@ class BrainMemory(BaseModel):
|
|||
is_dirty: bool = False
|
||||
last_talk: str = None
|
||||
cacheable: bool = True
|
||||
llm: Optional[BaseGPTAPI] = None
|
||||
llm: Optional[BaseLLM] = None
|
||||
|
||||
def add_talk(self, msg: Message):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,22 +6,22 @@
|
|||
@File : __init__.py
|
||||
"""
|
||||
|
||||
from metagpt.provider.fireworks_api import FireWorksGPTAPI
|
||||
from metagpt.provider.fireworks_api import FireworksLLM
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAIGPTAPI
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.metagpt_api import MetaGPTAPI
|
||||
|
||||
__all__ = [
|
||||
"FireWorksGPTAPI",
|
||||
"FireworksLLM",
|
||||
"GeminiGPTAPI",
|
||||
"OpenLLMGPTAPI",
|
||||
"OpenAIGPTAPI",
|
||||
"OpenAILLM",
|
||||
"ZhiPuAIGPTAPI",
|
||||
"AzureOpenAIGPTAPI",
|
||||
"AzureOpenAILLM",
|
||||
"MetaGPTAPI",
|
||||
"OllamaGPTAPI",
|
||||
"OllamaLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -15,11 +15,11 @@ from openai._base_client import AsyncHttpxClientWrapper
|
|||
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.AZURE_OPENAI)
|
||||
class AzureOpenAIGPTAPI(OpenAIGPTAPI):
|
||||
class AzureOpenAILLM(OpenAILLM):
|
||||
"""
|
||||
Check https://platform.openai.com/examples for examples
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -1,21 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/5 23:00
|
||||
@Author : alexanderwu
|
||||
@File : base_chatbot.py
|
||||
@Modified By: mashenquan, 2023/11/21. Add `timeout`.
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseChatbot(ABC):
|
||||
"""Abstract GPT class"""
|
||||
|
||||
use_system_prompt: bool = True
|
||||
|
||||
@abstractmethod
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
"""Ask GPT a question and get an answer"""
|
||||
|
|
@ -3,19 +3,18 @@
|
|||
"""
|
||||
@Time : 2023/5/5 23:04
|
||||
@Author : alexanderwu
|
||||
@File : base_gpt_api.py
|
||||
@File : base_llm.py
|
||||
@Desc : mashenquan, 2023/8/22. + try catch
|
||||
"""
|
||||
import json
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from metagpt.provider.base_chatbot import BaseChatbot
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""LLM API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
class BaseGPTAPI(BaseChatbot):
|
||||
"""GPT API abstract class, requiring all inheritors to provide a series of standard capabilities"""
|
||||
|
||||
use_system_prompt: bool = True
|
||||
system_prompt = "You are a helpful assistant."
|
||||
|
||||
def _user_msg(self, msg: str) -> dict[str, str]:
|
||||
|
|
@ -33,11 +32,6 @@ class BaseGPTAPI(BaseChatbot):
|
|||
def _default_system_msg(self):
|
||||
return self._system_msg(self.system_prompt)
|
||||
|
||||
def ask(self, msg: str, timeout=3) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)] if self.use_system_prompt else [self._user_msg(msg)]
|
||||
rsp = self.completion(message, timeout=timeout)
|
||||
return self.get_choice_text(rsp)
|
||||
|
||||
async def aask(
|
||||
self,
|
||||
msg: str,
|
||||
|
|
@ -54,7 +48,6 @@ class BaseGPTAPI(BaseChatbot):
|
|||
message.extend(format_msgs)
|
||||
message.append(self._user_msg(msg))
|
||||
rsp = await self.acompletion_text(message, stream=stream, timeout=timeout)
|
||||
# logger.debug(rsp)
|
||||
return rsp
|
||||
|
||||
def _extract_assistant_rsp(self, context):
|
||||
|
|
@ -75,15 +68,6 @@ class BaseGPTAPI(BaseChatbot):
|
|||
rsp_text = await self.aask_batch(msgs, timeout=timeout)
|
||||
return rsp_text
|
||||
|
||||
def completion(self, messages: list[dict], timeout=3) -> dict:
|
||||
"""All GPTAPIs are required to provide the standard OpenAI completion interface
|
||||
[
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "hello, show me python hello world code"},
|
||||
# {"role": "assistant", "content": ...}, # If there is an answer in the history, also include it
|
||||
]
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
"""Asynchronous version of completion
|
||||
|
|
@ -18,7 +18,7 @@ from tenacity import (
|
|||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter, log_and_reraise
|
||||
from metagpt.provider.openai_api import OpenAILLM, log_and_reraise
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
|
||||
MODEL_GRADE_TOKEN_COSTS = {
|
||||
|
|
@ -72,13 +72,12 @@ class FireworksCostManager(CostManager):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.FIREWORKS)
|
||||
class FireWorksGPTAPI(OpenAIGPTAPI):
|
||||
class FireworksLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_fireworks()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = FireworksCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_fireworks(self):
|
||||
self.is_azure = False
|
||||
|
|
|
|||
|
|
@ -47,8 +47,7 @@ MAX_CONNECTION_RETRIES = 2
|
|||
# Has one attribute per thread, 'session'.
|
||||
_thread_context = threading.local()
|
||||
|
||||
LLM_LOG = os.environ.get("LLM_LOG")
|
||||
LLM_LOG = "debug"
|
||||
LLM_LOG = os.environ.get("LLM_LOG", "debug")
|
||||
|
||||
|
||||
class ApiType(Enum):
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from tenacity import (
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.GEMINI)
|
||||
class GeminiGPTAPI(BaseGPTAPI):
|
||||
class GeminiGPTAPI(BaseLLM):
|
||||
"""
|
||||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ Author: garylin2099
|
|||
from typing import Optional
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class HumanProvider(BaseGPTAPI):
|
||||
class HumanProvider(BaseLLM):
|
||||
"""Humans provide themselves as a 'model', which actually takes in human input as its response.
|
||||
This enables replacing LLM anywhere in the framework with a human, thus introducing human interaction
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6,11 +6,11 @@
|
|||
@Desc : MetaGPT LLM provider.
|
||||
"""
|
||||
from metagpt.config import LLMProviderEnum
|
||||
from metagpt.provider import OpenAIGPTAPI
|
||||
from metagpt.provider import OpenAILLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.METAGPT)
|
||||
class MetaGPTAPI(OpenAIGPTAPI):
|
||||
class MetaGPTAPI(OpenAILLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ from tenacity import (
|
|||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.const import LLM_API_TIMEOUT
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
|
|
@ -39,7 +39,7 @@ class OllamaCostManager(CostManager):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.OLLAMA)
|
||||
class OllamaGPTAPI(BaseGPTAPI):
|
||||
class OllamaLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-chat-completion`
|
||||
"""
|
||||
|
|
@ -54,12 +54,8 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
|
||||
def __init_ollama(self, config: CONFIG):
|
||||
assert config.ollama_api_base
|
||||
|
||||
self.model = config.ollama_api_model
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def _const_kwargs(self, messages: list[dict], stream: bool = False) -> dict:
|
||||
kwargs = {"model": self.model, "messages": messages, "options": {"temperature": 0.3}, "stream": stream}
|
||||
return kwargs
|
||||
|
|
@ -87,18 +83,6 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
chunk = chunk.decode(encoding)
|
||||
return json.loads(chunk)
|
||||
|
||||
def completion(self, messages: list[dict]) -> dict:
|
||||
resp, _, _ = self.client.request(
|
||||
method=self.http_method,
|
||||
url=self.suffix_url,
|
||||
params=self._const_kwargs(messages),
|
||||
request_timeout=LLM_API_TIMEOUT,
|
||||
)
|
||||
resp = self._decode_and_load(resp)
|
||||
usage = self.get_usage(resp)
|
||||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def _achat_completion(self, messages: list[dict]) -> dict:
|
||||
resp, _, _ = await self.client.arequest(
|
||||
method=self.http_method,
|
||||
|
|
@ -111,7 +95,7 @@ class OllamaGPTAPI(BaseGPTAPI):
|
|||
self._update_costs(usage)
|
||||
return resp
|
||||
|
||||
async def acompletion(self, messages: list[dict]) -> dict:
|
||||
async def acompletion(self, messages: list[dict], timeout=3) -> dict:
|
||||
return await self._achat_completion(messages)
|
||||
|
||||
async def _achat_completion_stream(self, messages: list[dict]) -> str:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from openai.types import CompletionUsage
|
|||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI, RateLimiter
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.utils.cost_manager import CostManager, Costs
|
||||
from metagpt.utils.token_counter import count_message_tokens, count_string_tokens
|
||||
|
||||
|
|
@ -35,13 +35,12 @@ class OpenLLMCostManager(CostManager):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
class OpenLLMGPTAPI(OpenAIGPTAPI):
|
||||
class OpenLLMGPTAPI(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openllm()
|
||||
self.auto_max_tokens = False
|
||||
self._cost_manager = OpenLLMCostManager()
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
|
||||
def __init_openllm(self):
|
||||
self.is_azure = False
|
||||
|
|
|
|||
|
|
@ -3,15 +3,13 @@
|
|||
@Time : 2023/5/5 23:08
|
||||
@Author : alexanderwu
|
||||
@File : openai.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for business isolation;
|
||||
@Modified By: mashenquan, 2023/8/20. Remove global configuration `CONFIG`, enable configuration support for isolation;
|
||||
Change cost control from global to company level.
|
||||
@Modified By: mashenquan, 2023/11/21. Fix bug: ReadTimeout.
|
||||
@Modified By: mashenquan, 2023/12/1. Fix bug: Unclosed connection caused by openai 0.x.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import time
|
||||
from typing import AsyncIterator, Union
|
||||
|
||||
from openai import APIConnectionError, AsyncOpenAI, AsyncStream
|
||||
|
|
@ -28,7 +26,7 @@ from tenacity import (
|
|||
|
||||
from metagpt.config import CONFIG, Config, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA, GENERAL_TOOL_CHOICE
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -41,31 +39,6 @@ from metagpt.utils.token_counter import (
|
|||
)
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
"""Rate control class, each call goes through wait_if_needed, sleep if rate control is needed"""
|
||||
|
||||
def __init__(self, rpm):
|
||||
self.last_call_time = 0
|
||||
# Here 1.1 is used because even if the calls are made strictly according to time,
|
||||
# they will still be QOS'd; consider switching to simple error retry later
|
||||
self.interval = 1.1 * 60 / rpm
|
||||
self.rpm = rpm
|
||||
|
||||
def split_batches(self, batch):
|
||||
return [batch[i : i + self.rpm] for i in range(0, len(batch), self.rpm)]
|
||||
|
||||
async def wait_if_needed(self, num_requests):
|
||||
current_time = time.time()
|
||||
elapsed_time = current_time - self.last_call_time
|
||||
|
||||
if elapsed_time < self.interval * num_requests:
|
||||
remaining_time = self.interval * num_requests - elapsed_time
|
||||
logger.info(f"sleep {remaining_time}")
|
||||
await asyncio.sleep(remaining_time)
|
||||
|
||||
self.last_call_time = time.time()
|
||||
|
||||
|
||||
def log_and_reraise(retry_state):
|
||||
logger.error(f"Retry attempts exhausted. Last exception: {retry_state.outcome.exception()}")
|
||||
logger.warning(
|
||||
|
|
@ -78,7 +51,7 @@ See FAQ 5.8
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPENAI)
|
||||
class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""Check https://platform.openai.com/examples for examples"""
|
||||
|
||||
def __init__(self):
|
||||
|
|
@ -86,11 +59,8 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
self._init_openai()
|
||||
self._init_client()
|
||||
self.auto_max_tokens = False
|
||||
RateLimiter.__init__(self, rpm=self.rpm)
|
||||
super().__init__()
|
||||
|
||||
def _init_openai(self):
|
||||
self.rpm = int(self.config.openai_api_rpm)
|
||||
self.model = self.config.OPENAI_API_MODEL # Used in _calc_usage & _cons_kwargs
|
||||
|
||||
def _init_client(self):
|
||||
|
|
@ -211,7 +181,7 @@ class OpenAIGPTAPI(BaseGPTAPI, RateLimiter):
|
|||
Note: Keep kwargs consistent with https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
Examples:
|
||||
>>> llm = OpenAIGPTAPI()
|
||||
>>> llm = OpenAILLM()
|
||||
>>> msg = [{'role': 'user', 'content': "Write a python hello world code."}]
|
||||
>>> rsp = await llm.aask_code(msg)
|
||||
# -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/7/21 11:15
|
||||
@Author : Leo Xiao
|
||||
@File : anthropic_api.py
|
||||
@File : spark_api.py
|
||||
"""
|
||||
import _thread as thread
|
||||
import base64
|
||||
|
|
@ -13,7 +11,6 @@ import hmac
|
|||
import json
|
||||
import ssl
|
||||
from time import mktime
|
||||
from typing import Optional
|
||||
from urllib.parse import urlencode, urlparse
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
|
|
@ -21,32 +18,15 @@ import websocket # 使用websocket_client
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.SPARK)
|
||||
class SparkGPTAPI(BaseGPTAPI):
|
||||
class SparkLLM(BaseLLM):
|
||||
def __init__(self):
|
||||
logger.warning("当前方法无法支持异步运行。当你使用acompletion时,并不能并行访问。")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def ask(self, msg: str) -> str:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = self.completion(message)
|
||||
return rsp
|
||||
|
||||
async def aask(self, msg: str, system_msgs: Optional[list[str]] = None, stream: bool = True) -> str:
|
||||
if system_msgs:
|
||||
message = self._system_msgs(system_msgs) + [self._user_msg(msg)]
|
||||
else:
|
||||
message = [self._default_system_msg(), self._user_msg(msg)]
|
||||
rsp = await self.acompletion(message)
|
||||
logger.debug(message)
|
||||
return rsp
|
||||
|
||||
def get_choice_text(self, rsp: dict) -> str:
|
||||
return rsp["payload"]["choices"]["text"][-1]["content"]
|
||||
|
||||
|
|
@ -56,15 +36,11 @@ class SparkGPTAPI(BaseGPTAPI):
|
|||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
async def acompletion(self, messages: list[dict]):
|
||||
async def acompletion(self, messages: list[dict], timeout=3):
|
||||
# 不支持异步
|
||||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
def completion(self, messages: list[dict]):
|
||||
w = GetMessageFromWeb(messages)
|
||||
return w.run()
|
||||
|
||||
|
||||
class GetMessageFromWeb:
|
||||
class WsParam:
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from tenacity import (
|
|||
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.logs import log_llm_stream, logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.provider.llm_provider_registry import register_provider
|
||||
from metagpt.provider.openai_api import log_and_reraise
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
|
@ -31,7 +31,7 @@ class ZhiPuEvent(Enum):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.ZHIPUAI)
|
||||
class ZhiPuAIGPTAPI(BaseGPTAPI):
|
||||
class ZhiPuAIGPTAPI(BaseLLM):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ from metagpt.const import SERDESER_PATH
|
|||
from metagpt.llm import LLM, HumanProvider
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
|
|
@ -141,7 +141,7 @@ class Role(BaseModel):
|
|||
desc: str = ""
|
||||
is_human: bool = False
|
||||
|
||||
_llm: BaseGPTAPI = Field(default_factory=LLM) # Each role has its own LLM, use different system message
|
||||
_llm: BaseLLM = Field(default_factory=LLM) # Each role has its own LLM, use different system message
|
||||
_role_id: str = ""
|
||||
_states: list[str] = []
|
||||
_actions: list[Action] = []
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from metagpt.actions import UserRequirement
|
|||
from metagpt.actions.execute_task import ExecuteTask
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.make_sk_kernel import make_sk_kernel
|
||||
|
|
@ -44,7 +44,7 @@ class SkAgent(Role):
|
|||
plan: Any = None
|
||||
planner_cls: Any = None
|
||||
planner: Any = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
kernel: Kernel = Field(default_factory=Kernel)
|
||||
import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None
|
||||
import_skill: Type[Kernel.import_skill] = None
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as GPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM as GPTAPI
|
||||
|
||||
ICL_SAMPLE = """Interface definition:
|
||||
```text
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
from tests.metagpt.actions.mock_markdown import TASKS_2, WRITE_CODE_PROMPT_SAMPLE
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
default_chat_resp = {
|
||||
|
|
@ -27,7 +27,7 @@ prompt_msg = "who are you"
|
|||
resp_content = default_chat_resp["choices"][0]["message"]["content"]
|
||||
|
||||
|
||||
class MockBaseGPTAPI(BaseGPTAPI):
|
||||
class MockBaseGPTAPI(BaseLLM):
|
||||
def completion(self, messages: list[dict], timeout=3):
|
||||
return default_chat_resp
|
||||
|
||||
|
|
|
|||
|
|
@ -13,7 +13,7 @@ from openai.types.completion_usage import CompletionUsage
|
|||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireWorksGPTAPI,
|
||||
FireworksLLM,
|
||||
)
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
|
|
@ -62,7 +62,7 @@ async def test_fireworks_acompletion(mocker):
|
|||
mocker.patch(
|
||||
"metagpt.provider.fireworks_api.FireWorksGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
fireworks_gpt = FireWorksGPTAPI()
|
||||
fireworks_gpt = FireworksLLM()
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages, stream=False)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
|
|
|||
|
|
@ -35,16 +35,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.completion", mock_llm_completion)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp.text == resp_content
|
||||
|
||||
resp = gemini_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
|
||||
|
|
|
|||
|
|
@ -17,15 +17,6 @@ async def mock_llm_aask(msg: str, timeout: int = 3) -> str:
|
|||
return mock_llm_ask(msg)
|
||||
|
||||
|
||||
def test_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.ask", mock_llm_ask)
|
||||
human_provider = HumanProvider()
|
||||
|
||||
assert resp_content == human_provider.ask(None)
|
||||
|
||||
assert not human_provider.completion(messages=[])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_human_provider(mocker):
|
||||
mocker.patch("metagpt.provider.human_provider.HumanProvider.aask", mock_llm_aask)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.ollama_api import OllamaGPTAPI
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
|
@ -28,22 +28,12 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_gemini_completion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.completion", mock_llm_completion)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
resp = ollama_gpt.completion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
||||
resp = ollama_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
|
||||
ollama_gpt = OllamaGPTAPI()
|
||||
ollama_gpt = OllamaLLM()
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
assert resp["message"]["content"] == default_resp["message"]["content"]
|
||||
|
|
|
|||
|
|
@ -2,13 +2,13 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = [{"role": "user", "content": "Write a python hello world code."}]
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -18,7 +18,7 @@ async def test_aask_code():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_str():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = "Write a python hello world code."
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -28,7 +28,7 @@ async def test_aask_code_str():
|
|||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code_Message():
|
||||
llm = OpenAIGPTAPI()
|
||||
llm = OpenAILLM()
|
||||
msg = UserMessage("Write a python hello world code.")
|
||||
rsp = await llm.aask_code(msg) # -> {'language': 'python', 'code': "print('Hello, World!')"}
|
||||
assert "language" in rsp
|
||||
|
|
@ -84,7 +84,7 @@ class TestOpenAI:
|
|||
)
|
||||
|
||||
def test_make_client_kwargs_without_proxy(self, config):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
|
|
@ -93,7 +93,7 @@ class TestOpenAI:
|
|||
assert "http_client" not in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_without_proxy_azure(self, config_azure):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert kwargs == {"api_key": "test_key", "base_url": "test_url"}
|
||||
|
|
@ -102,14 +102,14 @@ class TestOpenAI:
|
|||
assert "http_client" not in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy(self, config_proxy):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_proxy
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
assert "http_client" in async_kwargs
|
||||
|
||||
def test_make_client_kwargs_with_proxy_azure(self, config_azure_proxy):
|
||||
instance = OpenAIGPTAPI()
|
||||
instance = OpenAILLM()
|
||||
instance.config = config_azure_proxy
|
||||
kwargs, async_kwargs = instance._make_client_kwargs()
|
||||
assert "http_client" in kwargs
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.spark_api import SparkGPTAPI
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
|
@ -18,24 +18,13 @@ async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False,
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_spark_completion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.completion", mock_llm_completion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
|
||||
resp = spark_gpt.completion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = spark_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
|
||||
spark_gpt = SparkGPTAPI()
|
||||
spark_gpt = SparkLLM()
|
||||
|
||||
resp = await spark_gpt.acompletion([], stream=False)
|
||||
resp = await spark_gpt.acompletion([])
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await spark_gpt.aask(prompt_msg, stream=False)
|
||||
|
|
|
|||
|
|
@ -28,18 +28,6 @@ async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
|||
return resp_content
|
||||
|
||||
|
||||
def test_zhipuai_completion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.completion", mock_llm_completion)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
|
||||
resp = zhipu_gpt.completion(messages)
|
||||
assert resp["code"] == 200
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
||||
resp = zhipu_gpt.ask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
|
||||
|
|
|
|||
|
|
@ -14,15 +14,6 @@ from metagpt.logs import logger
|
|||
|
||||
@pytest.mark.usefixtures("llm_api")
|
||||
class TestGPT:
|
||||
def test_llm_api_ask(self, llm_api):
|
||||
answer = llm_api.ask("hello chatgpt")
|
||||
logger.info(answer)
|
||||
assert len(answer) > 0
|
||||
|
||||
def test_gptapi_ask_batch(self, llm_api):
|
||||
answer = llm_api.ask_batch(["请扮演一个Google Python专家工程师,如果理解,回复明白", "写一个hello world"], timeout=60)
|
||||
assert len(answer) > 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_llm_api_aask(self, llm_api):
|
||||
answer = await llm_api.aask("hello chatgpt", stream=False)
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.openai_api import OpenAIGPTAPI as LLM
|
||||
from metagpt.provider.openai_api import OpenAILLM as LLM
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue