mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-11 15:15:18 +02:00
Merge pull request #649 from better629/dev
Feat provider increase coverage ratio and fix some problems.
This commit is contained in:
commit
235b5b5189
62 changed files with 558 additions and 235 deletions
|
|
@ -17,7 +17,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
from metagpt.utils.common import OutputParser, general_after_log
|
||||
|
||||
TAG = "CONTENT"
|
||||
|
|
@ -275,7 +275,7 @@ class ActionNode:
|
|||
output_class = self.create_model_class(output_class_name, output_data_mapping)
|
||||
|
||||
if schema == "json":
|
||||
parsed_data = llm_output_postprecess(
|
||||
parsed_data = llm_output_postprocess(
|
||||
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
|
||||
)
|
||||
else: # using markdown parser
|
||||
|
|
|
|||
|
|
@ -15,7 +15,6 @@ from pydantic import Field
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO, TEST_OUTPUTS_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -52,7 +51,6 @@ Now you should start rewriting the code:
|
|||
class DebugError(Action):
|
||||
name: str = "DebugError"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, *args, **kwargs) -> str:
|
||||
output_doc = await FileRepository.get_file(
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@ import json
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.design_api_an import DESIGN_API_NODE
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -25,9 +23,7 @@ from metagpt.const import (
|
|||
SYSTEM_DESIGN_FILE_REPO,
|
||||
SYSTEM_DESIGN_PDF_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, Documents, Message
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.mermaid import mermaid_to_file
|
||||
|
|
@ -44,7 +40,6 @@ NEW_REQ_TEMPLATE = """
|
|||
class WriteDesign(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = (
|
||||
"Based on the PRD, think about the system design, and design the corresponding APIs, "
|
||||
"data structures, library tables, processes, and paths. Please provide your design, feedback "
|
||||
|
|
|
|||
|
|
@ -8,17 +8,12 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class DesignReview(Action):
|
||||
name: str = "DesignReview"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, prd, api_design):
|
||||
prompt = (
|
||||
|
|
|
|||
|
|
@ -6,18 +6,14 @@
|
|||
@File : execute_task.py
|
||||
"""
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
class ExecuteTask(Action):
|
||||
name: str = "ExecuteTask"
|
||||
context: list[Message] = []
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -42,7 +42,6 @@ class InvoiceOCR(Action):
|
|||
|
||||
name: str = "InvoiceOCR"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@staticmethod
|
||||
async def _check_file_type(file_path: Path) -> str:
|
||||
|
|
|
|||
|
|
@ -11,13 +11,9 @@ import shutil
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DOCS_FILE_REPO, REQUIREMENT_FILENAME
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
from metagpt.utils.git_repository import GitRepository
|
||||
|
|
@ -28,7 +24,6 @@ class PrepareDocuments(Action):
|
|||
|
||||
name: str = "PrepareDocuments"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
def _init_repo(self):
|
||||
"""Initialize the Git environment."""
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import ActionOutput
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.actions.project_management_an import PM_NODE
|
||||
|
|
@ -25,9 +23,7 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
TASK_PDF_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, Documents
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
||||
|
|
@ -43,7 +39,6 @@ NEW_REQ_TEMPLATE = """
|
|||
class WriteTasks(Action):
|
||||
name: str = "CreateTasks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema):
|
||||
system_design_file_repo = CONFIG.git_repo.new_file_repository(SYSTEM_DESIGN_FILE_REPO)
|
||||
|
|
|
|||
|
|
@ -82,7 +82,6 @@ class CollectLinks(Action):
|
|||
|
||||
name: str = "CollectLinks"
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Collect links from a search engine."
|
||||
|
||||
search_engine: SearchEngine = Field(default_factory=SearchEngine)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ from pydantic import Field
|
|||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import RunCodeContext, RunCodeResult
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
|
@ -79,7 +78,6 @@ standard errors:
|
|||
class RunCode(Action):
|
||||
name: str = "RunCode"
|
||||
context: RunCodeContext = Field(default_factory=RunCodeContext)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@classmethod
|
||||
async def run_text(cls, code) -> Tuple[str, str]:
|
||||
|
|
|
|||
|
|
@ -12,9 +12,7 @@ from pydantic import Field, model_validator
|
|||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG, Config
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.tools.search_engine import SearchEngine
|
||||
|
|
@ -109,7 +107,6 @@ You are a member of a professional butler team and will provide helpful suggesti
|
|||
class SearchAndSummarize(Action):
|
||||
name: str = ""
|
||||
content: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
config: None = Field(default_factory=Config)
|
||||
engine: Optional[SearchEngineType] = CONFIG.search_engine
|
||||
search_func: Optional[Any] = None
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.llm import LLM, BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.schema import CodeSummarizeContext
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -95,7 +94,6 @@ flowchart TB
|
|||
class SummarizeCode(Action):
|
||||
name: str = "SummarizeCode"
|
||||
context: CodeSummarizeContext = Field(default_factory=CodeSummarizeContext)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(stop=stop_after_attempt(2), wait=wait_random_exponential(min=1, max=60))
|
||||
async def summarize_code(self, prompt):
|
||||
|
|
|
|||
|
|
@ -29,9 +29,7 @@ from metagpt.const import (
|
|||
TASK_FILE_REPO,
|
||||
TEST_OUTPUTS_FILE_REPO,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import CodingContext, Document, RunCodeResult
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -90,7 +88,6 @@ ATTENTION: Use '##' to SPLIT SECTIONS, not '#'. Output format carefully referenc
|
|||
class WriteCode(Action):
|
||||
name: str = "WriteCode"
|
||||
context: Document = Field(default_factory=Document)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code(self, prompt) -> str:
|
||||
|
|
|
|||
|
|
@ -14,9 +14,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
from metagpt.actions import WriteCode
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import CodingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -123,7 +121,6 @@ REWRITE_CODE_TEMPLATE = """
|
|||
class WriteCodeReview(Action):
|
||||
name: str = "WriteCodeReview"
|
||||
context: CodingContext = Field(default_factory=CodingContext)
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
|
||||
async def write_code_review_and_rewrite(self, context_prompt, cr_prompt, filename):
|
||||
|
|
|
|||
|
|
@ -27,11 +27,7 @@ import ast
|
|||
from pathlib import Path
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser, aread, awrite
|
||||
from metagpt.utils.pycst import merge_docstring
|
||||
|
||||
|
|
@ -166,7 +162,6 @@ class WriteDocstring(Action):
|
|||
|
||||
desc: str = "Write docstring for code."
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -17,8 +17,6 @@ import json
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.fix_bug import FixBug
|
||||
|
|
@ -37,9 +35,7 @@ from metagpt.const import (
|
|||
PRDS_FILE_REPO,
|
||||
REQUIREMENT_FILENAME,
|
||||
)
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import BugFixContext, Document, Documents, Message
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.file_repository import FileRepository
|
||||
|
|
@ -68,7 +64,6 @@ NEW_REQ_TEMPLATE = """
|
|||
class WritePRD(Action):
|
||||
name: str = "WritePRD"
|
||||
content: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
|
|
|
|||
|
|
@ -8,17 +8,12 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class WritePRDReview(Action):
|
||||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
prd: Optional[str] = None
|
||||
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
|
||||
|
|
|
|||
|
|
@ -6,12 +6,8 @@
|
|||
"""
|
||||
from typing import List
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
REVIEW = ActionNode(
|
||||
key="Review",
|
||||
|
|
@ -38,7 +34,6 @@ class WriteReview(Action):
|
|||
"""Write a review for the given context."""
|
||||
|
||||
name: str = "WriteReview"
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, context):
|
||||
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
|
|
|
|||
|
|
@ -7,20 +7,15 @@
|
|||
"""
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
|
||||
|
||||
class WriteTeachingPlanPart(Action):
|
||||
"""Write Teaching Plan Part"""
|
||||
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
topic: str = ""
|
||||
language: str = "Chinese"
|
||||
rsp: Optional[str] = None
|
||||
|
|
|
|||
|
|
@ -10,14 +10,10 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.action import Action
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import TEST_CODES_FILE_REPO
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Document, TestingContext
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
|
|
@ -45,7 +41,6 @@ you should correctly import the necessary classes based on these file locations!
|
|||
class WriteTest(Action):
|
||||
name: str = "WriteTest"
|
||||
context: Optional[TestingContext] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def write_code(self, prompt):
|
||||
code_rsp = await self._aask(prompt)
|
||||
|
|
|
|||
|
|
@ -9,12 +9,8 @@
|
|||
|
||||
from typing import Dict
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.prompts.tutorial_assistant import CONTENT_PROMPT, DIRECTORY_PROMPT
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import OutputParser
|
||||
|
||||
|
||||
|
|
@ -27,7 +23,6 @@ class WriteDirectory(Action):
|
|||
"""
|
||||
|
||||
name: str = "WriteDirectory"
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
language: str = "Chinese"
|
||||
|
||||
async def run(self, topic: str, *args, **kwargs) -> Dict:
|
||||
|
|
@ -54,7 +49,6 @@ class WriteContent(Action):
|
|||
"""
|
||||
|
||||
name: str = "WriteContent"
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
directory: dict = dict()
|
||||
language: str = "Chinese"
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from pydantic import BaseModel, Field
|
|||
from metagpt.config import CONFIG
|
||||
from metagpt.const import DEFAULT_LANGUAGE, DEFAULT_MAX_TOKENS, DEFAULT_TOKEN_SIZE
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider import MetaGPTAPI
|
||||
from metagpt.provider import MetaGPTLLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, SimpleMessage
|
||||
from metagpt.utils.redis import Redis
|
||||
|
|
@ -123,7 +123,7 @@ class BrainMemory(BaseModel):
|
|||
return v
|
||||
|
||||
async def summarize(self, llm, max_words=200, keep_language: bool = False, limit: int = -1, **kwargs):
|
||||
if isinstance(llm, MetaGPTAPI):
|
||||
if isinstance(llm, MetaGPTLLM):
|
||||
return await self._metagpt_summarize(max_words=max_words)
|
||||
|
||||
self.llm = llm
|
||||
|
|
@ -176,7 +176,7 @@ class BrainMemory(BaseModel):
|
|||
|
||||
async def get_title(self, llm, max_words=5, **kwargs) -> str:
|
||||
"""Generate text title"""
|
||||
if isinstance(llm, MetaGPTAPI):
|
||||
if isinstance(llm, MetaGPTLLM):
|
||||
return self.history[0].content if self.history else "New"
|
||||
|
||||
summary = await self.summarize(llm=llm, max_words=500)
|
||||
|
|
@ -191,7 +191,7 @@ class BrainMemory(BaseModel):
|
|||
return response
|
||||
|
||||
async def is_related(self, text1, text2, llm):
|
||||
if isinstance(llm, MetaGPTAPI):
|
||||
if isinstance(llm, MetaGPTLLM):
|
||||
return await self._metagpt_is_related(text1=text1, text2=text2, llm=llm)
|
||||
return await self._openai_is_related(text1=text1, text2=text2, llm=llm)
|
||||
|
||||
|
|
@ -213,7 +213,7 @@ class BrainMemory(BaseModel):
|
|||
return result
|
||||
|
||||
async def rewrite(self, sentence: str, context: str, llm):
|
||||
if isinstance(llm, MetaGPTAPI):
|
||||
if isinstance(llm, MetaGPTLLM):
|
||||
return await self._metagpt_rewrite(sentence=sentence, context=context, llm=llm)
|
||||
return await self._openai_rewrite(sentence=sentence, context=context, llm=llm)
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from pydantic import ConfigDict, Field
|
|||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.roles.role import RoleContext
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
|
|
@ -25,10 +26,10 @@ class LongTermMemory(Memory):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_storage: MemoryStorage = Field(default_factory=MemoryStorage)
|
||||
rc: Optional["RoleContext"] = None
|
||||
rc: Optional[RoleContext] = None
|
||||
msg_from_recover: bool = False
|
||||
|
||||
def recover_memory(self, role_id: str, rc: "RoleContext"):
|
||||
def recover_memory(self, role_id: str, rc: RoleContext):
|
||||
messages = self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
if not self.memory_storage.is_initialized:
|
||||
|
|
|
|||
|
|
@ -7,21 +7,21 @@
|
|||
"""
|
||||
|
||||
from metagpt.provider.fireworks_api import FireworksLLM
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
from metagpt.provider.ollama_api import OllamaLLM
|
||||
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
from metagpt.provider.metagpt_api import MetaGPTAPI
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
|
||||
__all__ = [
|
||||
"FireworksLLM",
|
||||
"GeminiGPTAPI",
|
||||
"OpenLLMGPTAPI",
|
||||
"GeminiLLM",
|
||||
"OpenLLM",
|
||||
"OpenAILLM",
|
||||
"ZhiPuAIGPTAPI",
|
||||
"ZhiPuAILLM",
|
||||
"AzureOpenAILLM",
|
||||
"MetaGPTAPI",
|
||||
"MetaGPTLLM",
|
||||
"OllamaLLM",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -100,7 +100,7 @@ def log_info(message, **params):
|
|||
def log_warn(message, **params):
|
||||
msg = logfmt(dict(message=message, **params))
|
||||
print(msg, file=sys.stderr)
|
||||
logger.warn(msg)
|
||||
logger.warning(msg)
|
||||
|
||||
|
||||
def logfmt(props):
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ class GeminiGenerativeModel(GenerativeModel):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.GEMINI)
|
||||
class GeminiGPTAPI(BaseLLM):
|
||||
class GeminiLLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://ai.google.dev/tutorials/python_quickstart`
|
||||
"""
|
||||
|
|
@ -79,9 +79,6 @@ class GeminiGPTAPI(BaseLLM):
|
|||
except Exception as e:
|
||||
logger.error(f"google gemini updats costs failed! exp: {e}")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get_choice_text(self, resp: GenerateContentResponse) -> str:
|
||||
return resp.text
|
||||
|
||||
|
|
|
|||
|
|
@ -11,6 +11,6 @@ from metagpt.provider.llm_provider_registry import register_provider
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.METAGPT)
|
||||
class MetaGPTAPI(OpenAILLM):
|
||||
class MetaGPTLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
|||
|
|
@ -31,11 +31,10 @@ class OpenLLMCostManager(CostManager):
|
|||
f"Max budget: ${CONFIG.max_budget:.3f} | reference "
|
||||
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
|
||||
)
|
||||
CONFIG.total_cost = self.total_cost
|
||||
|
||||
|
||||
@register_provider(LLMProviderEnum.OPEN_LLM)
|
||||
class OpenLLMGPTAPI(OpenAILLM):
|
||||
class OpenLLM(OpenAILLM):
|
||||
def __init__(self):
|
||||
self.config: Config = CONFIG
|
||||
self.__init_openllm()
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from metagpt.utils.repair_llm_raw_output import (
|
|||
)
|
||||
|
||||
|
||||
class BasePostPrecessPlugin(object):
|
||||
model = None # the plugin of the `model`, use to judge in `llm_postprecess`
|
||||
class BasePostProcessPlugin(object):
|
||||
model = None # the plugin of the `model`, use to judge in `llm_postprocess`
|
||||
|
||||
def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
|
||||
"""
|
||||
|
|
@ -4,17 +4,17 @@
|
|||
|
||||
from typing import Union
|
||||
|
||||
from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin
|
||||
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
|
||||
|
||||
|
||||
def llm_output_postprecess(
|
||||
def llm_output_postprocess(
|
||||
output: str, schema: dict, req_key: str = "[/CONTENT]", model_name: str = None
|
||||
) -> Union[dict, str]:
|
||||
"""
|
||||
default use BasePostPrecessPlugin if there is not matched plugin.
|
||||
default use BasePostProcessPlugin if there is not matched plugin.
|
||||
"""
|
||||
# TODO choose different model's plugin according to the model_name
|
||||
postprecess_plugin = BasePostPrecessPlugin()
|
||||
postprocess_plugin = BasePostProcessPlugin()
|
||||
|
||||
result = postprecess_plugin.run(output=output, schema=schema, req_key=req_key)
|
||||
result = postprocess_plugin.run(output=output, schema=schema, req_key=req_key)
|
||||
return result
|
||||
|
|
@ -33,7 +33,7 @@ class ZhiPuModelAPI(ModelAPI):
|
|||
zhipu_api_url: https://open.bigmodel.cn/api/paas/v3/model-api/{model}/{invoke_method}
|
||||
"""
|
||||
arr = zhipu_api_url.split("/api/")
|
||||
# ("https://open.bigmodel.cn/api/" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
# ("https://open.bigmodel.cn/api" , "/paas/v3/model-api/chatglm_turbo/invoke")
|
||||
return f"{arr[0]}/api", f"/{arr[1]}"
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
import json
|
||||
from enum import Enum
|
||||
|
||||
import openai
|
||||
import zhipuai
|
||||
from requests import ConnectionError
|
||||
from tenacity import (
|
||||
|
|
@ -31,7 +32,7 @@ class ZhiPuEvent(Enum):
|
|||
|
||||
|
||||
@register_provider(LLMProviderEnum.ZHIPUAI)
|
||||
class ZhiPuAIGPTAPI(BaseLLM):
|
||||
class ZhiPuAILLM(BaseLLM):
|
||||
"""
|
||||
Refs to `https://open.bigmodel.cn/dev/api#chatglm_turbo`
|
||||
From now, there is only one model named `chatglm_turbo`
|
||||
|
|
@ -67,9 +68,6 @@ class ZhiPuAIGPTAPI(BaseLLM):
|
|||
except Exception as e:
|
||||
logger.error(f"zhipuai updats costs failed! exp: {e}")
|
||||
|
||||
def close(self):
|
||||
pass
|
||||
|
||||
def get_choice_text(self, resp: dict) -> str:
|
||||
"""get the first text of choice from llm response"""
|
||||
assist_msg = resp.get("data", {}).get("choices", [{"role": "error"}])[-1]
|
||||
|
|
|
|||
|
|
@ -394,7 +394,9 @@ class Role(SerializationMixin, is_polymorphic_base=True):
|
|||
old_messages = [] if ignore_memory else self.rc.memory.get()
|
||||
self.rc.memory.add_batch(news)
|
||||
# Filter out messages of interest.
|
||||
self.rc.news = [n for n in news if n.cause_by in self.rc.watch and n not in old_messages]
|
||||
self.rc.news = [
|
||||
n for n in news if (n.cause_by in self.rc.watch or self.name in n.send_to) and n not in old_messages
|
||||
]
|
||||
self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None # record the latest observed msg
|
||||
|
||||
# Design Rules:
|
||||
|
|
|
|||
|
|
@ -174,7 +174,7 @@ class Message(BaseModel):
|
|||
role: str = "user" # system / user / assistant
|
||||
cause_by: str = Field(default="", validate_default=True)
|
||||
sent_from: str = Field(default="", validate_default=True)
|
||||
send_to: set = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
|
||||
send_to: set[str] = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
|
||||
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -20,6 +20,10 @@ def test_ltm_search():
|
|||
assert len(CONFIG.openai_api_key) > 20
|
||||
|
||||
role_id = "UTUserLtm(Product Manager)"
|
||||
from metagpt.environment import Environment
|
||||
|
||||
Environment
|
||||
RoleContext.model_rebuild()
|
||||
rc = RoleContext(watch={"metagpt.actions.add_requirement.UserRequirement"})
|
||||
ltm = LongTermMemory()
|
||||
ltm.recover_memory(role_id, rc)
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ def test_idea_message():
|
|||
role_id = "UTUser1(Product Manager)"
|
||||
message = Message(role="User", content=idea, cause_by=UserRequirement)
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"))
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
|
|
@ -58,7 +58,7 @@ def test_actionout_message():
|
|||
content=content, instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
) # WritePRD as test action
|
||||
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"))
|
||||
shutil.rmtree(Path(DATA_PATH / f"role_mem/{role_id}/"), ignore_errors=True)
|
||||
|
||||
memory_storage: MemoryStorage = MemoryStorage()
|
||||
messages = memory_storage.recover_memory(role_id)
|
||||
|
|
|
|||
3
tests/metagpt/provider/postprocess/__init__.py
Normal file
3
tests/metagpt/provider/postprocess/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
|
@ -0,0 +1,31 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
|
||||
from metagpt.provider.postprocess.base_postprocess_plugin import BasePostProcessPlugin
|
||||
|
||||
raw_output = """
|
||||
[CONTENT]
|
||||
{
|
||||
"Original Requirements": "xxx"
|
||||
}
|
||||
[/CONTENT]
|
||||
"""
|
||||
raw_schema = {
|
||||
"title": "prd",
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"Original Requirements": {"title": "Original Requirements", "type": "string"},
|
||||
},
|
||||
"required": [
|
||||
"Original Requirements",
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_llm_post_process_plugin():
|
||||
post_process_plugin = BasePostProcessPlugin()
|
||||
|
||||
output = post_process_plugin.run(output=raw_output, schema=raw_schema)
|
||||
assert "Original Requirements" in output
|
||||
|
|
@ -0,0 +1,15 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
from tests.metagpt.provider.postprocess.test_base_postprocess_plugin import (
|
||||
raw_output,
|
||||
raw_schema,
|
||||
)
|
||||
|
||||
|
||||
def test_llm_output_postprocess():
|
||||
output = llm_output_postprocess(output=raw_output, schema=raw_schema)
|
||||
assert "Original Requirements" in output
|
||||
|
|
@ -2,28 +2,33 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of Claude2
|
||||
|
||||
import pytest
|
||||
|
||||
import pytest
|
||||
from anthropic.resources.completions import Completion
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.anthropic_api import Claude2
|
||||
|
||||
CONFIG.anthropic_api_key = "xxx"
|
||||
|
||||
prompt = "who are you"
|
||||
resp = "I'am Claude2"
|
||||
|
||||
|
||||
def mock_llm_ask(self, msg: str) -> str:
|
||||
return resp
|
||||
def mock_anthropic_completions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
async def mock_llm_aask(self, msg: str) -> str:
|
||||
return resp
|
||||
async def mock_anthropic_acompletions_create(self, model: str, prompt: str, max_tokens_to_sample: int) -> Completion:
|
||||
return Completion(id="xx", completion=resp, model="claude-2", stop_reason="stop_sequence", type="completion")
|
||||
|
||||
|
||||
def test_claude2_ask(mocker):
|
||||
mocker.patch("metagpt.provider.anthropic_api.Claude2.ask", mock_llm_ask)
|
||||
mocker.patch("anthropic.resources.completions.Completions.create", mock_anthropic_completions_create)
|
||||
assert resp == Claude2().ask(prompt)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude2_aask(mocker):
|
||||
mocker.patch("metagpt.provider.anthropic_api.Claude2.aask", mock_llm_aask)
|
||||
mocker.patch("anthropic.resources.completions.AsyncCompletions.create", mock_anthropic_acompletions_create)
|
||||
assert resp == await Claude2().aask(prompt)
|
||||
|
|
|
|||
|
|
@ -1,20 +1,14 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_azure_openai.py
|
||||
"""
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.llm import LLM
|
||||
# @Desc :
|
||||
|
||||
|
||||
def test_llm():
|
||||
# Prerequisites
|
||||
assert CONFIG.DEPLOYMENT_NAME and CONFIG.DEPLOYMENT_NAME != "YOUR_DEPLOYMENT_NAME"
|
||||
assert CONFIG.OPENAI_API_KEY and CONFIG.OPENAI_API_KEY != "YOUR_AZURE_API_KEY"
|
||||
assert CONFIG.OPENAI_API_VERSION
|
||||
assert CONFIG.OPENAI_BASE_URL
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.azure_openai_api import AzureOpenAILLM
|
||||
|
||||
llm = LLM(provider=LLMProviderEnum.AZURE_OPENAI)
|
||||
assert llm
|
||||
CONFIG.OPENAI_API_VERSION = "xx"
|
||||
CONFIG.openai_proxy = "http://127.0.0.1:80" # fake value
|
||||
|
||||
|
||||
def test_azure_openai_api():
|
||||
_ = AzureOpenAILLM()
|
||||
|
|
|
|||
|
|
@ -8,13 +8,22 @@ from openai.types.chat.chat_completion import (
|
|||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.fireworks_api import (
|
||||
MODEL_GRADE_TOKEN_COSTS,
|
||||
FireworksCostManager,
|
||||
FireworksLLM,
|
||||
)
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
|
||||
CONFIG.fireworks_api_key = "xxx"
|
||||
CONFIG.max_budget = 10
|
||||
CONFIG.calc_usage = True
|
||||
|
||||
resp_content = "I'm fireworks"
|
||||
default_resp = ChatCompletion(
|
||||
|
|
@ -25,14 +34,30 @@ default_resp = ChatCompletion(
|
|||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
logprobs=None,
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=CompletionUsage(completion_tokens=110, prompt_tokens=92, total_tokens=202),
|
||||
)
|
||||
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
usage=dict(default_resp.usage),
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
|
@ -47,29 +72,37 @@ def test_fireworks_costmanager():
|
|||
assert MODEL_GRADE_TOKEN_COSTS["80"] == cost_manager.model_grade_token_costs("xxx-80b-chat")
|
||||
assert MODEL_GRADE_TOKEN_COSTS["mixtral-8x7b"] == cost_manager.model_grade_token_costs("mixtral-8x7b-chat")
|
||||
|
||||
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> ChatCompletion:
|
||||
return default_resp
|
||||
cost_manager.update_cost(prompt_tokens=500000, completion_tokens=500000, model="llama-v2-13b-chat")
|
||||
assert cost_manager.total_cost == 0.5
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> ChatCompletion:
|
||||
return default_resp
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return default_resp.choices[0].message.content
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fireworks_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireworksLLM.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.fireworks_api.FireworksLLM._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.fireworks_api.FireworksLLM._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
fireworks_gpt = FireworksLLM()
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages, stream=False)
|
||||
fireworks_gpt = FireworksLLM()
|
||||
fireworks_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
fireworks_gpt._update_costs(
|
||||
usage=CompletionUsage(prompt_tokens=500000, completion_tokens=500000, total_tokens=1000000)
|
||||
)
|
||||
assert fireworks_gpt.get_costs() == Costs(
|
||||
total_prompt_tokens=500000, total_completion_tokens=500000, total_cost=0.5, total_budget=0
|
||||
)
|
||||
|
||||
resp = await fireworks_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
resp = await fireworks_gpt.aask(prompt_msg, stream=False)
|
||||
|
|
|
|||
99
tests/metagpt/provider/test_general_api_base.py
Normal file
99
tests/metagpt/provider/test_general_api_base.py
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import os
|
||||
from typing import AsyncGenerator, Generator, Iterator, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
import requests
|
||||
from openai import OpenAIError
|
||||
|
||||
from metagpt.provider.general_api_base import (
|
||||
APIRequestor,
|
||||
ApiType,
|
||||
OpenAIResponse,
|
||||
_make_session,
|
||||
_requests_proxies_arg,
|
||||
log_debug,
|
||||
log_info,
|
||||
log_warn,
|
||||
parse_stream,
|
||||
parse_stream_helper,
|
||||
)
|
||||
|
||||
|
||||
def test_basic():
|
||||
_ = ApiType.from_str("azure")
|
||||
_ = ApiType.from_str("azuread")
|
||||
_ = ApiType.from_str("openai")
|
||||
with pytest.raises(OpenAIError):
|
||||
_ = ApiType.from_str("xx")
|
||||
|
||||
os.environ.setdefault("LLM_LOG", "debug")
|
||||
log_debug("debug")
|
||||
log_warn("warn")
|
||||
log_info("info")
|
||||
|
||||
|
||||
def test_openai_response():
|
||||
resp = OpenAIResponse(data=[], headers={"retry-after": 3})
|
||||
assert resp.request_id is None
|
||||
assert resp.retry_after == 3
|
||||
assert resp.operation_location is None
|
||||
assert resp.organization is None
|
||||
assert resp.response_ms is None
|
||||
|
||||
|
||||
def test_proxy():
|
||||
assert _requests_proxies_arg(proxy=None) is None
|
||||
|
||||
proxy = "127.0.0.1:80"
|
||||
assert _requests_proxies_arg(proxy=proxy) == {"http": proxy, "https": proxy}
|
||||
proxy_dict = {"http": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
proxy_dict = {"https": proxy}
|
||||
assert _requests_proxies_arg(proxy=proxy_dict) == proxy_dict
|
||||
|
||||
assert _make_session() is not None
|
||||
|
||||
|
||||
def test_parse_stream():
|
||||
assert parse_stream_helper(None) is None
|
||||
assert parse_stream_helper(b"data: [DONE]") is None
|
||||
assert parse_stream_helper(b"data: test") == "test"
|
||||
assert parse_stream_helper(b"test") is None
|
||||
for line in parse_stream([b"data: test"]):
|
||||
assert line == "test"
|
||||
|
||||
|
||||
api_requestor = APIRequestor(base_url="http://www.baidu.com")
|
||||
|
||||
|
||||
def mock_interpret_response(
|
||||
self, result: requests.Response, stream: bool
|
||||
) -> Tuple[Union[bytes, Iterator[Generator]], bytes]:
|
||||
return b"baidu", False
|
||||
|
||||
|
||||
async def mock_interpret_async_response(
|
||||
self, result: aiohttp.ClientResponse, stream: bool
|
||||
) -> Tuple[Union[OpenAIResponse, AsyncGenerator[OpenAIResponse, None]], bool]:
|
||||
return b"baidu", True
|
||||
|
||||
|
||||
def test_api_requestor(mocker):
|
||||
mocker.patch("metagpt.provider.general_api_base.APIRequestor._interpret_response", mock_interpret_response)
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
|
||||
resp, _, _ = api_requestor.request(method="post", url="/s?wd=baidu")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_api_requestor(mocker):
|
||||
mocker.patch(
|
||||
"metagpt.provider.general_api_base.APIRequestor._interpret_async_response", mock_interpret_async_response
|
||||
)
|
||||
resp, _, _ = await api_requestor.arequest(method="get", url="/s?wd=baidu")
|
||||
resp, _, _ = await api_requestor.arequest(method="post", url="/s?wd=baidu")
|
||||
|
|
@ -4,11 +4,24 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.general_api_requestor import GeneralAPIRequestor
|
||||
from metagpt.provider.general_api_requestor import (
|
||||
GeneralAPIRequestor,
|
||||
parse_stream,
|
||||
parse_stream_helper,
|
||||
)
|
||||
|
||||
api_requestor = GeneralAPIRequestor(base_url="http://www.baidu.com")
|
||||
|
||||
|
||||
def test_parse_stream():
|
||||
assert parse_stream_helper(None) is None
|
||||
assert parse_stream_helper(b"data: [DONE]") is None
|
||||
assert parse_stream_helper(b"data: test") == b"test"
|
||||
assert parse_stream_helper(b"test") is None
|
||||
for line in parse_stream([b"data: test"]):
|
||||
assert line == b"test"
|
||||
|
||||
|
||||
def test_api_requestor():
|
||||
resp, _, _ = api_requestor.request(method="get", url="/s?wd=baidu")
|
||||
assert b"baidu" in resp
|
||||
|
|
|
|||
|
|
@ -6,8 +6,13 @@ from abc import ABC
|
|||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from google.ai import generativelanguage as glm
|
||||
from google.generativeai.types import content_types
|
||||
|
||||
from metagpt.provider.google_gemini_api import GeminiGPTAPI
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.google_gemini_api import GeminiLLM
|
||||
|
||||
CONFIG.gemini_api_key = "xx"
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -21,28 +26,52 @@ resp_content = "I'm gemini from google"
|
|||
default_resp = MockGeminiResponse(text=resp_content)
|
||||
|
||||
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> MockGeminiResponse:
|
||||
def mock_gemini_count_tokens(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
return glm.CountTokensResponse(total_tokens=20)
|
||||
|
||||
|
||||
async def mock_gemini_count_tokens_async(self, contents: content_types.ContentsType) -> glm.CountTokensResponse:
|
||||
return glm.CountTokensResponse(total_tokens=20)
|
||||
|
||||
|
||||
def mock_gemini_generate_content(self, **kwargs) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(
|
||||
self, messgaes: list[dict], stream: bool = False, timeout: int = 60
|
||||
) -> MockGeminiResponse:
|
||||
return default_resp
|
||||
async def mock_gemini_generate_content_async(self, stream: bool = False, **kwargs) -> MockGeminiResponse:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens", mock_gemini_count_tokens)
|
||||
mocker.patch(
|
||||
"metagpt.provider.google_gemini_api.GeminiGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
"metagpt.provider.google_gemini_api.GeminiGenerativeModel.count_tokens_async", mock_gemini_count_tokens_async
|
||||
)
|
||||
gemini_gpt = GeminiGPTAPI()
|
||||
mocker.patch("google.generativeai.generative_models.GenerativeModel.generate_content", mock_gemini_generate_content)
|
||||
mocker.patch(
|
||||
"google.generativeai.generative_models.GenerativeModel.generate_content_async",
|
||||
mock_gemini_generate_content_async,
|
||||
)
|
||||
|
||||
gemini_gpt = GeminiLLM()
|
||||
|
||||
assert gemini_gpt._user_msg(prompt_msg) == {"role": "user", "parts": [prompt_msg]}
|
||||
assert gemini_gpt._assistant_msg(prompt_msg) == {"role": "model", "parts": [prompt_msg]}
|
||||
|
||||
usage = gemini_gpt.get_usage(messages, resp_content)
|
||||
assert usage == {"prompt_tokens": 20, "completion_tokens": 20}
|
||||
|
||||
resp = gemini_gpt.completion(messages)
|
||||
assert resp == default_resp
|
||||
|
||||
resp = await gemini_gpt.acompletion(messages)
|
||||
assert resp.text == default_resp.text
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@
|
|||
@Author : mashenquan
|
||||
@File : test_metagpt_llm_api.py
|
||||
"""
|
||||
from metagpt.provider.metagpt_api import MetaGPTAPI
|
||||
from metagpt.provider.metagpt_api import MetaGPTLLM
|
||||
|
||||
|
||||
def test_metagpt():
|
||||
llm = MetaGPTAPI()
|
||||
llm = MetaGPTLLM()
|
||||
assert llm
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ollama api
|
||||
|
||||
import json
|
||||
from typing import Any, Tuple
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -14,25 +17,33 @@ resp_content = "I'm ollama"
|
|||
default_resp = {"message": {"role": "assistant", "content": resp_content}}
|
||||
|
||||
CONFIG.ollama_api_base = "http://xxx"
|
||||
CONFIG.max_budget = 10
|
||||
|
||||
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
async def mock_ollama_arequest(self, stream: bool = False, **kwargs) -> Tuple[Any, Any, bool]:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
events = [
|
||||
b'{"message": {"role": "assistant", "content": "I\'m ollama"}, "done": false}',
|
||||
b'{"prompt_eval_count": 20, "eval_count": 20, "done": true}',
|
||||
]
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
|
||||
return default_resp
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
yield event
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
return Iterator(), None, None
|
||||
else:
|
||||
raw_default_resp = default_resp.copy()
|
||||
raw_default_resp.update({"prompt_eval_count": 20, "eval_count": 20})
|
||||
return json.dumps(raw_default_resp).encode(), None, None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_gemini_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.ollama_api.OllamaGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream)
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_ollama_arequest)
|
||||
|
||||
ollama_gpt = OllamaLLM()
|
||||
|
||||
resp = await ollama_gpt.acompletion(messages)
|
||||
|
|
|
|||
|
|
@ -1,25 +1,95 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/12/28
|
||||
@Author : mashenquan
|
||||
@File : test_open_llm_api.py
|
||||
"""
|
||||
from metagpt.config import CONFIG, LLMProviderEnum
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.open_llm_api import OpenLLMCostManager
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
from openai.types.chat.chat_completion import (
|
||||
ChatCompletion,
|
||||
ChatCompletionMessage,
|
||||
Choice,
|
||||
)
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as AChoice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.completion_usage import CompletionUsage
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.open_llm_api import OpenLLM
|
||||
from metagpt.utils.cost_manager import Costs
|
||||
|
||||
CONFIG.max_budget = 10
|
||||
CONFIG.calc_usage = True
|
||||
|
||||
resp_content = "I'm llama2"
|
||||
default_resp = ChatCompletion(
|
||||
id="cmpl-a6652c1bb181caae8dd19ad8",
|
||||
model="llama-v2-13b-chat",
|
||||
object="chat.completion",
|
||||
created=1703302755,
|
||||
choices=[
|
||||
Choice(
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
message=ChatCompletionMessage(role="assistant", content=resp_content),
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
default_resp_chunk = ChatCompletionChunk(
|
||||
id=default_resp.id,
|
||||
model=default_resp.model,
|
||||
object="chat.completion.chunk",
|
||||
created=default_resp.created,
|
||||
choices=[
|
||||
AChoice(
|
||||
delta=ChoiceDelta(content=resp_content, role="assistant"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
logprobs=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
|
||||
def test_llm():
|
||||
llm = LLM(provider=LLMProviderEnum.OPEN_LLM)
|
||||
assert llm
|
||||
async def mock_openai_acompletions_create(self, stream: bool = False, **kwargs) -> ChatCompletionChunk:
|
||||
if stream:
|
||||
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield default_resp_chunk
|
||||
|
||||
return Iterator()
|
||||
else:
|
||||
return default_resp
|
||||
|
||||
|
||||
def test_cost():
|
||||
# Prerequisites
|
||||
CONFIG.max_budget = 10
|
||||
@pytest.mark.asyncio
|
||||
async def test_openllm_acompletion(mocker):
|
||||
mocker.patch("openai.resources.chat.completions.AsyncCompletions.create", mock_openai_acompletions_create)
|
||||
|
||||
cost = OpenLLMCostManager()
|
||||
cost.update_cost(prompt_tokens=10, completion_tokens=1, model="gpt-35-turbo")
|
||||
assert cost.get_total_prompt_tokens() > 0
|
||||
assert cost.get_total_completion_tokens() > 0
|
||||
openllm_gpt = OpenLLM()
|
||||
openllm_gpt.model = "llama-v2-13b-chat"
|
||||
|
||||
openllm_gpt._update_costs(usage=CompletionUsage(prompt_tokens=100, completion_tokens=100, total_tokens=200))
|
||||
assert openllm_gpt.get_costs() == Costs(
|
||||
total_prompt_tokens=100, total_completion_tokens=100, total_cost=0, total_budget=0
|
||||
)
|
||||
|
||||
resp = await openllm_gpt.acompletion(messages)
|
||||
assert resp.choices[0].message.content in resp_content
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=False)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.acompletion_text(messages, stream=True)
|
||||
assert resp == resp_content
|
||||
|
||||
resp = await openllm_gpt.aask(prompt_msg)
|
||||
assert resp == resp_content
|
||||
|
|
|
|||
|
|
@ -2,9 +2,14 @@ from unittest.mock import Mock
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.openai_api import OpenAILLM
|
||||
from metagpt.schema import UserMessage
|
||||
|
||||
CONFIG.openai_proxy = None
|
||||
|
||||
print("openai_api_key ", CONFIG.openai_api_key)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_aask_code():
|
||||
|
|
|
|||
|
|
@ -4,24 +4,31 @@
|
|||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.spark_api import SparkLLM
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.spark_api import GetMessageFromWeb, SparkLLM
|
||||
|
||||
CONFIG.spark_appid = "xxx"
|
||||
CONFIG.spark_api_secret = "xxx"
|
||||
CONFIG.spark_api_key = "xxx"
|
||||
CONFIG.domain = "xxxxxx"
|
||||
CONFIG.spark_url = "xxxx"
|
||||
|
||||
prompt_msg = "who are you"
|
||||
resp_content = "I'm Spark"
|
||||
|
||||
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> str:
|
||||
return resp_content
|
||||
def test_get_msg_from_web():
|
||||
get_msg_from_web = GetMessageFromWeb(text=prompt_msg)
|
||||
assert get_msg_from_web.gen_params()["parameter"]["chat"]["domain"] == "xxxxxx"
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> str:
|
||||
def mock_spark_get_msg_from_web_run(self) -> str:
|
||||
return resp_content
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spark_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.spark_api.SparkGPTAPI.acompletion_text", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.spark_api.GetMessageFromWeb.run", mock_spark_get_msg_from_web_run)
|
||||
spark_gpt = SparkLLM()
|
||||
|
||||
resp = await spark_gpt.acompletion([])
|
||||
|
|
|
|||
|
|
@ -1,41 +1,71 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc : the unittest of ZhiPuAIGPTAPI
|
||||
# @Desc : the unittest of ZhiPuAILLM
|
||||
|
||||
import pytest
|
||||
from zhipuai.utils.sse_client import Event
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
|
||||
from metagpt.provider.zhipuai_api import ZhiPuAILLM
|
||||
|
||||
CONFIG.zhipuai_api_key = "xxx"
|
||||
CONFIG.zhipuai_api_key = "xxx.xxx"
|
||||
|
||||
prompt_msg = "who are you"
|
||||
messages = [{"role": "user", "content": prompt_msg}]
|
||||
|
||||
resp_content = "I'm chatglm-turbo"
|
||||
default_resp = {"code": 200, "data": {"choices": [{"role": "assistant", "content": resp_content}]}}
|
||||
default_resp = {
|
||||
"code": 200,
|
||||
"data": {
|
||||
"choices": [{"role": "assistant", "content": resp_content}],
|
||||
"usage": {"prompt_tokens": 20, "completion_tokens": 20},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def mock_llm_completion(self, messages: list[dict], timeout: int = 60) -> dict:
|
||||
def mock_zhipuai_invoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_acompletion(self, messgaes: list[dict], stream: bool = False, timeout: int = 60) -> dict:
|
||||
async def mock_zhipuai_ainvoke(**kwargs) -> dict:
|
||||
return default_resp
|
||||
|
||||
|
||||
async def mock_llm_achat_completion_stream(self, messgaes: list[dict]) -> str:
|
||||
return resp_content
|
||||
async def mock_zhipuai_asse_invoke(**kwargs):
|
||||
class MockResponse(object):
|
||||
async def _aread(self):
|
||||
class Iterator(object):
|
||||
events = [
|
||||
Event(id="xxx", event="add", data=resp_content, retry=0),
|
||||
Event(
|
||||
id="xxx",
|
||||
event="finish",
|
||||
data="",
|
||||
meta='{"usage": {"completion_tokens": 20,"prompt_tokens": 20}}',
|
||||
),
|
||||
]
|
||||
|
||||
async def __aiter__(self):
|
||||
for event in self.events:
|
||||
yield event
|
||||
|
||||
async for chunk in Iterator():
|
||||
yield chunk
|
||||
|
||||
async def async_events(self):
|
||||
async for chunk in self._aread():
|
||||
yield chunk
|
||||
|
||||
return MockResponse()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipuai_acompletion(mocker):
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI.acompletion", mock_llm_acompletion)
|
||||
mocker.patch("metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion", mock_llm_acompletion)
|
||||
mocker.patch(
|
||||
"metagpt.provider.zhipuai_api.ZhiPuAIGPTAPI._achat_completion_stream", mock_llm_achat_completion_stream
|
||||
)
|
||||
zhipu_gpt = ZhiPuAIGPTAPI()
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.invoke", mock_zhipuai_invoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.ainvoke", mock_zhipuai_ainvoke)
|
||||
mocker.patch("metagpt.provider.zhipuai.zhipu_model_api.ZhiPuModelAPI.asse_invoke", mock_zhipuai_asse_invoke)
|
||||
|
||||
zhipu_gpt = ZhiPuAILLM()
|
||||
|
||||
resp = await zhipu_gpt.acompletion(messages)
|
||||
assert resp["data"]["choices"][0]["content"] == resp_content
|
||||
|
|
@ -53,11 +83,11 @@ async def test_zhipuai_acompletion(mocker):
|
|||
assert resp == resp_content
|
||||
|
||||
|
||||
def test_zhipuai_proxy(mocker):
|
||||
def test_zhipuai_proxy():
|
||||
import openai
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
CONFIG.openai_proxy = "http://127.0.0.1:8080"
|
||||
_ = ZhiPuAIGPTAPI()
|
||||
_ = ZhiPuAILLM()
|
||||
assert openai.proxy == CONFIG.openai_proxy
|
||||
|
|
|
|||
3
tests/metagpt/provider/zhipuai/__init__.py
Normal file
3
tests/metagpt/provider/zhipuai/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
18
tests/metagpt/provider/zhipuai/test_async_sse_client.py
Normal file
18
tests/metagpt/provider/zhipuai/test_async_sse_client.py
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.provider.zhipuai.async_sse_client import AsyncSSEClient
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_sse_client():
|
||||
class Iterator(object):
|
||||
async def __aiter__(self):
|
||||
yield b"data: test_value"
|
||||
|
||||
async_sse_client = AsyncSSEClient(event_source=Iterator())
|
||||
async for event in async_sse_client.async_events():
|
||||
assert event.data, "test_value"
|
||||
41
tests/metagpt/provider/zhipuai/test_zhipu_model_api.py
Normal file
41
tests/metagpt/provider/zhipuai/test_zhipu_model_api.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
# @Desc :
|
||||
|
||||
from typing import Any, Tuple
|
||||
|
||||
import pytest
|
||||
import zhipuai
|
||||
from zhipuai.model_api.api import InvokeType
|
||||
from zhipuai.utils.http_client import headers as zhipuai_default_headers
|
||||
|
||||
from metagpt.provider.zhipuai.zhipu_model_api import ZhiPuModelAPI
|
||||
|
||||
api_key = "xxx.xxx"
|
||||
zhipuai.api_key = api_key
|
||||
|
||||
default_resp = {"result": "test response"}
|
||||
|
||||
|
||||
async def mock_requestor_arequest(self, **kwargs) -> Tuple[Any, Any, str]:
|
||||
return default_resp, None, None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_zhipu_model_api(mocker):
|
||||
header = ZhiPuModelAPI.get_header()
|
||||
zhipuai_default_headers.update({"Authorization": api_key})
|
||||
assert header == zhipuai_default_headers
|
||||
|
||||
sse_header = ZhiPuModelAPI.get_sse_header()
|
||||
assert len(sse_header["Authorization"]) == 191
|
||||
|
||||
url_prefix, url_suffix = ZhiPuModelAPI.split_zhipu_api_url(InvokeType.SYNC, kwargs={"model": "chatglm_turbo"})
|
||||
assert url_prefix == "https://open.bigmodel.cn/api"
|
||||
assert url_suffix == "/paas/v3/model-api/chatglm_turbo/invoke"
|
||||
|
||||
mocker.patch("metagpt.provider.general_api_requestor.GeneralAPIRequestor.arequest", mock_requestor_arequest)
|
||||
result = await ZhiPuModelAPI.arequest(
|
||||
InvokeType.SYNC, stream=False, method="get", headers={}, kwargs={"model": "chatglm_turbo"}
|
||||
)
|
||||
assert result == default_resp
|
||||
|
|
@ -28,5 +28,5 @@ async def test_action_deserialize():
|
|||
new_action = Action(**serialized_data)
|
||||
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
assert isinstance(new_action.llm, type(LLM()))
|
||||
assert len(await new_action._aask("who are you")) > 0
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from metagpt.schema import Message
|
|||
from metagpt.utils.common import any_to_str
|
||||
from tests.metagpt.serialize_deserialize.test_serdeser_base import (
|
||||
ActionOK,
|
||||
ActionRaise,
|
||||
RoleC,
|
||||
serdeser_path,
|
||||
)
|
||||
|
|
@ -55,9 +56,9 @@ def test_environment_serdeser():
|
|||
assert len(new_env.roles) == 1
|
||||
|
||||
assert list(new_env.roles.values())[0].states == list(environment.roles.values())[0].states
|
||||
assert list(new_env.roles.values())[0].actions == list(environment.roles.values())[0].actions
|
||||
assert isinstance(list(environment.roles.values())[0].actions[0], ActionOK)
|
||||
assert type(list(new_env.roles.values())[0].actions[0]) == ActionOK
|
||||
assert type(list(new_env.roles.values())[0].actions[1]) == ActionRaise
|
||||
|
||||
|
||||
def test_environment_serdeser_v2():
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
|
|
@ -28,5 +27,4 @@ async def test_write_code_deserialize():
|
|||
new_action = WriteCode(**serialized_data)
|
||||
|
||||
assert new_action.name == "WriteCode"
|
||||
assert new_action.llm == LLM()
|
||||
await action.run()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WriteCodeReview
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import CodingContext, Document
|
||||
|
||||
|
||||
|
|
@ -28,5 +27,4 @@ def div(a: int, b: int = 0):
|
|||
new_action = WriteCodeReview(**serialized_data)
|
||||
|
||||
assert new_action.name == "WriteCodeReview"
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run()
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WriteDesign, WriteTasks
|
||||
from metagpt.llm import LLM
|
||||
|
||||
|
||||
def test_write_design_serialize():
|
||||
|
|
@ -28,7 +27,6 @@ async def test_write_design_deserialize():
|
|||
serialized_data = action.model_dump()
|
||||
new_action = WriteDesign(**serialized_data)
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
||||
|
||||
|
|
@ -38,5 +36,4 @@ async def test_write_task_deserialize():
|
|||
serialized_data = action.model_dump()
|
||||
new_action = WriteTasks(**serialized_data)
|
||||
assert new_action.name == "CreateTasks"
|
||||
assert new_action.llm == LLM()
|
||||
await new_action.run(with_messages="write a cli snake game")
|
||||
|
|
|
|||
|
|
@ -6,7 +6,6 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
||||
|
|
@ -22,7 +21,6 @@ async def test_action_deserialize():
|
|||
action = WritePRD()
|
||||
serialized_data = action.model_dump()
|
||||
new_action = WritePRD(**serialized_data)
|
||||
assert new_action.name == ""
|
||||
assert new_action.llm == LLM()
|
||||
assert new_action.name == "WritePRD"
|
||||
action_output = await new_action.run(with_messages=Message(content="write a cli snake game"))
|
||||
assert len(action_output.content) > 0
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
@Desc : the unittest of serialize
|
||||
"""
|
||||
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
|
@ -27,7 +27,7 @@ def test_actionoutout_schema_to_mapping():
|
|||
"properties": {"field": {"title": "field", "type": "array", "items": {"type": "string"}}},
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping["field"] == (List[str], ...)
|
||||
assert mapping["field"] == (list[str], ...)
|
||||
|
||||
schema = {
|
||||
"title": "test",
|
||||
|
|
@ -46,7 +46,7 @@ def test_actionoutout_schema_to_mapping():
|
|||
},
|
||||
}
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
assert mapping["field"] == (List[Tuple[str, str]], ...)
|
||||
assert mapping["field"] == (list[list[str]], ...)
|
||||
|
||||
assert True, True
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue