mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-05-04 21:32:38 +02:00
Merge branch 'mgx_ops' of https://gitlab.deepwisdomai.com/pub/MetaGPT into check_role_zero
This commit is contained in:
commit
e52672aea5
20 changed files with 311 additions and 70 deletions
25
examples/serialize_model.py
Normal file
25
examples/serialize_model.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
from metagpt.environment.mgx.mgx_env import MGXEnv
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
def main():
|
||||
"""Demonstrates serialization and deserialization using SerializationMixin.
|
||||
|
||||
This example creates an instance of MGXEnv, serializes it to a file,
|
||||
and then deserializes it back to an instance.
|
||||
|
||||
If executed correctly, the following log messages will be output:
|
||||
MGXEnv serialization successful. File saved at: /.../workspace/storage/MGXEnv.json
|
||||
MGXEnv deserialization successful. Instance created from file: /.../workspace/storage/MGXEnv.json
|
||||
The instance is MGXEnv()
|
||||
"""
|
||||
|
||||
env = MGXEnv()
|
||||
env.serialize()
|
||||
|
||||
env: MGXEnv = MGXEnv.deserialize()
|
||||
logger.info(f"The instance is {repr(env)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -4,7 +4,7 @@ from __future__ import annotations
|
|||
|
||||
import json
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
from pydantic import Field, PrivateAttr, model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize
|
||||
|
|
@ -90,6 +90,8 @@ class SearchEnhancedQA(Action):
|
|||
description="Maximum number of search results (links) to collect using the collect_links_action. This controls the number of potential sources for answering the question.",
|
||||
)
|
||||
|
||||
_reporter: ThoughtReporter = PrivateAttr(ThoughtReporter())
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize(self):
|
||||
if self.web_browse_and_summarize_action is None:
|
||||
|
|
@ -118,13 +120,14 @@ class SearchEnhancedQA(Action):
|
|||
Raises:
|
||||
ValueError: If the query is invalid.
|
||||
"""
|
||||
async with self._reporter:
|
||||
await self._reporter.async_report({"type": "search", "stage": "init"})
|
||||
self._validate_query(query)
|
||||
|
||||
self._validate_query(query)
|
||||
processed_query = await self._process_query(query, rewrite_query)
|
||||
context = await self._build_context(processed_query)
|
||||
|
||||
processed_query = await self._process_query(query, rewrite_query)
|
||||
context = await self._build_context(processed_query)
|
||||
|
||||
return await self._generate_answer(processed_query, context)
|
||||
return await self._generate_answer(processed_query, context)
|
||||
|
||||
def _validate_query(self, query: str) -> None:
|
||||
"""Validate the input query.
|
||||
|
|
@ -203,6 +206,7 @@ class SearchEnhancedQA(Action):
|
|||
"""
|
||||
|
||||
relevant_urls = await self._collect_relevant_links(query)
|
||||
await self._reporter.async_report({"type": "search", "stage": "searching", "urls": relevant_urls})
|
||||
if not relevant_urls:
|
||||
logger.warning(f"No relevant URLs found for query: {query}")
|
||||
return []
|
||||
|
|
@ -245,10 +249,12 @@ class SearchEnhancedQA(Action):
|
|||
contents = await self._fetch_web_contents(urls)
|
||||
|
||||
summaries = {}
|
||||
await self._reporter.async_report(
|
||||
{"type": "search", "stage": "browsing", "pages": [i.model_dump() for i in contents]}
|
||||
)
|
||||
for content in contents:
|
||||
url = content.url
|
||||
inner_text = content.inner_text.replace("\n", "")
|
||||
|
||||
if self.web_browse_and_summarize_action._is_content_invalid(inner_text):
|
||||
logger.warning(f"Invalid content detected for URL {url}: {inner_text[:10]}...")
|
||||
continue
|
||||
|
|
@ -276,8 +282,7 @@ class SearchEnhancedQA(Action):
|
|||
|
||||
system_prompt = SEARCH_ENHANCED_QA_SYSTEM_PROMPT.format(context=context)
|
||||
|
||||
async with ThoughtReporter(enable_llm_stream=True) as reporter:
|
||||
await reporter.async_report({"type": "quick"})
|
||||
async with ThoughtReporter(uuid=self._reporter.uuid, enable_llm_stream=True) as reporter:
|
||||
await reporter.async_report({"type": "search", "stage": "answer"})
|
||||
rsp = await self._aask(query, [system_prompt])
|
||||
|
||||
return rsp
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""
|
||||
from typing import Callable, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.tools import SearchEngineType
|
||||
from metagpt.utils.yaml_model import YamlModel
|
||||
|
|
@ -16,10 +16,11 @@ from metagpt.utils.yaml_model import YamlModel
|
|||
class SearchConfig(YamlModel):
|
||||
"""Config for Search"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
api_type: SearchEngineType = SearchEngineType.DUCK_DUCK_GO
|
||||
api_key: str = ""
|
||||
cse_id: str = "" # for google
|
||||
discovery_service_url: str = "" # for google
|
||||
search_func: Optional[Callable] = None
|
||||
params: dict = Field(
|
||||
default_factory=lambda: {
|
||||
|
|
|
|||
|
|
@ -10,11 +10,11 @@ from metagpt.const import AGENT
|
|||
from metagpt.environment.base_env import Environment
|
||||
from metagpt.logs import get_human_input
|
||||
from metagpt.roles import Architect, ProductManager, ProjectManager, Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.schema import Message, SerializationMixin
|
||||
from metagpt.utils.common import any_to_str, any_to_str_set
|
||||
|
||||
|
||||
class MGXEnv(Environment):
|
||||
class MGXEnv(Environment, SerializationMixin):
|
||||
"""MGX Environment"""
|
||||
|
||||
# If True, fixed software sop bypassing TL is allowed, otherwise, TL will fully take over the routing
|
||||
|
|
|
|||
|
|
@ -208,6 +208,6 @@ class CodeReview(Action):
|
|||
comments = await self.confirm_comments(patch=patch, comments=comments, points=points)
|
||||
for comment in comments:
|
||||
if comment["code"]:
|
||||
if not (comment["code"].startswith("-") or comment["code"].isspace()):
|
||||
if not (comment["code"].isspace()):
|
||||
result.append(comment)
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -152,7 +152,7 @@ Respond with a concise thought, then provide the appropriate response category:
|
|||
"""
|
||||
|
||||
|
||||
QUICK_THINK_EXAMPLES ="""
|
||||
QUICK_THINK_EXAMPLES = """
|
||||
# Example
|
||||
|
||||
1. Request: "How do I design an online document editing platform that supports real-time collaboration?"
|
||||
|
|
@ -190,4 +190,4 @@ Response Category: AMBIGUOUS.
|
|||
# Instruction
|
||||
"""
|
||||
|
||||
QUICK_THINK_PROMPT = QUICK_THINK_PROMPT.format(examples=QUICK_THINK_EXAMPLES)
|
||||
QUICK_THINK_PROMPT = QUICK_THINK_PROMPT.format(examples=QUICK_THINK_EXAMPLES)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions.di.execute_nb_code import ExecuteNbCode
|
||||
|
|
@ -31,7 +33,7 @@ class DataAnalyst(RoleZero):
|
|||
tools: list[str] = ["Plan", "DataAnalyst", "RoleZero", "Browser"]
|
||||
custom_tools: list[str] = ["web scraping", "Terminal"]
|
||||
custom_tool_recommender: ToolRecommender = None
|
||||
experience_retriever: ExpRetriever = KeywordExpRetriever()
|
||||
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = KeywordExpRetriever()
|
||||
|
||||
use_reflection: bool = True
|
||||
write_code: WriteAnalysisCode = Field(default_factory=WriteAnalysisCode, exclude=True)
|
||||
|
|
|
|||
|
|
@ -4,9 +4,9 @@ import inspect
|
|||
import json
|
||||
import re
|
||||
import traceback
|
||||
from typing import Callable, Dict, List, Literal, Tuple
|
||||
from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import Action, UserRequirement
|
||||
from metagpt.actions.analyze_requirements import AnalyzeRequirementsRestrictions
|
||||
|
|
@ -41,7 +41,11 @@ from metagpt.utils.common import (
|
|||
extract_image_paths,
|
||||
is_support_image_input,
|
||||
)
|
||||
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
|
||||
from metagpt.utils.repair_llm_raw_output import (
|
||||
RepairType,
|
||||
repair_escape_error,
|
||||
repair_llm_raw_output,
|
||||
)
|
||||
from metagpt.utils.report import ThoughtReporter
|
||||
|
||||
|
||||
|
|
@ -53,12 +57,13 @@ class RoleZero(Role):
|
|||
name: str = "Zero"
|
||||
profile: str = "RoleZero"
|
||||
goal: str = ""
|
||||
system_msg: Optional[list[str]] = None # Use None to conform to the default value at llm.aask
|
||||
system_prompt: str = SYSTEM_PROMPT # Use None to conform to the default value at llm.aask
|
||||
cmd_prompt: str = CMD_PROMPT
|
||||
cmd_prompt_current_state: str = ""
|
||||
thought_guidance: str = THOUGHT_GUIDANCE
|
||||
instruction: str = ROLE_INSTRUCTION
|
||||
task_type_desc: str = None
|
||||
task_type_desc: Optional[str] = None
|
||||
|
||||
# React Mode
|
||||
react_mode: Literal["react"] = "react"
|
||||
|
|
@ -66,15 +71,15 @@ class RoleZero(Role):
|
|||
|
||||
# Tools
|
||||
tools: list[str] = [] # Use special symbol ["<all>"] to indicate use of all registered tools
|
||||
tool_recommender: ToolRecommender = None
|
||||
tool_execution_map: dict[str, Callable] = {}
|
||||
tool_recommender: Optional[ToolRecommender] = None
|
||||
tool_execution_map: Annotated[dict[str, Callable], Field(exclude=True)] = {}
|
||||
special_tool_commands: list[str] = ["Plan.finish_current_task", "end", "Bash.run"]
|
||||
# Equipped with three basic tools by default for optional use
|
||||
editor: Editor = Editor()
|
||||
browser: Browser = Browser()
|
||||
|
||||
# Experience
|
||||
experience_retriever: ExpRetriever = DummyExpRetriever()
|
||||
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = DummyExpRetriever()
|
||||
|
||||
# Others
|
||||
command_rsp: str = "" # the raw string containing the commands
|
||||
|
|
@ -330,10 +335,20 @@ class RoleZero(Role):
|
|||
if commands.endswith("]") and not commands.startswith("["):
|
||||
commands = "[" + commands
|
||||
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
|
||||
except json.JSONDecodeError:
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(f"Failed to parse JSON for: {self.command_rsp}. Trying to repair...")
|
||||
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
|
||||
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
|
||||
commands = await self.llm.aask(
|
||||
msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp, json_decode_error=str(e))
|
||||
)
|
||||
try:
|
||||
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
|
||||
except json.JSONDecodeError:
|
||||
# repair escape error of code and math
|
||||
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
|
||||
new_command = repair_escape_error(commands)
|
||||
commands = json.loads(
|
||||
repair_llm_raw_output(output=new_command, req_keys=[None], repair_type=RepairType.JSON)
|
||||
)
|
||||
except Exception as e:
|
||||
tb = traceback.format_exc()
|
||||
print(tb)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,9 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Annotated
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from metagpt.actions.di.run_command import RunCommand
|
||||
from metagpt.prompts.di.team_leader import (
|
||||
FINISH_CURRENT_TASK_CMD,
|
||||
|
|
@ -24,7 +28,7 @@ class TeamLeader(RoleZero):
|
|||
|
||||
tools: list[str] = ["Plan", "RoleZero", "TeamLeader"]
|
||||
|
||||
experience_retriever: ExpRetriever = SimpleExpRetriever()
|
||||
experience_retriever: Annotated[ExpRetriever, Field(exclude=True)] = SimpleExpRetriever()
|
||||
|
||||
def _update_tool_execution(self):
|
||||
self.tool_execution_map.update(
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ from metagpt.const import (
|
|||
MESSAGE_ROUTE_FROM,
|
||||
MESSAGE_ROUTE_TO,
|
||||
MESSAGE_ROUTE_TO_ALL,
|
||||
SERDESER_PATH,
|
||||
SYSTEM_DESIGN_FILE_REPO,
|
||||
TASK_FILE_REPO,
|
||||
)
|
||||
|
|
@ -56,6 +57,8 @@ from metagpt.utils.common import (
|
|||
any_to_str_set,
|
||||
aread,
|
||||
import_class,
|
||||
read_json_file,
|
||||
write_json_file,
|
||||
)
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.report import TaskReporter
|
||||
|
|
@ -127,6 +130,65 @@ class SerializationMixin(BaseModel, extra="forbid"):
|
|||
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
@handle_exception
|
||||
def serialize(self, file_path: str = None) -> str:
|
||||
"""Serializes the current instance to a JSON file.
|
||||
|
||||
If an exception occurs, `handle_exception` will catch it and return `None`.
|
||||
|
||||
Args:
|
||||
file_path (str, optional): The path to the JSON file where the instance will be saved. Defaults to None.
|
||||
|
||||
Returns:
|
||||
str: The path to the JSON file where the instance was saved.
|
||||
"""
|
||||
|
||||
file_path = file_path or self.get_serialization_path()
|
||||
|
||||
serialized_data = self.model_dump()
|
||||
|
||||
write_json_file(file_path, serialized_data)
|
||||
logger.info(f"{self.__class__.__qualname__} serialization successful. File saved at: {file_path}")
|
||||
|
||||
return file_path
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
def deserialize(cls, file_path: str = None) -> BaseModel:
|
||||
"""Deserializes a JSON file to an instance of cls.
|
||||
|
||||
If an exception occurs, `handle_exception` will catch it and return `None`.
|
||||
|
||||
Args:
|
||||
file_path (str, optional): The path to the JSON file to read from. Defaults to None.
|
||||
|
||||
Returns:
|
||||
An instance of the cls.
|
||||
"""
|
||||
|
||||
file_path = file_path or cls.get_serialization_path()
|
||||
|
||||
data: dict = read_json_file(file_path)
|
||||
|
||||
model = cls(**data)
|
||||
logger.info(f"{cls.__qualname__} deserialization successful. Instance created from file: {file_path}")
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def get_serialization_path(cls) -> str:
|
||||
"""Get the serialization path for the class.
|
||||
|
||||
This method constructs a file path for serialization based on the class name.
|
||||
The default path is constructed as './workspace/storage/ClassName.json', where 'ClassName'
|
||||
is the name of the class.
|
||||
|
||||
Returns:
|
||||
str: The path to the serialization file.
|
||||
"""
|
||||
|
||||
return str(SERDESER_PATH / f"{cls.__qualname__}.json")
|
||||
|
||||
|
||||
class SimpleMessage(BaseModel):
|
||||
content: str
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from pydantic import BaseModel
|
|||
class ExpRetriever(BaseModel):
|
||||
"""interface for experience retriever"""
|
||||
|
||||
def retrieve(self, context: str) -> str:
|
||||
def retrieve(self, context: str = "") -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from playwright.async_api import (
|
|||
Request,
|
||||
async_playwright,
|
||||
)
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.a11y_tree import (
|
||||
|
|
@ -43,7 +44,7 @@ from metagpt.utils.report import BrowserReporter
|
|||
"type",
|
||||
],
|
||||
)
|
||||
class Browser:
|
||||
class Browser(BaseModel):
|
||||
"""A tool for browsing the web. Don't initialize a new instance of this class if one already exists.
|
||||
|
||||
Note: If you plan to use the browser to assist you in completing tasks, then using the browser should be a standalone
|
||||
|
|
@ -66,16 +67,17 @@ class Browser:
|
|||
>>> await browser.close_tab()
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.playwright: Optional[Playwright] = None
|
||||
self.browser_instance: Optional[Browser_] = None
|
||||
self.browser_ctx: Optional[BrowserContext] = None
|
||||
self.page: Optional[Page] = None
|
||||
self.accessibility_tree: list = []
|
||||
self.headless: bool = True
|
||||
self.proxy = get_proxy_from_env()
|
||||
self.is_empty_page = True
|
||||
self.reporter = BrowserReporter()
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
playwright: Optional[Playwright] = None
|
||||
browser_instance: Optional[Browser_] = None
|
||||
browser_ctx: Optional[BrowserContext] = None
|
||||
page: Optional[Page] = None
|
||||
accessibility_tree: list = Field(default_factory=list)
|
||||
headless: bool = True
|
||||
proxy: Optional[str] = Field(default_factory=get_proxy_from_env)
|
||||
is_empty_page: bool = True
|
||||
reporter: BrowserReporter = Field(default_factory=BrowserReporter)
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Starts Playwright and launches a browser"""
|
||||
|
|
|
|||
|
|
@ -5,9 +5,8 @@ import subprocess
|
|||
from pathlib import Path
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.const import DEFAULT_WORKSPACE_ROOT
|
||||
from metagpt.logs import logger
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils import read_docx
|
||||
|
|
@ -24,12 +23,12 @@ class FileBlock(BaseModel):
|
|||
|
||||
|
||||
@register_tool()
|
||||
class Editor:
|
||||
class Editor(BaseModel):
|
||||
"""A tool for reading, understanding, writing, and editing files"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
print(f"Editor initialized with root path at: {DEFAULT_WORKSPACE_ROOT}")
|
||||
self.resource = EditorReporter()
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
resource: EditorReporter = EditorReporter()
|
||||
|
||||
def write(self, path: str, content: str):
|
||||
"""Write the whole content to a file. When used, make sure content arg contains the full content of the file."""
|
||||
|
|
|
|||
|
|
@ -31,6 +31,10 @@ async def default_get_env(key: str, app_name: str = None) -> str:
|
|||
if app_key in os.environ:
|
||||
return os.environ[app_key]
|
||||
|
||||
env_app_key = app_key.replace("-", "_") # "-" is not supported by linux environment variable
|
||||
if env_app_key in os.environ:
|
||||
return os.environ[env_app_key]
|
||||
|
||||
from metagpt.context import Context
|
||||
|
||||
context = Context()
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@
|
|||
@File : search_engine_serpapi.py
|
||||
"""
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
|
@ -24,6 +24,7 @@ class SerpAPIWrapper(BaseModel):
|
|||
"hl": "en",
|
||||
}
|
||||
)
|
||||
url: str = "https://serpapi.com/search"
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
proxy: Optional[str] = None
|
||||
|
||||
|
|
@ -49,22 +50,18 @@ class SerpAPIWrapper(BaseModel):
|
|||
async def results(self, query: str, max_results: int) -> dict:
|
||||
"""Use aiohttp to run query through SerpAPI and return the results async."""
|
||||
|
||||
def construct_url_and_params() -> Tuple[str, Dict[str, str]]:
|
||||
params = self.get_params(query)
|
||||
params["source"] = "python"
|
||||
params["num"] = max_results
|
||||
params["output"] = "json"
|
||||
url = "https://serpapi.com/search"
|
||||
return url, params
|
||||
params = self.get_params(query)
|
||||
params["source"] = "python"
|
||||
params["num"] = max_results
|
||||
params["output"] = "json"
|
||||
|
||||
url, params = construct_url_and_params()
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, params=params, proxy=self.proxy) as response:
|
||||
async with session.get(self.url, params=params, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get(url, params=params, proxy=self.proxy) as response:
|
||||
async with self.aiosession.get(self.url, params=params, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""
|
||||
import json
|
||||
import warnings
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
|
@ -17,6 +17,7 @@ class SerperWrapper(BaseModel):
|
|||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
api_key: str
|
||||
url: str = "https://google.serper.dev/search"
|
||||
payload: dict = Field(default_factory=lambda: {"page": 1, "num": 10})
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
proxy: Optional[str] = None
|
||||
|
|
@ -33,6 +34,7 @@ class SerperWrapper(BaseModel):
|
|||
"To use serper search engine, make sure you provide the `api_key` when constructing an object. You can obtain "
|
||||
"an API key from https://serper.dev/."
|
||||
)
|
||||
|
||||
return values
|
||||
|
||||
async def run(self, query: str, max_results: int = 8, as_string: bool = True, **kwargs: Any) -> str:
|
||||
|
|
@ -46,20 +48,16 @@ class SerperWrapper(BaseModel):
|
|||
async def results(self, queries: list[str], max_results: int = 8) -> dict:
|
||||
"""Use aiohttp to run query through Serper and return the results async."""
|
||||
|
||||
def construct_url_and_payload_and_headers() -> Tuple[str, Dict[str, str]]:
|
||||
payloads = self.get_payloads(queries, max_results)
|
||||
url = "https://google.serper.dev/search"
|
||||
headers = self.get_headers()
|
||||
return url, payloads, headers
|
||||
payloads = self.get_payloads(queries, max_results)
|
||||
headers = self.get_headers()
|
||||
|
||||
url, payloads, headers = construct_url_and_payload_and_headers()
|
||||
if not self.aiosession:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
|
||||
async with session.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
else:
|
||||
async with self.aiosession.get.post(url, data=payloads, headers=headers, proxy=self.proxy) as response:
|
||||
async with self.aiosession.post(self.url, data=payloads, headers=headers, proxy=self.proxy) as response:
|
||||
response.raise_for_status()
|
||||
res = await response.json()
|
||||
|
||||
|
|
|
|||
|
|
@ -64,6 +64,10 @@ class ToolRecommender(BaseModel):
|
|||
@field_validator("tools", mode="before")
|
||||
@classmethod
|
||||
def validate_tools(cls, v: list[str]) -> dict[str, Tool]:
|
||||
# If `v` is already a dictionary (e.g., during deserialization), return it as is.
|
||||
if isinstance(v, dict):
|
||||
return v
|
||||
|
||||
# One can use special symbol ["<all>"] to indicate use of all registered tools
|
||||
if v == ["<all>"]:
|
||||
return TOOL_REGISTRY.get_all_tools()
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ import time
|
|||
import traceback
|
||||
from asyncio import iscoroutinefunction
|
||||
from datetime import datetime
|
||||
from functools import partial
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
||||
|
|
@ -577,13 +578,32 @@ def read_json_file(json_file: str, encoding="utf-8") -> list[Any]:
|
|||
return data
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: list, encoding: str = None, indent: int = 4):
|
||||
def handle_unknown_serialization(x: Any) -> str:
|
||||
"""For `to_jsonable_python` debug, unknown values will be logged instead of raising an exception."""
|
||||
|
||||
if inspect.ismethod(x):
|
||||
logger.error(f"Method: {x.__self__.__class__.__name__}.{x.__func__.__name__}")
|
||||
elif inspect.isfunction(x):
|
||||
logger.error(f"Function: {x.__name__}")
|
||||
elif hasattr(x, "__class__"):
|
||||
logger.error(f"Instance of: {x.__class__.__name__}")
|
||||
elif hasattr(x, "__name__"):
|
||||
logger.error(f"Class or module: {x.__name__}")
|
||||
else:
|
||||
logger.error(f"Unknown type: {type(x)}")
|
||||
|
||||
return f"<Unserializable {type(x).__name__} object>"
|
||||
|
||||
|
||||
def write_json_file(json_file: str, data: Any, encoding: str = None, indent: int = 4, use_fallback: bool = False):
|
||||
folder_path = Path(json_file).parent
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
custom_default = partial(to_jsonable_python, fallback=handle_unknown_serialization if use_fallback else None)
|
||||
|
||||
with open(json_file, "w", encoding=encoding) as fout:
|
||||
json.dump(data, fout, ensure_ascii=False, indent=indent, default=to_jsonable_python)
|
||||
json.dump(data, fout, ensure_ascii=False, indent=indent, default=custom_default)
|
||||
|
||||
|
||||
def read_csv_to_list(curr_file: str, header=False, strip_trail=True):
|
||||
|
|
|
|||
|
|
@ -347,3 +347,44 @@ def extract_state_value_from_output(content: str) -> str:
|
|||
matches = list(set(matches))
|
||||
state = matches[0] if len(matches) > 0 else "-1"
|
||||
return state
|
||||
|
||||
|
||||
def repair_escape_error(commands):
|
||||
"""
|
||||
Repaires escape errors in command responses.
|
||||
When RoleZero parses a command, the command may contain unknown escape characters.
|
||||
|
||||
This function has two steps:
|
||||
1. Transform unescaped substrings like "\d" and "\(" to "\\\\d" and "\\\\(".
|
||||
2. Transform escaped characters like '\f' to substrings like "\\\\f".
|
||||
|
||||
Example:
|
||||
When the original JSON string is " {"content":"\\\\( \\\\frac{1}{2} \\\\)"} ",
|
||||
The "content" will be parsed correctly to "\( \frac{1}{2} \)".
|
||||
|
||||
However, if the original JSON string is " {"content":"\( \frac{1}{2} \)"}" directly.
|
||||
It will cause a parsing error.
|
||||
|
||||
To repair the wrong JSON string, the following transformations will be used:
|
||||
"\(" ---> "\\\\("
|
||||
'\f' ---> "\\\\f"
|
||||
"\)" ---> "\\\\)"
|
||||
|
||||
"""
|
||||
escape_repair_map = {
|
||||
"\a": "\\\\a",
|
||||
"\b": "\\\\b",
|
||||
"\f": "\\\\f",
|
||||
"\r": "\\\\r",
|
||||
"\t": "\\\\t",
|
||||
"\v": "\\\\v",
|
||||
}
|
||||
new_command = ""
|
||||
for index, ch in enumerate(commands):
|
||||
if ch == "\\" and index + 1 < len(commands):
|
||||
if commands[index + 1] not in ["n", '"', " "]:
|
||||
new_command += "\\"
|
||||
elif ch in escape_repair_map:
|
||||
ch = escape_repair_map[ch]
|
||||
new_command += ch
|
||||
return new_command
|
||||
|
|
|
|||
|
|
@ -9,13 +9,15 @@
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import Annotated
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.write_code import WriteCode
|
||||
from metagpt.const import SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.const import SERDESER_PATH, SYSTEM_DESIGN_FILE_REPO, TASK_FILE_REPO
|
||||
from metagpt.schema import (
|
||||
AIMessage,
|
||||
CodeSummarizeContext,
|
||||
|
|
@ -23,6 +25,7 @@ from metagpt.schema import (
|
|||
Message,
|
||||
MessageQueue,
|
||||
Plan,
|
||||
SerializationMixin,
|
||||
SystemMessage,
|
||||
Task,
|
||||
UMLClassAttribute,
|
||||
|
|
@ -398,5 +401,64 @@ def test_create_instruct_value(name, value):
|
|||
assert obj.model_dump() == value
|
||||
|
||||
|
||||
class TestUserModel(SerializationMixin, BaseModel):
|
||||
name: str
|
||||
value: int
|
||||
|
||||
|
||||
class TestUserModelWithExclude(TestUserModel):
|
||||
age: Annotated[int, Field(exclude=True)]
|
||||
|
||||
|
||||
class TestSerializationMixin:
|
||||
@pytest.fixture
|
||||
def mock_write_json_file(self, mocker):
|
||||
return mocker.patch("metagpt.schema.write_json_file")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_read_json_file(self, mocker):
|
||||
return mocker.patch("metagpt.schema.read_json_file")
|
||||
|
||||
@pytest.fixture
|
||||
def mock_user_model(self):
|
||||
return TestUserModel(name="test", value=42)
|
||||
|
||||
def test_serialize(self, mock_write_json_file, mock_user_model):
|
||||
file_path = "test.json"
|
||||
|
||||
mock_user_model.serialize(file_path)
|
||||
|
||||
mock_write_json_file.assert_called_once_with(file_path, mock_user_model.model_dump())
|
||||
|
||||
def test_deserialize(self, mock_read_json_file):
|
||||
file_path = "test.json"
|
||||
data = {"name": "test", "value": 42}
|
||||
mock_read_json_file.return_value = data
|
||||
|
||||
model = TestUserModel.deserialize(file_path)
|
||||
|
||||
mock_read_json_file.assert_called_once_with(file_path)
|
||||
assert model == TestUserModel(**data)
|
||||
|
||||
def test_serialize_with_exclude(self, mock_write_json_file):
|
||||
model = TestUserModelWithExclude(name="test", value=42, age=10)
|
||||
file_path = "test.json"
|
||||
|
||||
model.serialize(file_path)
|
||||
|
||||
expected_data = {
|
||||
"name": "test",
|
||||
"value": 42,
|
||||
"__module_class_name": "tests.metagpt.test_schema.TestUserModelWithExclude",
|
||||
}
|
||||
|
||||
mock_write_json_file.assert_called_once_with(file_path, expected_data)
|
||||
|
||||
def test_get_serialization_path(self):
|
||||
expected_path = str(SERDESER_PATH / "TestUserModel.json")
|
||||
|
||||
assert TestUserModel.get_serialization_path() == expected_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-s"])
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue