mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-29 15:59:42 +02:00
feat: merge
This commit is contained in:
commit
f76078dedf
95 changed files with 1629 additions and 948 deletions
|
|
@ -13,7 +13,7 @@ from metagpt.actions.add_requirement import UserRequirement
|
|||
from metagpt.actions.debug_error import DebugError
|
||||
from metagpt.actions.design_api import WriteDesign
|
||||
from metagpt.actions.design_api_review import DesignReview
|
||||
from metagpt.actions.project_management import AssignTasks, WriteTasks
|
||||
from metagpt.actions.project_management import WriteTasks
|
||||
from metagpt.actions.research import CollectLinks, WebBrowseAndSummarize, ConductResearch
|
||||
from metagpt.actions.run_code import RunCode
|
||||
from metagpt.actions.search_and_summarize import SearchAndSummarize
|
||||
|
|
@ -38,7 +38,6 @@ class ActionType(Enum):
|
|||
RUN_CODE = RunCode
|
||||
DEBUG_ERROR = DebugError
|
||||
WRITE_TASKS = WriteTasks
|
||||
ASSIGN_TASKS = AssignTasks
|
||||
SEARCH_AND_SUMMARIZE = SearchAndSummarize
|
||||
COLLECT_LINKS = CollectLinks
|
||||
WEB_BROWSE_AND_SUMMARIZE = WebBrowseAndSummarize
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ from __future__ import annotations
|
|||
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
|
|
@ -19,50 +19,31 @@ from metagpt.schema import (
|
|||
CodeSummarizeContext,
|
||||
CodingContext,
|
||||
RunCodeContext,
|
||||
SerializationMixin,
|
||||
TestingContext,
|
||||
)
|
||||
|
||||
action_subclass_registry = {}
|
||||
|
||||
class Action(SerializationMixin, is_polymorphic_base=True):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
||||
class Action(BaseModel):
|
||||
name: str = ""
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
|
||||
context: Union[dict, CodingContext, CodeSummarizeContext, TestingContext, RunCodeContext, str, None] = ""
|
||||
prefix = "" # aask*时会加上prefix,作为system_message
|
||||
desc = "" # for skill manager
|
||||
prefix: str = "" # aask*时会加上prefix,作为system_message
|
||||
desc: str = "" # for skill manager
|
||||
node: ActionNode = Field(default=None, exclude=True)
|
||||
|
||||
# builtin variables
|
||||
builtin_class_name: str = ""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init_with_instruction(self, instruction: str):
|
||||
"""Initialize action with instruction"""
|
||||
self.node = ActionNode(key=self.name, expected_type=str, instruction=instruction, example="", schema="raw")
|
||||
return self
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
# deserialize child classes dynamically for inherited `action`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
self.__fields__["builtin_class_name"].default = self.__class__.__name__
|
||||
|
||||
if "instruction" in kwargs:
|
||||
self.__init_with_instruction(kwargs["instruction"])
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
action_subclass_registry[cls.__name__] = cls
|
||||
|
||||
def dict(self, *args, **kwargs) -> "DictStrAny":
|
||||
obj_dict = super().dict(*args, **kwargs)
|
||||
if "llm" in obj_dict:
|
||||
obj_dict.pop("llm")
|
||||
return obj_dict
|
||||
if "instruction" in data:
|
||||
self.__init_with_instruction(data["instruction"])
|
||||
|
||||
def set_prefix(self, prefix):
|
||||
"""Set prefix for later usage"""
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ NOTE: You should use typing.List instead of list to do type annotation. Because
|
|||
import json
|
||||
from typing import Any, Dict, List, Optional, Tuple, Type
|
||||
|
||||
from pydantic import BaseModel, create_model, root_validator, validator
|
||||
from pydantic import BaseModel, create_model, field_validator, model_validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -137,13 +137,15 @@ class ActionNode:
|
|||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
new_class = create_model(class_name, **mapping)
|
||||
|
||||
@validator("*", allow_reuse=True)
|
||||
@field_validator("*", mode="before")
|
||||
@classmethod
|
||||
def check_name(v, field):
|
||||
if field.name not in mapping.keys():
|
||||
raise ValueError(f"Unrecognized block: {field.name}")
|
||||
return v
|
||||
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def check_missing_fields(values):
|
||||
required_fields = set(mapping.keys())
|
||||
missing_fields = required_fields - set(values.keys())
|
||||
|
|
@ -273,7 +275,9 @@ class ActionNode:
|
|||
output_class = self.create_model_class(output_class_name, output_data_mapping)
|
||||
|
||||
if schema == "json":
|
||||
parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key=f"[/{TAG}]")
|
||||
parsed_data = llm_output_postprecess(
|
||||
output=content, schema=output_class.model_json_schema(), req_key=f"[/{TAG}]"
|
||||
)
|
||||
else: # using markdown parser
|
||||
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
|
||||
|
||||
|
|
@ -282,7 +286,7 @@ class ActionNode:
|
|||
return content, instruct_content
|
||||
|
||||
def get(self, key):
|
||||
return self.instruct_content.dict()[key]
|
||||
return self.instruct_content.model_dump()[key]
|
||||
|
||||
def set_recursive(self, name, value):
|
||||
setattr(self, name, value)
|
||||
|
|
@ -348,17 +352,3 @@ class ActionNode:
|
|||
cls = self.create_children_class()
|
||||
self.instruct_content = cls(**tmp)
|
||||
return self
|
||||
|
||||
|
||||
def action_node_example():
|
||||
node = ActionNode(key="key-0", expected_type=str, instruction="instruction-a", example="example-b")
|
||||
|
||||
logger.info(node.compile(context="123", schema="raw", mode="auto"))
|
||||
logger.info(node.compile(context="123", schema="json", mode="auto"))
|
||||
logger.info(node.compile(context="123", schema="markdown", mode="auto"))
|
||||
logger.info(node.to_dict())
|
||||
logger.info(node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
action_node_example()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,3 @@ from metagpt.actions import Action
|
|||
|
||||
class UserRequirement(Action):
|
||||
"""User Requirement without any implementation details"""
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class WriteDesign(Action):
|
|||
logger.info("Nothing has changed.")
|
||||
# Wait until all files under `docs/system_designs/` are processed before sending the publish message,
|
||||
# leaving room for global optimization in subsequent steps.
|
||||
return ActionOutput(content=changed_files.json(), instruct_content=changed_files)
|
||||
return ActionOutput(content=changed_files.model_dump_json(), instruct_content=changed_files)
|
||||
|
||||
async def _new_system_design(self, context, schema=CONFIG.prompt_schema):
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
|
|
@ -88,7 +88,7 @@ class WriteDesign(Action):
|
|||
async def _merge(self, prd_doc, system_design_doc, schema=CONFIG.prompt_schema):
|
||||
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=schema)
|
||||
system_design_doc.content = node.instruct_content.json(ensure_ascii=False)
|
||||
system_design_doc.content = node.instruct_content.model_dump_json()
|
||||
return system_design_doc
|
||||
|
||||
async def _update_system_design(self, filename, prds_file_repo, system_design_file_repo) -> Document:
|
||||
|
|
@ -99,7 +99,7 @@ class WriteDesign(Action):
|
|||
doc = Document(
|
||||
root_path=SYSTEM_DESIGN_FILE_REPO,
|
||||
filename=filename,
|
||||
content=system_design.instruct_content.json(ensure_ascii=False),
|
||||
content=system_design.instruct_content.model_dump_json(),
|
||||
)
|
||||
else:
|
||||
doc = await self._merge(prd_doc=prd, system_design_doc=old_system_design_doc)
|
||||
|
|
|
|||
|
|
@ -8,7 +8,6 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.mermaid import MMC1, MMC2
|
||||
|
||||
IMPLEMENTATION_APPROACH = ActionNode(
|
||||
|
|
@ -63,12 +62,3 @@ NODES = [
|
|||
]
|
||||
|
||||
DESIGN_API_NODE = ActionNode.from_children("DesignAPI", NODES)
|
||||
|
||||
|
||||
def main():
|
||||
prompt = DESIGN_API_NODE.compile(context="")
|
||||
logger.info(prompt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
|||
|
|
@ -39,6 +39,8 @@ class PrepareDocuments(Action):
|
|||
path = Path(CONFIG.project_path)
|
||||
if path.exists() and not CONFIG.inc:
|
||||
shutil.rmtree(path)
|
||||
CONFIG.project_path = path
|
||||
CONFIG.project_name = path.name
|
||||
CONFIG.git_repo = GitRepository(local_path=path, auto_init=True)
|
||||
|
||||
async def run(self, with_messages, **kwargs):
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class WriteTasks(Action):
|
|||
logger.info("Nothing has changed.")
|
||||
# Wait until all files under `docs/tasks/` are processed before sending the publish_message, leaving room for
|
||||
# global optimization in subsequent steps.
|
||||
return ActionOutput(content=change_files.json(), instruct_content=change_files)
|
||||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
|
||||
async def _update_tasks(self, filename, system_design_file_repo, tasks_file_repo):
|
||||
system_design_doc = await system_design_file_repo.get(filename)
|
||||
|
|
@ -83,7 +83,7 @@ class WriteTasks(Action):
|
|||
else:
|
||||
rsp = await self._run_new_tasks(context=system_design_doc.content)
|
||||
task_doc = Document(
|
||||
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.json(ensure_ascii=False)
|
||||
root_path=TASK_FILE_REPO, filename=filename, content=rsp.instruct_content.model_dump_json()
|
||||
)
|
||||
await tasks_file_repo.save(
|
||||
filename=filename, content=task_doc.content, dependencies={system_design_doc.root_relative_path}
|
||||
|
|
@ -102,7 +102,7 @@ class WriteTasks(Action):
|
|||
async def _merge(self, system_design_doc, task_doc, schema=CONFIG.prompt_schema) -> Document:
|
||||
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_tasks=task_doc.content)
|
||||
node = await PM_NODE.fill(context, self.llm, schema)
|
||||
task_doc.content = node.instruct_content.json(ensure_ascii=False)
|
||||
task_doc.content = node.instruct_content.model_dump_json()
|
||||
return task_doc
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -123,9 +123,3 @@ class WriteTasks(Action):
|
|||
@staticmethod
|
||||
async def _save_pdf(task_doc):
|
||||
await FileRepository.save_as(doc=task_doc, with_suffix=".md", relative_path=TASK_PDF_FILE_REPO)
|
||||
|
||||
|
||||
class AssignTasks(Action):
|
||||
async def run(self, *args, **kwargs):
|
||||
# Here you should implement the actual action
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class RebuildClassView(Action):
|
|||
|
||||
# try:
|
||||
# node = await REBUILD_CLASS_VIEW_NODE.fill(context=f"```{code_type}\n{src_code}\n```", llm=self.llm, to=format)
|
||||
# class_view = node.instruct_content.dict()["Class View"]
|
||||
# class_view = node.instruct_content.model_dump()["Class View"]
|
||||
# except Exception as e:
|
||||
# class_view = RepoParser.rebuild_class_view(src_code, code_type)
|
||||
# await graph_db.insert(subject=concat_namespace(filename, class_name), predicate=GraphKeyword.HAS_CLASS_VIEW, object_=class_view)
|
||||
|
|
|
|||
|
|
@ -84,8 +84,9 @@ class CollectLinks(Action):
|
|||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Collect links from a search engine."
|
||||
|
||||
search_engine: SearchEngine = Field(default_factory=SearchEngine)
|
||||
rank_func: Union[Callable[[list[str]], None], None] = None
|
||||
rank_func: Optional[Callable[[list[str]], None]] = None
|
||||
|
||||
async def run(
|
||||
self,
|
||||
|
|
@ -181,18 +182,18 @@ class WebBrowseAndSummarize(Action):
|
|||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
desc: str = "Explore the web and provide summaries of articles and webpages."
|
||||
browse_func: Union[Callable[[list[str]], None], None] = None
|
||||
web_browser_engine: WebBrowserEngine = Field(
|
||||
default_factory=lambda: WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if WebBrowseAndSummarize.browse_func else None,
|
||||
run_func=WebBrowseAndSummarize.browse_func,
|
||||
)
|
||||
)
|
||||
web_browser_engine: Optional[WebBrowserEngine] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if CONFIG.model_for_researcher_summary:
|
||||
self.llm.model = CONFIG.model_for_researcher_summary
|
||||
|
||||
self.web_browser_engine = WebBrowserEngine(
|
||||
engine=WebBrowserEngineType.CUSTOM if self.browse_func else None,
|
||||
run_func=self.browse_func,
|
||||
)
|
||||
|
||||
async def run(
|
||||
self,
|
||||
url: str,
|
||||
|
|
|
|||
|
|
@ -82,11 +82,13 @@ class RunCode(Action):
|
|||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
@classmethod
|
||||
@handle_exception
|
||||
async def run_text(cls, code) -> Tuple[str, str]:
|
||||
# We will document_store the result in this dictionary
|
||||
namespace = {}
|
||||
exec(code, namespace)
|
||||
try:
|
||||
# We will document_store the result in this dictionary
|
||||
namespace = {}
|
||||
exec(code, namespace)
|
||||
except Exception as e:
|
||||
return "", str(e)
|
||||
return namespace.get("result", ""), ""
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
from typing import Any, Optional
|
||||
|
||||
import pydantic
|
||||
from pydantic import Field, root_validator
|
||||
from pydantic import Field, model_validator
|
||||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.config import CONFIG, Config
|
||||
|
|
@ -114,10 +114,10 @@ class SearchAndSummarize(Action):
|
|||
engine: Optional[SearchEngineType] = CONFIG.search_engine
|
||||
search_func: Optional[Any] = None
|
||||
search_engine: SearchEngine = None
|
||||
result: str = ""
|
||||
|
||||
result = ""
|
||||
|
||||
@root_validator
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def validate_engine_and_run_func(cls, values):
|
||||
engine = values.get("engine")
|
||||
search_func = values.get("search_func")
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class WritePRD(Action):
|
|||
await docs_file_repo.save(filename=REQUIREMENT_FILENAME, content="")
|
||||
bug_fix = BugFixContext(filename=BUGFIX_FILENAME)
|
||||
return Message(
|
||||
content=bug_fix.json(),
|
||||
content=bug_fix.model_dump_json(),
|
||||
instruct_content=bug_fix,
|
||||
role="",
|
||||
cause_by=FixBug,
|
||||
|
|
@ -112,7 +112,7 @@ class WritePRD(Action):
|
|||
# Once all files under 'docs/prds/' have been compared with the newly added requirements, trigger the
|
||||
# 'publish' message to transition the workflow to the next stage. This design allows room for global
|
||||
# optimization in subsequent steps.
|
||||
return ActionOutput(content=change_files.json(), instruct_content=change_files)
|
||||
return ActionOutput(content=change_files.model_dump_json(), instruct_content=change_files)
|
||||
|
||||
async def _run_new_requirement(self, requirements, schema=CONFIG.prompt_schema) -> ActionOutput:
|
||||
# sas = SearchAndSummarize()
|
||||
|
|
@ -139,7 +139,7 @@ class WritePRD(Action):
|
|||
CONFIG.project_name = Path(CONFIG.project_path).name
|
||||
prompt = NEW_REQ_TEMPLATE.format(requirements=new_requirement_doc.content, old_prd=prd_doc.content)
|
||||
node = await WRITE_PRD_NODE.fill(context=prompt, llm=self.llm, schema=schema)
|
||||
prd_doc.content = node.instruct_content.json(ensure_ascii=False)
|
||||
prd_doc.content = node.instruct_content.model_dump_json()
|
||||
await self._rename_workspace(node)
|
||||
return prd_doc
|
||||
|
||||
|
|
@ -151,7 +151,7 @@ class WritePRD(Action):
|
|||
new_prd_doc = Document(
|
||||
root_path=PRDS_FILE_REPO,
|
||||
filename=FileRepository.new_filename() + ".json",
|
||||
content=prd.instruct_content.json(ensure_ascii=False),
|
||||
content=prd.instruct_content.model_dump_json(),
|
||||
)
|
||||
elif await self._is_relative(requirement_doc, prd_doc):
|
||||
new_prd_doc = await self._merge(requirement_doc, prd_doc)
|
||||
|
|
@ -181,18 +181,13 @@ class WritePRD(Action):
|
|||
|
||||
@staticmethod
|
||||
async def _rename_workspace(prd):
|
||||
if CONFIG.project_path: # Updating on the old version has already been specified if it's valid. According to
|
||||
# Section 2.2.3.10 of RFC 135
|
||||
if not CONFIG.project_name:
|
||||
CONFIG.project_name = Path(CONFIG.project_path).name
|
||||
return
|
||||
|
||||
if not CONFIG.project_name:
|
||||
if isinstance(prd, (ActionOutput, ActionNode)):
|
||||
ws_name = prd.instruct_content.dict()["Project Name"]
|
||||
ws_name = prd.instruct_content.model_dump()["Project Name"]
|
||||
else:
|
||||
ws_name = CodeParser.parse_str(block="Project Name", text=prd)
|
||||
CONFIG.project_name = ws_name
|
||||
if ws_name:
|
||||
CONFIG.project_name = ws_name
|
||||
CONFIG.git_repo.rename_root(CONFIG.project_name)
|
||||
|
||||
async def _is_bugfix(self, context) -> bool:
|
||||
|
|
|
|||
|
|
@ -19,6 +19,7 @@ class WritePRDReview(Action):
|
|||
name: str = ""
|
||||
context: Optional[str] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
prd: Optional[str] = None
|
||||
desc: str = "Based on the PRD, conduct a PRD Review, providing clear and detailed feedback"
|
||||
prd_review_prompt_template: str = """
|
||||
|
|
|
|||
|
|
@ -72,6 +72,7 @@ class Config(metaclass=Singleton):
|
|||
self.inc = False
|
||||
self.reqa_file = ""
|
||||
self.max_auto_summarize_code = 0
|
||||
self.git_reinit = False
|
||||
|
||||
self._init_with_config_files_and_env(yaml_file)
|
||||
# The agent needs to be billed per user, so billing information cannot be destroyed when the session ends.
|
||||
|
|
@ -146,7 +147,7 @@ class Config(metaclass=Singleton):
|
|||
if not self._get("DISABLE_LLM_PROVIDER_CHECK"):
|
||||
_ = self.get_default_llm_provider_enum()
|
||||
|
||||
# self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
self.openai_base_url = self._get("OPENAI_BASE_URL")
|
||||
self.openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
|
||||
self.openai_api_type = self._get("OPENAI_API_TYPE")
|
||||
self.openai_api_version = self._get("OPENAI_API_VERSION")
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ from langchain.document_loaders import (
|
|||
UnstructuredWordDocumentLoader,
|
||||
)
|
||||
from langchain.text_splitter import CharacterTextSplitter
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from tqdm import tqdm
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -117,13 +117,12 @@ class IndexableDocument(Document):
|
|||
Advanced document handling: For vector databases or search engines.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
data: Union[pd.DataFrame, list]
|
||||
content_col: Optional[str] = Field(default="")
|
||||
meta_col: Optional[str] = Field(default="")
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@classmethod
|
||||
def from_path(cls, data_path: Path, content_col="content", meta_col="metadata"):
|
||||
if not data_path.exists():
|
||||
|
|
|
|||
|
|
@ -1,111 +0,0 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2023/5/28 00:00
|
||||
@Author : alexanderwu
|
||||
@File : milvus_store.py
|
||||
"""
|
||||
from typing import TypedDict
|
||||
|
||||
import numpy as np
|
||||
from pymilvus import Collection, CollectionSchema, DataType, FieldSchema, connections
|
||||
|
||||
from metagpt.document_store.base_store import BaseStore
|
||||
|
||||
type_mapping = {int: DataType.INT64, str: DataType.VARCHAR, float: DataType.DOUBLE, np.ndarray: DataType.FLOAT_VECTOR}
|
||||
|
||||
|
||||
def columns_to_milvus_schema(columns: dict, primary_col_name: str = "", desc: str = ""):
|
||||
"""Assume the structure of columns is str: regular type"""
|
||||
fields = []
|
||||
for col, ctype in columns.items():
|
||||
if ctype == str:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], max_length=100)
|
||||
elif ctype == np.ndarray:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], dim=2)
|
||||
else:
|
||||
mcol = FieldSchema(name=col, dtype=type_mapping[ctype], is_primary=(col == primary_col_name))
|
||||
fields.append(mcol)
|
||||
schema = CollectionSchema(fields, description=desc)
|
||||
return schema
|
||||
|
||||
|
||||
class MilvusConnection(TypedDict):
|
||||
alias: str
|
||||
host: str
|
||||
port: str
|
||||
|
||||
|
||||
class MilvusStore(BaseStore):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/create_collection.md
|
||||
"""
|
||||
|
||||
def __init__(self, connection):
|
||||
connections.connect(**connection)
|
||||
self.collection = None
|
||||
|
||||
def _create_collection(self, name, schema):
|
||||
collection = Collection(name=name, schema=schema, using="default", shards_num=2, consistency_level="Strong")
|
||||
return collection
|
||||
|
||||
def create_collection(self, name, columns):
|
||||
schema = columns_to_milvus_schema(columns, "idx")
|
||||
self.collection = self._create_collection(name, schema)
|
||||
return self.collection
|
||||
|
||||
def drop(self, name):
|
||||
Collection(name).drop()
|
||||
|
||||
def load_collection(self):
|
||||
self.collection.load()
|
||||
|
||||
def build_index(self, field="emb"):
|
||||
self.collection.create_index(field, {"index_type": "FLAT", "metric_type": "L2", "params": {}})
|
||||
|
||||
def search(self, query: list[list[float]], *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/search.md
|
||||
All search and query operations within Milvus are executed in memory. Load the collection to memory before conducting a vector similarity search.
|
||||
Note the above description, is this logic serious? This should take a long time, right?
|
||||
"""
|
||||
search_params = {"metric_type": "L2", "params": {"nprobe": 10}}
|
||||
results = self.collection.search(
|
||||
data=query,
|
||||
anns_field=kwargs.get("field", "emb"),
|
||||
param=search_params,
|
||||
limit=10,
|
||||
expr=None,
|
||||
consistency_level="Strong",
|
||||
)
|
||||
# FIXME: results contain id, but to get the actual value from the id, we still need to call the query interface
|
||||
return results
|
||||
|
||||
def write(self, name, schema, *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/create_collection.md
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def add(self, data, *args, **kwargs):
|
||||
"""
|
||||
FIXME: ADD TESTS
|
||||
https://milvus.io/docs/v2.0.x/insert_data.md
|
||||
import random
|
||||
data = [
|
||||
[i for i in range(2000)],
|
||||
[i for i in range(10000, 12000)],
|
||||
[[random.random() for _ in range(2)] for _ in range(2000)],
|
||||
]
|
||||
|
||||
:param args:
|
||||
:param kwargs:
|
||||
:return:
|
||||
"""
|
||||
self.collection.insert(data)
|
||||
|
|
@ -15,11 +15,11 @@ import asyncio
|
|||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles.role import Role, role_subclass_registry
|
||||
from metagpt.roles.role import Role
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.common import is_subscribed, read_json_file, write_json_file
|
||||
|
||||
|
|
@ -29,30 +29,17 @@ class Environment(BaseModel):
|
|||
Environment, hosting a batch of roles, roles can publish messages to the environment, and can be observed by other roles
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
desc: str = Field(default="") # 环境描述
|
||||
roles: dict[str, Role] = Field(default_factory=dict)
|
||||
members: dict[Role, Set] = Field(default_factory=dict)
|
||||
roles: dict[str, SerializeAsAny[Role]] = Field(default_factory=dict, validate_default=True)
|
||||
members: dict[Role, Set] = Field(default_factory=dict, exclude=True)
|
||||
history: str = "" # For debug
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
roles = []
|
||||
for role_key, role in kwargs.get("roles", {}).items():
|
||||
current_role = kwargs["roles"][role_key]
|
||||
if isinstance(current_role, dict):
|
||||
item_class_name = current_role.get("builtin_class_name", None)
|
||||
for name, subclass in role_subclass_registry.items():
|
||||
registery_class_name = subclass.__fields__["builtin_class_name"].default
|
||||
if item_class_name == registery_class_name:
|
||||
current_role = subclass(**current_role)
|
||||
break
|
||||
kwargs["roles"][role_key] = current_role
|
||||
roles.append(current_role)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.add_roles(roles) # add_roles again to init the Role.set_env
|
||||
@model_validator(mode="after")
|
||||
def init_roles(self):
|
||||
self.add_roles(self.roles.values())
|
||||
return self
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
roles_path = stg_path.joinpath("roles.json")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@
|
|||
@Time : 2023/6/5 01:44
|
||||
@Author : alexanderwu
|
||||
@File : skill_manager.py
|
||||
@Modified By: mashenquan, 2023/8/20. Remove useless `_llm`
|
||||
@Modified By: mashenquan, 2023/8/20. Remove useless `llm`
|
||||
"""
|
||||
from metagpt.actions import Action
|
||||
from metagpt.const import PROMPT_PATH
|
||||
|
|
|
|||
|
|
@ -73,7 +73,7 @@ class BrainMemory(BaseModel):
|
|||
redis = Redis()
|
||||
if not redis.is_valid or not redis_key:
|
||||
return False
|
||||
v = self.json(ensure_ascii=False)
|
||||
v = self.model_dump_json()
|
||||
if self.cacheable:
|
||||
await redis.set(key=redis_key, data=v, timeout_sec=timeout_sec)
|
||||
logger.debug(f"REDIS SET {redis_key} {v}")
|
||||
|
|
@ -99,6 +99,7 @@ class BrainMemory(BaseModel):
|
|||
if msg.id:
|
||||
if self.to_int(msg.id, 0) <= self.to_int(self.last_history_id, -1):
|
||||
return
|
||||
|
||||
self.history.append(msg)
|
||||
self.last_history_id = str(msg.id)
|
||||
self.is_dirty = True
|
||||
|
|
@ -156,7 +157,7 @@ class BrainMemory(BaseModel):
|
|||
if left == 0:
|
||||
break
|
||||
m.content = m.content[0:left]
|
||||
msgs.append(m.dict())
|
||||
msgs.append(m.model_dump())
|
||||
break
|
||||
msgs.append(m)
|
||||
total_length += delta
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import ConfigDict, Field
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
|
|
@ -22,13 +22,12 @@ class LongTermMemory(Memory):
|
|||
- update memory when it changed
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
memory_storage: MemoryStorage = Field(default_factory=MemoryStorage)
|
||||
rc: Optional["RoleContext"] = None
|
||||
msg_from_recover: bool = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def recover_memory(self, role_id: str, rc: "RoleContext"):
|
||||
messages = self.memory_storage.recover_memory(role_id)
|
||||
self.rc = rc
|
||||
|
|
|
|||
|
|
@ -8,9 +8,9 @@
|
|||
"""
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Set
|
||||
from typing import DefaultDict, Iterable, Set
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, SerializeAsAny
|
||||
|
||||
from metagpt.const import IGNORED_MESSAGE_ID
|
||||
from metagpt.schema import Message
|
||||
|
|
@ -25,23 +25,14 @@ from metagpt.utils.common import (
|
|||
class Memory(BaseModel):
|
||||
"""The most basic memory: super-memory"""
|
||||
|
||||
storage: list[Message] = []
|
||||
index: dict[str, list[Message]] = Field(default_factory=defaultdict(list))
|
||||
storage: list[SerializeAsAny[Message]] = []
|
||||
index: DefaultDict[str, list[SerializeAsAny[Message]]] = Field(default_factory=lambda: defaultdict(list))
|
||||
ignore_id: bool = False
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
index = kwargs.get("index", {})
|
||||
new_index = defaultdict(list)
|
||||
for action_str, value in index.items():
|
||||
new_index[action_str] = [Message(**item_dict) for item_dict in value]
|
||||
kwargs["index"] = new_index
|
||||
super(Memory, self).__init__(**kwargs)
|
||||
self.index = new_index
|
||||
|
||||
def serialize(self, stg_path: Path):
|
||||
"""stg_path = ./storage/team/environment/ or ./storage/team/environment/roles/{role_class}_{role_name}/"""
|
||||
memory_path = stg_path.joinpath("memory.json")
|
||||
storage = self.dict()
|
||||
storage = self.model_dump()
|
||||
write_json_file(memory_path, storage)
|
||||
|
||||
@classmethod
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class OpenAILLM(BaseLLM):
|
|||
self.aclient = AsyncOpenAI(**kwargs)
|
||||
|
||||
def _make_client_kwargs(self) -> dict:
|
||||
kwargs = {"api_key": self.config.OPENAI_API_KEY, "base_url": self.config.OPENAI_BASE_URL}
|
||||
kwargs = {"api_key": self.config.openai_api_key, "base_url": self.config.openai_base_url}
|
||||
|
||||
# to use proxy, openai v1 needs http_client
|
||||
if proxy_params := self._get_proxy_params():
|
||||
|
|
@ -81,8 +81,8 @@ class OpenAILLM(BaseLLM):
|
|||
params = {}
|
||||
if self.config.openai_proxy:
|
||||
params = {"proxies": self.config.openai_proxy}
|
||||
if self.config.OPENAI_BASE_URL:
|
||||
params["base_url"] = self.config.OPENAI_BASE_URL
|
||||
if self.config.openai_base_url:
|
||||
params["base_url"] = self.config.openai_base_url
|
||||
|
||||
return params
|
||||
|
||||
|
|
|
|||
|
|
@ -65,22 +65,20 @@ class Assistant(Role):
|
|||
prompt += f"If the text explicitly want you to {desc}, return `[SKILL]: {name}` brief and clear. For instance: [SKILL]: {name}\n"
|
||||
prompt += 'Otherwise, return `[TALK]: {talk}` brief and clear. For instance: if {talk} is "xxxx" return [TALK]: xxxx\n\n'
|
||||
prompt += f"Now what specific action is explicitly mentioned in the text: {last_talk}\n"
|
||||
rsp = await self._llm.aask(prompt, [])
|
||||
rsp = await self.llm.aask(prompt, [])
|
||||
logger.info(f"THINK: {prompt}\n, THINK RESULT: {rsp}\n")
|
||||
return await self._plan(rsp, last_talk=last_talk)
|
||||
|
||||
async def act(self) -> Message:
|
||||
result = await self._rc.todo.run()
|
||||
result = await self.rc.todo.run()
|
||||
if not result:
|
||||
return None
|
||||
if isinstance(result, str):
|
||||
msg = Message(content=result, role="assistant", cause_by=self._rc.todo)
|
||||
msg = Message(content=result, role="assistant", cause_by=self.rc.todo)
|
||||
elif isinstance(result, Message):
|
||||
msg = result
|
||||
else:
|
||||
msg = Message(
|
||||
content=result.content, instruct_content=result.instruct_content, cause_by=type(self._rc.todo)
|
||||
)
|
||||
msg = Message(content=result.content, instruct_content=result.instruct_content, cause_by=type(self.rc.todo))
|
||||
self.memory.add_answer(msg)
|
||||
return msg
|
||||
|
||||
|
|
@ -99,8 +97,8 @@ class Assistant(Role):
|
|||
async def talk_handler(self, text, **kwargs) -> bool:
|
||||
history = self.memory.history_text
|
||||
text = kwargs.get("last_talk") or text
|
||||
self._rc.todo = TalkAction(
|
||||
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self._llm, **kwargs
|
||||
self.rc.todo = TalkAction(
|
||||
context=text, knowledge=self.memory.get_knowledge(), history_summary=history, llm=self.llm, **kwargs
|
||||
)
|
||||
return True
|
||||
|
||||
|
|
@ -110,13 +108,11 @@ class Assistant(Role):
|
|||
if not skill:
|
||||
logger.info(f"skill not found: {text}")
|
||||
return await self.talk_handler(text=last_talk, **kwargs)
|
||||
action = ArgumentsParingAction(skill=skill, llm=self._llm, ask=last_talk, **kwargs)
|
||||
action = ArgumentsParingAction(skill=skill, llm=self.llm, ask=last_talk, **kwargs)
|
||||
await action.run(**kwargs)
|
||||
if action.args is None:
|
||||
return await self.talk_handler(text=last_talk, **kwargs)
|
||||
self._rc.todo = SkillAction(
|
||||
skill=skill, args=action.args, llm=self._llm, name=skill.name, desc=skill.description
|
||||
)
|
||||
self.rc.todo = SkillAction(skill=skill, args=action.args, llm=self.llm, name=skill.name, desc=skill.description)
|
||||
return True
|
||||
|
||||
async def refine_memory(self) -> str:
|
||||
|
|
@ -125,16 +121,16 @@ class Assistant(Role):
|
|||
return None
|
||||
if not self.memory.is_history_available:
|
||||
return last_talk
|
||||
history_summary = await self.memory.summarize(max_words=800, keep_language=True, llm=self._llm)
|
||||
if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self._llm):
|
||||
history_summary = await self.memory.summarize(max_words=800, keep_language=True, llm=self.llm)
|
||||
if last_talk and await self.memory.is_related(text1=last_talk, text2=history_summary, llm=self.llm):
|
||||
# Merge relevant content.
|
||||
merged = await self.memory.rewrite(sentence=last_talk, context=history_summary, llm=self._llm)
|
||||
merged = await self.memory.rewrite(sentence=last_talk, context=history_summary, llm=self.llm)
|
||||
return f"{merged} {last_talk}"
|
||||
|
||||
return last_talk
|
||||
|
||||
def get_memory(self) -> str:
|
||||
return self.memory.json()
|
||||
return self.memory.model_dump_json()
|
||||
|
||||
def load_memory(self, jsn):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ class Engineer(Role):
|
|||
coding_context = await todo.run()
|
||||
# Code review
|
||||
if review:
|
||||
action = WriteCodeReview(context=coding_context, llm=self._llm)
|
||||
action = WriteCodeReview(context=coding_context, llm=self.llm)
|
||||
self._init_action_system_message(action)
|
||||
coding_context = await action.run()
|
||||
await src_file_repo.save(
|
||||
|
|
@ -118,9 +118,12 @@ class Engineer(Role):
|
|||
content=coding_context.code_doc.content,
|
||||
)
|
||||
msg = Message(
|
||||
content=coding_context.json(), instruct_content=coding_context, role=self.profile, cause_by=WriteCode
|
||||
content=coding_context.model_dump_json(),
|
||||
instruct_content=coding_context,
|
||||
role=self.profile,
|
||||
cause_by=WriteCode,
|
||||
)
|
||||
self._rc.memory.add(msg)
|
||||
self.rc.memory.add(msg)
|
||||
|
||||
changed_files.add(coding_context.code_doc.filename)
|
||||
if not changed_files:
|
||||
|
|
@ -129,12 +132,12 @@ class Engineer(Role):
|
|||
|
||||
async def _act(self) -> Message | None:
|
||||
"""Determines the mode of action based on whether code review is used."""
|
||||
if self._rc.todo is None:
|
||||
if self.rc.todo is None:
|
||||
return None
|
||||
if isinstance(self._rc.todo, WriteCode):
|
||||
if isinstance(self.rc.todo, WriteCode):
|
||||
self.next_todo_action = any_to_name(SummarizeCode)
|
||||
return await self._act_write_code()
|
||||
if isinstance(self._rc.todo, SummarizeCode):
|
||||
if isinstance(self.rc.todo, SummarizeCode):
|
||||
self.next_todo_action = any_to_name(WriteCode)
|
||||
return await self._act_summarize()
|
||||
return None
|
||||
|
|
@ -170,7 +173,7 @@ class Engineer(Role):
|
|||
tasks.append(todo.context.dict())
|
||||
await code_summaries_file_repo.save(
|
||||
filename=Path(todo.context.design_filename).name,
|
||||
content=todo.context.json(),
|
||||
content=todo.context.model_dump_json(),
|
||||
dependencies=dependencies,
|
||||
)
|
||||
else:
|
||||
|
|
@ -193,7 +196,7 @@ class Engineer(Role):
|
|||
)
|
||||
|
||||
async def _is_pass(self, summary) -> (str, str):
|
||||
rsp = await self._llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False)
|
||||
rsp = await self.llm.aask(msg=IS_PASS_PROMPT.format(context=summary), stream=False)
|
||||
logger.info(rsp)
|
||||
if "YES" in rsp:
|
||||
return True, rsp
|
||||
|
|
@ -204,17 +207,17 @@ class Engineer(Role):
|
|||
CONFIG.src_workspace = CONFIG.git_repo.workdir / CONFIG.git_repo.workdir.name
|
||||
write_code_filters = any_to_str_set([WriteTasks, SummarizeCode, FixBug])
|
||||
summarize_code_filters = any_to_str_set([WriteCode, WriteCodeReview])
|
||||
if not self._rc.news:
|
||||
if not self.rc.news:
|
||||
return None
|
||||
msg = self._rc.news[0]
|
||||
msg = self.rc.news[0]
|
||||
if msg.cause_by in write_code_filters:
|
||||
logger.debug(f"TODO WriteCode:{msg.json()}")
|
||||
logger.debug(f"TODO WriteCode:{msg.model_dump_json()}")
|
||||
await self._new_code_actions(bug_fix=msg.cause_by == any_to_str(FixBug))
|
||||
return self._rc.todo
|
||||
return self.rc.todo
|
||||
if msg.cause_by in summarize_code_filters and msg.sent_from == any_to_str(self):
|
||||
logger.debug(f"TODO SummarizeCode:{msg.json()}")
|
||||
logger.debug(f"TODO SummarizeCode:{msg.model_dump_json()}")
|
||||
await self._new_summarize_actions()
|
||||
return self._rc.todo
|
||||
return self.rc.todo
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -241,7 +244,9 @@ class Engineer(Role):
|
|||
context = await Engineer._new_coding_context(
|
||||
filename, src_file_repo, task_file_repo, design_file_repo, dependency
|
||||
)
|
||||
coding_doc = Document(root_path=str(src_file_repo.root_path), filename=filename, content=context.json())
|
||||
coding_doc = Document(
|
||||
root_path=str(src_file_repo.root_path), filename=filename, content=context.model_dump_json()
|
||||
)
|
||||
return coding_doc
|
||||
|
||||
async def _new_code_actions(self, bug_fix=False):
|
||||
|
|
@ -266,15 +271,15 @@ class Engineer(Role):
|
|||
filename=task_filename, design_doc=design_doc, task_doc=task_doc, code_doc=old_code_doc
|
||||
)
|
||||
coding_doc = Document(
|
||||
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.json()
|
||||
root_path=str(src_file_repo.root_path), filename=task_filename, content=context.model_dump_json()
|
||||
)
|
||||
if task_filename in changed_files.docs:
|
||||
logger.warning(
|
||||
f"Log to expose potential conflicts: {coding_doc.json()} & "
|
||||
f"{changed_files.docs[task_filename].json()}"
|
||||
f"Log to expose potential conflicts: {coding_doc.model_dump_json()} & "
|
||||
f"{changed_files.docs[task_filename].model_dump_json()}"
|
||||
)
|
||||
changed_files.docs[task_filename] = coding_doc
|
||||
self.code_todos = [WriteCode(context=i, llm=self._llm) for i in changed_files.docs.values()]
|
||||
self.code_todos = [WriteCode(context=i, llm=self.llm) for i in changed_files.docs.values()]
|
||||
# Code directly modified by the user.
|
||||
dependency = await CONFIG.git_repo.get_dependency()
|
||||
for filename in changed_src_files:
|
||||
|
|
@ -288,10 +293,10 @@ class Engineer(Role):
|
|||
dependency=dependency,
|
||||
)
|
||||
changed_files.docs[filename] = coding_doc
|
||||
self.code_todos.append(WriteCode(context=coding_doc, llm=self._llm))
|
||||
self.code_todos.append(WriteCode(context=coding_doc, llm=self.llm))
|
||||
|
||||
if self.code_todos:
|
||||
self._rc.todo = self.code_todos[0]
|
||||
self.rc.todo = self.code_todos[0]
|
||||
|
||||
async def _new_summarize_actions(self):
|
||||
src_file_repo = CONFIG.git_repo.new_file_repository(CONFIG.src_workspace)
|
||||
|
|
@ -304,9 +309,9 @@ class Engineer(Role):
|
|||
summarizations[ctx].append(filename)
|
||||
for ctx, filenames in summarizations.items():
|
||||
ctx.codes_filenames = filenames
|
||||
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self._llm))
|
||||
self.summarize_todos.append(SummarizeCode(context=ctx, llm=self.llm))
|
||||
if self.summarize_todos:
|
||||
self._rc.todo = self.summarize_todos[0]
|
||||
self.rc.todo = self.summarize_todos[0]
|
||||
|
||||
@property
|
||||
def todo(self) -> str:
|
||||
|
|
|
|||
|
|
@ -69,8 +69,8 @@ class InvoiceOCRAssistant(Role):
|
|||
Returns:
|
||||
A message containing the result of the action.
|
||||
"""
|
||||
msg = self._rc.memory.get(k=1)[0]
|
||||
todo = self._rc.todo
|
||||
msg = self.rc.memory.get(k=1)[0]
|
||||
todo = self.rc.todo
|
||||
if isinstance(todo, InvoiceOCR):
|
||||
self.origin_query = msg.content
|
||||
invoice_path: InvoicePath = msg.instruct_content
|
||||
|
|
@ -87,11 +87,11 @@ class InvoiceOCRAssistant(Role):
|
|||
else:
|
||||
self._init_actions([GenerateTable])
|
||||
|
||||
self._rc.todo = None
|
||||
self.rc.todo = None
|
||||
content = INVOICE_OCR_SUCCESS
|
||||
resp = OCRResults(ocr_result=json.dumps(resp))
|
||||
msg = Message(content=content, instruct_content=resp)
|
||||
self._rc.memory.add(msg)
|
||||
self.rc.memory.add(msg)
|
||||
return await super().react()
|
||||
elif isinstance(todo, GenerateTable):
|
||||
ocr_results: OCRResults = msg.instruct_content
|
||||
|
|
@ -108,5 +108,5 @@ class InvoiceOCRAssistant(Role):
|
|||
resp = ReplyData(content=resp)
|
||||
|
||||
msg = Message(content=content, instruct_content=resp)
|
||||
self._rc.memory.add(msg)
|
||||
self.rc.memory.add(msg)
|
||||
return msg
|
||||
|
|
|
|||
|
|
@ -40,12 +40,13 @@ class ProductManager(Role):
|
|||
|
||||
async def _think(self) -> bool:
|
||||
"""Decide what to do"""
|
||||
if CONFIG.git_repo:
|
||||
if CONFIG.git_repo and not CONFIG.git_reinit:
|
||||
self._set_state(1)
|
||||
else:
|
||||
self._set_state(0)
|
||||
CONFIG.git_reinit = False
|
||||
self.todo_action = any_to_name(WritePRD)
|
||||
return bool(self._rc.todo)
|
||||
return bool(self.rc.todo)
|
||||
|
||||
async def _observe(self, ignore_memory=False) -> int:
|
||||
return await super()._observe(ignore_memory=True)
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class QaEngineer(Role):
|
|||
)
|
||||
logger.info(f"Writing {test_doc.filename}..")
|
||||
context = TestingContext(filename=test_doc.filename, test_doc=test_doc, code_doc=code_doc)
|
||||
context = await WriteTest(context=context, llm=self._llm).run()
|
||||
context = await WriteTest(context=context, llm=self.llm).run()
|
||||
await tests_file_repo.save(
|
||||
filename=context.test_doc.filename,
|
||||
content=context.test_doc.content,
|
||||
|
|
@ -86,7 +86,7 @@ class QaEngineer(Role):
|
|||
)
|
||||
self.publish_message(
|
||||
Message(
|
||||
content=run_code_context.json(),
|
||||
content=run_code_context.model_dump_json(),
|
||||
role=self.profile,
|
||||
cause_by=WriteTest,
|
||||
sent_from=self,
|
||||
|
|
@ -106,11 +106,11 @@ class QaEngineer(Role):
|
|||
return
|
||||
run_code_context.code = src_doc.content
|
||||
run_code_context.test_code = test_doc.content
|
||||
result = await RunCode(context=run_code_context, llm=self._llm).run()
|
||||
result = await RunCode(context=run_code_context, llm=self.llm).run()
|
||||
run_code_context.output_filename = run_code_context.test_filename + ".json"
|
||||
await CONFIG.git_repo.new_file_repository(TEST_OUTPUTS_FILE_REPO).save(
|
||||
filename=run_code_context.output_filename,
|
||||
content=result.json(),
|
||||
content=result.model_dump_json(),
|
||||
dependencies={src_doc.root_relative_path, test_doc.root_relative_path},
|
||||
)
|
||||
run_code_context.code = None
|
||||
|
|
@ -120,7 +120,7 @@ class QaEngineer(Role):
|
|||
mappings = {"Engineer": "Alex", "QaEngineer": "Edward"}
|
||||
self.publish_message(
|
||||
Message(
|
||||
content=run_code_context.json(),
|
||||
content=run_code_context.model_dump_json(),
|
||||
role=self.profile,
|
||||
cause_by=RunCode,
|
||||
sent_from=self,
|
||||
|
|
@ -130,14 +130,14 @@ class QaEngineer(Role):
|
|||
|
||||
async def _debug_error(self, msg):
|
||||
run_code_context = RunCodeContext.loads(msg.content)
|
||||
code = await DebugError(context=run_code_context, llm=self._llm).run()
|
||||
code = await DebugError(context=run_code_context, llm=self.llm).run()
|
||||
await FileRepository.save_file(
|
||||
filename=run_code_context.test_filename, content=code, relative_path=TEST_CODES_FILE_REPO
|
||||
)
|
||||
run_code_context.output = None
|
||||
self.publish_message(
|
||||
Message(
|
||||
content=run_code_context.json(),
|
||||
content=run_code_context.model_dump_json(),
|
||||
role=self.profile,
|
||||
cause_by=DebugError,
|
||||
sent_from=self,
|
||||
|
|
@ -159,7 +159,7 @@ class QaEngineer(Role):
|
|||
code_filters = any_to_str_set({SummarizeCode})
|
||||
test_filters = any_to_str_set({WriteTest, DebugError})
|
||||
run_filters = any_to_str_set({RunCode})
|
||||
for msg in self._rc.news:
|
||||
for msg in self.rc.news:
|
||||
# Decide what to do based on observed msg type, currently defined by human,
|
||||
# might potentially be moved to _think, that is, let the agent decides for itself
|
||||
if msg.cause_by in code_filters:
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
|
@ -41,20 +42,20 @@ class Researcher(Role):
|
|||
logger.warning(f"The language `{self.language}` has not been tested, it may not work.")
|
||||
|
||||
async def _think(self) -> bool:
|
||||
if self._rc.todo is None:
|
||||
if self.rc.todo is None:
|
||||
self._set_state(0)
|
||||
return True
|
||||
|
||||
if self._rc.state + 1 < len(self._states):
|
||||
self._set_state(self._rc.state + 1)
|
||||
if self.rc.state + 1 < len(self.states):
|
||||
self._set_state(self.rc.state + 1)
|
||||
else:
|
||||
self._rc.todo = None
|
||||
self.rc.todo = None
|
||||
return False
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
|
||||
todo = self._rc.todo
|
||||
msg = self._rc.memory.get(k=1)[0]
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
todo = self.rc.todo
|
||||
msg = self.rc.memory.get(k=1)[0]
|
||||
if isinstance(msg.instruct_content, Report):
|
||||
instruct_content = msg.instruct_content
|
||||
topic = instruct_content.topic
|
||||
|
|
@ -78,14 +79,14 @@ class Researcher(Role):
|
|||
else:
|
||||
summaries = instruct_content.summaries
|
||||
summary_text = "\n---\n".join(f"url: {url}\nsummary: {summary}" for (url, summary) in summaries)
|
||||
content = await self._rc.todo.run(topic, summary_text, system_text=research_system_text)
|
||||
content = await self.rc.todo.run(topic, summary_text, system_text=research_system_text)
|
||||
ret = Message(
|
||||
content="",
|
||||
instruct_content=Report(topic=topic, content=content),
|
||||
role=self.profile,
|
||||
cause_by=self._rc.todo,
|
||||
cause_by=self.rc.todo,
|
||||
)
|
||||
self._rc.memory.add(ret)
|
||||
self.rc.memory.add(ret)
|
||||
return ret
|
||||
|
||||
def research_system_text(self, topic, current_task: Action) -> str:
|
||||
|
|
@ -107,9 +108,11 @@ class Researcher(Role):
|
|||
return msg
|
||||
|
||||
def write_report(self, topic: str, content: str):
|
||||
filename = re.sub(r'[\\/:"*?<>|]+', " ", topic)
|
||||
filename = filename.replace("\n", "")
|
||||
if not RESEARCH_PATH.exists():
|
||||
RESEARCH_PATH.mkdir(parents=True)
|
||||
filepath = RESEARCH_PATH / f"{topic}.md"
|
||||
filepath = RESEARCH_PATH / f"{filename}.md"
|
||||
filepath.write_text(content)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -10,8 +10,8 @@
|
|||
consolidated within the `_observe` function.
|
||||
2. Standardize the message filtering for string label matching. Role objects can access the message labels
|
||||
they've subscribed to through the `subscribed_tags` property.
|
||||
3. Move the message receive buffer from the global variable `self._rc.env.memory` to the role's private variable
|
||||
`self._rc.msg_buffer` for easier message identification and asynchronous appending of messages.
|
||||
3. Move the message receive buffer from the global variable `self.rc.env.memory` to the role's private variable
|
||||
`self.rc.msg_buffer` for easier message identification and asynchronous appending of messages.
|
||||
4. Standardize the way messages are passed: `publish_message` sends messages out, while `put_message` places
|
||||
messages into the Role object's private message receive buffer. There are no other message transmit methods.
|
||||
5. Standardize the parameters for the `run` function: the `test_message` parameter is used for testing purposes
|
||||
|
|
@ -24,12 +24,11 @@ from __future__ import annotations
|
|||
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Iterable, Set, Type
|
||||
from typing import Any, Iterable, Optional, Set, Type
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field, SerializeAsAny, model_validator
|
||||
|
||||
from metagpt.actions import Action, ActionOutput
|
||||
from metagpt.actions.action import action_subclass_registry
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.actions.add_requirement import UserRequirement
|
||||
from metagpt.const import SERDESER_PATH
|
||||
|
|
@ -37,7 +36,7 @@ from metagpt.llm import LLM, HumanProvider
|
|||
from metagpt.logs import logger
|
||||
from metagpt.memory import Memory
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.schema import Message, MessageQueue
|
||||
from metagpt.schema import Message, MessageQueue, SerializationMixin
|
||||
from metagpt.utils.common import (
|
||||
any_to_name,
|
||||
any_to_str,
|
||||
|
|
@ -92,8 +91,10 @@ class RoleReactMode(str, Enum):
|
|||
class RoleContext(BaseModel):
|
||||
"""Role Runtime Context"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
# # env exclude=True to avoid `RecursionError: maximum recursion depth exceeded in comparison`
|
||||
env: "Environment" = Field(default=None, exclude=True)
|
||||
env: "Environment" = Field(default=None, exclude=True) # # avoid circular import
|
||||
# TODO judge if ser&deser
|
||||
msg_buffer: MessageQueue = Field(
|
||||
default_factory=MessageQueue, exclude=True
|
||||
|
|
@ -109,9 +110,6 @@ class RoleContext(BaseModel):
|
|||
) # see `Role._set_react_mode` for definitions of the following two attributes
|
||||
max_react_loop: int = 1
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def check(self, role_id: str):
|
||||
# if hasattr(CONFIG, "long_term_memory") and CONFIG.long_term_memory:
|
||||
# self.long_term_memory.recover_memory(role_id, self)
|
||||
|
|
@ -128,12 +126,11 @@ class RoleContext(BaseModel):
|
|||
return self.memory.get()
|
||||
|
||||
|
||||
role_subclass_registry = {}
|
||||
|
||||
|
||||
class Role(BaseModel):
|
||||
class Role(SerializationMixin, is_polymorphic_base=True):
|
||||
"""Role/Agent"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, exclude=["llm"])
|
||||
|
||||
name: str = ""
|
||||
profile: str = ""
|
||||
goal: str = ""
|
||||
|
|
@ -141,84 +138,40 @@ class Role(BaseModel):
|
|||
desc: str = ""
|
||||
is_human: bool = False
|
||||
|
||||
_llm: BaseLLM = Field(default_factory=LLM) # Each role has its own LLM, use different system message
|
||||
_role_id: str = ""
|
||||
_states: list[str] = []
|
||||
_actions: list[Action] = []
|
||||
_rc: RoleContext = Field(default_factory=RoleContext)
|
||||
llm: BaseLLM = Field(default_factory=LLM, exclude=True) # Each role has its own LLM, use different system message
|
||||
role_id: str = ""
|
||||
states: list[str] = []
|
||||
actions: list[SerializeAsAny[Action]] = Field(default=[], validate_default=True)
|
||||
rc: RoleContext = Field(default_factory=RoleContext)
|
||||
subscription: set[str] = set()
|
||||
|
||||
# builtin variables
|
||||
recovered: bool = False # to tag if a recovered role
|
||||
latest_observed_msg: Message = None # record the latest observed message when interrupted
|
||||
builtin_class_name: str = ""
|
||||
|
||||
_private_attributes = {
|
||||
"_llm": None,
|
||||
"_role_id": _role_id,
|
||||
"_states": [],
|
||||
"_actions": [],
|
||||
"_rc": RoleContext(),
|
||||
"_subscription": set(),
|
||||
}
|
||||
latest_observed_msg: Optional[Message] = None # record the latest observed message when interrupted
|
||||
|
||||
__hash__ = object.__hash__ # support Role as hashable type in `Environment.members`
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
exclude = ["_llm"]
|
||||
@model_validator(mode="after")
|
||||
def check_subscription(self) -> set:
|
||||
if not self.subscription:
|
||||
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
|
||||
return self
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
for index in range(len(kwargs.get("_actions", []))):
|
||||
current_action = kwargs["_actions"][index]
|
||||
if isinstance(current_action, dict):
|
||||
item_class_name = current_action.get("builtin_class_name", None)
|
||||
for name, subclass in action_subclass_registry.items():
|
||||
registery_class_name = subclass.__fields__["builtin_class_name"].default
|
||||
if item_class_name == registery_class_name:
|
||||
current_action = subclass(**current_action)
|
||||
break
|
||||
kwargs["_actions"][index] = current_action
|
||||
def __init__(self, **data: Any):
|
||||
# --- avoid PydanticUndefinedAnnotation name 'Environment' is not defined #
|
||||
from metagpt.environment import Environment
|
||||
|
||||
super().__init__(**kwargs)
|
||||
Environment
|
||||
# ------
|
||||
Role.model_rebuild()
|
||||
super().__init__(**data)
|
||||
|
||||
# 关于私有变量的初始化 https://github.com/pydantic/pydantic/issues/655
|
||||
self._private_attributes["_llm"] = LLM() if not self.is_human else HumanProvider()
|
||||
self._private_attributes["_role_id"] = str(self._setting)
|
||||
self.subscription = {any_to_str(self), self.name} if self.name else {any_to_str(self)}
|
||||
|
||||
for key in self._private_attributes.keys():
|
||||
if key in kwargs:
|
||||
object.__setattr__(self, key, kwargs[key])
|
||||
if key == "_rc":
|
||||
_rc = RoleContext(**kwargs["_rc"])
|
||||
object.__setattr__(self, "_rc", _rc)
|
||||
else:
|
||||
if key == "_rc":
|
||||
# # Warning, if use self._private_attributes["_rc"],
|
||||
# # self._rc will be a shared object between roles, so init one or reset it inside `_reset`
|
||||
object.__setattr__(self, key, RoleContext())
|
||||
else:
|
||||
object.__setattr__(self, key, self._private_attributes[key])
|
||||
|
||||
self._llm.system_prompt = self._get_prefix()
|
||||
|
||||
# deserialize child classes dynamically for inherited `role`
|
||||
object.__setattr__(self, "builtin_class_name", self.__class__.__name__)
|
||||
self.__fields__["builtin_class_name"].default = self.__class__.__name__
|
||||
|
||||
if "actions" in kwargs:
|
||||
self._init_actions(kwargs["actions"])
|
||||
|
||||
self._watch(kwargs.get("watch") or [UserRequirement])
|
||||
|
||||
def __init_subclass__(cls, **kwargs: Any) -> None:
|
||||
super().__init_subclass__(**kwargs)
|
||||
role_subclass_registry[cls.__name__] = cls
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
self._watch(data.get("watch") or [UserRequirement])
|
||||
|
||||
def _reset(self):
|
||||
object.__setattr__(self, "_states", [])
|
||||
object.__setattr__(self, "_actions", [])
|
||||
self.states = []
|
||||
self.actions = []
|
||||
|
||||
@property
|
||||
def _setting(self):
|
||||
|
|
@ -231,12 +184,12 @@ class Role(BaseModel):
|
|||
else stg_path
|
||||
)
|
||||
|
||||
role_info = self.dict(exclude={"_rc": {"memory": True, "msg_buffer": True}, "_llm": True})
|
||||
role_info = self.model_dump(exclude={"rc": {"memory": True, "msg_buffer": True}, "llm": True})
|
||||
role_info.update({"role_class": self.__class__.__name__, "module_name": self.__module__})
|
||||
role_info_path = stg_path.joinpath("role_info.json")
|
||||
write_json_file(role_info_path, role_info)
|
||||
|
||||
self._rc.memory.serialize(stg_path) # serialize role's memory alone
|
||||
self.rc.memory.serialize(stg_path) # serialize role's memory alone
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Role":
|
||||
|
|
@ -260,13 +213,13 @@ class Role(BaseModel):
|
|||
action.set_prefix(self._get_prefix())
|
||||
|
||||
def refresh_system_message(self):
|
||||
self._llm.system_prompt = self._get_prefix()
|
||||
self.llm.system_prompt = self._get_prefix()
|
||||
|
||||
def set_recovered(self, recovered: bool = False):
|
||||
self.recovered = recovered
|
||||
|
||||
def set_memory(self, memory: Memory):
|
||||
self._rc.memory = memory
|
||||
self.rc.memory = memory
|
||||
|
||||
def init_actions(self, actions):
|
||||
self._init_actions(actions)
|
||||
|
|
@ -276,7 +229,7 @@ class Role(BaseModel):
|
|||
for idx, action in enumerate(actions):
|
||||
if not isinstance(action, Action):
|
||||
## 默认初始化
|
||||
i = action(name="", llm=self._llm)
|
||||
i = action(name="", llm=self.llm)
|
||||
else:
|
||||
if self.is_human and not isinstance(action.llm, HumanProvider):
|
||||
logger.warning(
|
||||
|
|
@ -285,10 +238,9 @@ class Role(BaseModel):
|
|||
f"try passing in Action classes instead of initialized instances"
|
||||
)
|
||||
i = action
|
||||
# i.set_env(self._rc.env)
|
||||
self._init_action_system_message(i)
|
||||
self._actions.append(i)
|
||||
self._states.append(f"{idx}. {action}")
|
||||
self.actions.append(i)
|
||||
self.states.append(f"{idx}. {action}")
|
||||
|
||||
def _set_react_mode(self, react_mode: str, max_react_loop: int = 1):
|
||||
"""Set strategy of the Role reacting to observed Message. Variation lies in how
|
||||
|
|
@ -307,20 +259,20 @@ class Role(BaseModel):
|
|||
Defaults to 1, i.e. _think -> _act (-> return result and end)
|
||||
"""
|
||||
assert react_mode in RoleReactMode.values(), f"react_mode must be one of {RoleReactMode.values()}"
|
||||
self._rc.react_mode = react_mode
|
||||
self.rc.react_mode = react_mode
|
||||
if react_mode == RoleReactMode.REACT:
|
||||
self._rc.max_react_loop = max_react_loop
|
||||
self.rc.max_react_loop = max_react_loop
|
||||
|
||||
def _watch(self, actions: Iterable[Type[Action]] | Iterable[Action]):
|
||||
"""Watch Actions of interest. Role will select Messages caused by these Actions from its personal message
|
||||
buffer during _observe.
|
||||
"""
|
||||
self._rc.watch = {any_to_str(t) for t in actions}
|
||||
self.rc.watch = {any_to_str(t) for t in actions}
|
||||
# check RoleContext after adding watch actions
|
||||
self._rc.check(self._role_id)
|
||||
self.rc.check(self.role_id)
|
||||
|
||||
def is_watch(self, caused_by: str):
|
||||
return caused_by in self._rc.watch
|
||||
return caused_by in self.rc.watch
|
||||
|
||||
def subscribe(self, tags: Set[str]):
|
||||
"""Used to receive Messages with certain tags from the environment. Message will be put into personal message
|
||||
|
|
@ -328,19 +280,19 @@ class Role(BaseModel):
|
|||
or profile.
|
||||
"""
|
||||
self.subscription = tags
|
||||
if self._rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
|
||||
self._rc.env.set_subscription(self, self.subscription)
|
||||
if self.rc.env: # According to the routing feature plan in Chapter 2.2.3.2 of RFC 113
|
||||
self.rc.env.set_subscription(self, self.subscription)
|
||||
|
||||
def _set_state(self, state: int):
|
||||
"""Update the current state."""
|
||||
self._rc.state = state
|
||||
logger.debug(f"actions={self._actions}, state={state}")
|
||||
self._rc.todo = self._actions[self._rc.state] if state >= 0 else None
|
||||
self.rc.state = state
|
||||
logger.debug(f"actions={self.actions}, state={state}")
|
||||
self.rc.todo = self.actions[self.rc.state] if state >= 0 else None
|
||||
|
||||
def set_env(self, env: "Environment"):
|
||||
"""Set the environment in which the role works. The role can talk to the environment and can also receive
|
||||
messages by observing."""
|
||||
self._rc.env = env
|
||||
self.rc.env = env
|
||||
if env:
|
||||
env.set_subscription(self, self.subscription)
|
||||
self.refresh_system_message() # add env message to system message
|
||||
|
|
@ -348,7 +300,7 @@ class Role(BaseModel):
|
|||
@property
|
||||
def action_count(self):
|
||||
"""Return number of action"""
|
||||
return len(self._actions)
|
||||
return len(self.actions)
|
||||
|
||||
def _get_prefix(self):
|
||||
"""Get the role prefix"""
|
||||
|
|
@ -360,38 +312,38 @@ class Role(BaseModel):
|
|||
if self.constraints:
|
||||
prefix += CONSTRAINT_TEMPLATE.format(**{"constraints": self.constraints})
|
||||
|
||||
if self._rc.env and self._rc.env.desc:
|
||||
other_role_names = ", ".join(self._rc.env.role_names())
|
||||
env_desc = f"You are in {self._rc.env.desc} with roles({other_role_names})."
|
||||
if self.rc.env and self.rc.env.desc:
|
||||
other_role_names = ", ".join(self.rc.env.role_names())
|
||||
env_desc = f"You are in {self.rc.env.desc} with roles({other_role_names})."
|
||||
prefix += env_desc
|
||||
return prefix
|
||||
|
||||
async def _think(self) -> bool:
|
||||
"""Consider what to do and decide on the next course of action. Return false if nothing can be done."""
|
||||
if len(self._actions) == 1:
|
||||
if len(self.actions) == 1:
|
||||
# If there is only one action, then only this one can be performed
|
||||
self._set_state(0)
|
||||
|
||||
return True
|
||||
|
||||
if self.recovered and self._rc.state >= 0:
|
||||
self._set_state(self._rc.state) # action to run from recovered state
|
||||
if self.recovered and self.rc.state >= 0:
|
||||
self._set_state(self.rc.state) # action to run from recovered state
|
||||
self.set_recovered(False) # avoid max_react_loop out of work
|
||||
return True
|
||||
|
||||
prompt = self._get_prefix()
|
||||
prompt += STATE_TEMPLATE.format(
|
||||
history=self._rc.history,
|
||||
states="\n".join(self._states),
|
||||
n_states=len(self._states) - 1,
|
||||
previous_state=self._rc.state,
|
||||
history=self.rc.history,
|
||||
states="\n".join(self.states),
|
||||
n_states=len(self.states) - 1,
|
||||
previous_state=self.rc.state,
|
||||
)
|
||||
|
||||
next_state = await self._llm.aask(prompt)
|
||||
next_state = await self.llm.aask(prompt)
|
||||
next_state = extract_state_value_from_output(next_state)
|
||||
logger.debug(f"{prompt=}")
|
||||
|
||||
if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self._states)):
|
||||
if (not next_state.isdigit() and next_state != "-1") or int(next_state) not in range(-1, len(self.states)):
|
||||
logger.warning(f"Invalid answer of state, {next_state=}, will be set to -1")
|
||||
next_state = -1
|
||||
else:
|
||||
|
|
@ -402,21 +354,21 @@ class Role(BaseModel):
|
|||
return True
|
||||
|
||||
async def _act(self) -> Message:
|
||||
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
|
||||
response = await self._rc.todo.run(self._rc.history)
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
response = await self.rc.todo.run(self.rc.history)
|
||||
if isinstance(response, (ActionOutput, ActionNode)):
|
||||
msg = Message(
|
||||
content=response.content,
|
||||
instruct_content=response.instruct_content,
|
||||
role=self._setting,
|
||||
cause_by=self._rc.todo,
|
||||
cause_by=self.rc.todo,
|
||||
sent_from=self,
|
||||
)
|
||||
elif isinstance(response, Message):
|
||||
msg = response
|
||||
else:
|
||||
msg = Message(content=response, role=self.profile, cause_by=self._rc.todo, sent_from=self)
|
||||
self._rc.memory.add(msg)
|
||||
msg = Message(content=response, role=self.profile, cause_by=self.rc.todo, sent_from=self)
|
||||
self.rc.memory.add(msg)
|
||||
|
||||
return msg
|
||||
|
||||
|
|
@ -426,7 +378,7 @@ class Role(BaseModel):
|
|||
observed_pure = [msg.dict(exclude={"id": True}) for msg in observed]
|
||||
existed_pure = [msg.dict(exclude={"id": True}) for msg in existed]
|
||||
for idx, new in enumerate(observed_pure):
|
||||
if (new["cause_by"] in self._rc.watch or self.name in new["send_to"]) and new not in existed_pure:
|
||||
if (new["cause_by"] in self.rc.watch or self.name in new["send_to"]) and new not in existed_pure:
|
||||
news.append(observed[idx])
|
||||
return news
|
||||
|
||||
|
|
@ -437,59 +389,59 @@ class Role(BaseModel):
|
|||
if self.recovered:
|
||||
news = [self.latest_observed_msg] if self.latest_observed_msg else []
|
||||
if not news:
|
||||
news = self._rc.msg_buffer.pop_all()
|
||||
news = self.rc.msg_buffer.pop_all()
|
||||
# Store the read messages in your own memory to prevent duplicate processing.
|
||||
old_messages = [] if ignore_memory else self._rc.memory.get()
|
||||
self._rc.memory.add_batch(news)
|
||||
old_messages = [] if ignore_memory else self.rc.memory.get()
|
||||
self.rc.memory.add_batch(news)
|
||||
# Filter out messages of interest.
|
||||
self._rc.news = [n for n in news if n.cause_by in self._rc.watch and n not in old_messages]
|
||||
self.latest_observed_msg = self._rc.news[-1] if self._rc.news else None # record the latest observed msg
|
||||
self.rc.news = [n for n in news if n.cause_by in self.rc.watch and n not in old_messages]
|
||||
self.latest_observed_msg = self.rc.news[-1] if self.rc.news else None # record the latest observed msg
|
||||
|
||||
# Design Rules:
|
||||
# If you need to further categorize Message objects, you can do so using the Message.set_meta function.
|
||||
# msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
|
||||
news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]
|
||||
news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
|
||||
if news_text:
|
||||
logger.debug(f"{self._setting} observed: {news_text}")
|
||||
return len(self._rc.news)
|
||||
return len(self.rc.news)
|
||||
|
||||
# async def _observe(self, ignore_memory=False) -> int:
|
||||
# """Prepare new messages for processing from the message buffer and other sources."""
|
||||
# # Read unprocessed messages from the msg buffer.
|
||||
# news = self._rc.msg_buffer.pop_all()
|
||||
# news = self.rc.msg_buffer.pop_all()
|
||||
# if self.recovered:
|
||||
# news = [self.latest_observed_msg] if self.latest_observed_msg else []
|
||||
# else:
|
||||
# self.latest_observed_msg = news[-1] if len(news) > 0 else None # record the latest observed msg
|
||||
#
|
||||
# # Store the read messages in your own memory to prevent duplicate processing.
|
||||
# old_messages = [] if ignore_memory else self._rc.memory.get()
|
||||
# self._rc.memory.add_batch(news)
|
||||
# old_messages = [] if ignore_memory else self.rc.memory.get()
|
||||
# self.rc.memory.add_batch(news)
|
||||
# # Filter out messages of interest.
|
||||
# self._rc.news = self._find_news(news, old_messages)
|
||||
# self.rc.news = self._find_news(news, old_messages)
|
||||
#
|
||||
# # Design Rules:
|
||||
# # If you need to further categorize Message objects, you can do so using the Message.set_meta function.
|
||||
# # msg_buffer is a receiving buffer, avoid adding message data and operations to msg_buffer.
|
||||
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self._rc.news]
|
||||
# news_text = [f"{i.role}: {i.content[:20]}..." for i in self.rc.news]
|
||||
# if news_text:
|
||||
# logger.debug(f"{self._setting} observed: {news_text}")
|
||||
# return len(self._rc.news)
|
||||
# return len(self.rc.news)
|
||||
|
||||
def publish_message(self, msg):
|
||||
"""If the role belongs to env, then the role's messages will be broadcast to env"""
|
||||
if not msg:
|
||||
return
|
||||
if not self._rc.env:
|
||||
if not self.rc.env:
|
||||
# If env does not exist, do not publish the message
|
||||
return
|
||||
self._rc.env.publish_message(msg)
|
||||
self.rc.env.publish_message(msg)
|
||||
|
||||
def put_message(self, message):
|
||||
"""Place the message into the Role object's private message buffer."""
|
||||
if not message:
|
||||
return
|
||||
self._rc.msg_buffer.push(message)
|
||||
self.rc.msg_buffer.push(message)
|
||||
|
||||
async def _react(self) -> Message:
|
||||
"""Think first, then act, until the Role _think it is time to stop and requires no more todo.
|
||||
|
|
@ -498,22 +450,22 @@ class Role(BaseModel):
|
|||
"""
|
||||
actions_taken = 0
|
||||
rsp = Message(content="No actions taken yet") # will be overwritten after Role _act
|
||||
while actions_taken < self._rc.max_react_loop:
|
||||
while actions_taken < self.rc.max_react_loop:
|
||||
# think
|
||||
await self._think()
|
||||
if self._rc.todo is None:
|
||||
if self.rc.todo is None:
|
||||
break
|
||||
# act
|
||||
logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}")
|
||||
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
|
||||
rsp = await self._act() # 这个rsp是否需要publish_message?
|
||||
actions_taken += 1
|
||||
return rsp # return output from the last action
|
||||
|
||||
async def _act_by_order(self) -> Message:
|
||||
"""switch action each time by order defined in _init_actions, i.e. _act (Action1) -> _act (Action2) -> ..."""
|
||||
start_idx = self._rc.state if self._rc.state >= 0 else 0 # action to run from recovered state
|
||||
rsp = Message(content="No actions taken yet") # return default message if _actions=[]
|
||||
for i in range(start_idx, len(self._states)):
|
||||
start_idx = self.rc.state if self.rc.state >= 0 else 0 # action to run from recovered state
|
||||
rsp = Message(content="No actions taken yet") # return default message if actions=[]
|
||||
for i in range(start_idx, len(self.states)):
|
||||
self._set_state(i)
|
||||
rsp = await self._act()
|
||||
return rsp # return output from the last action
|
||||
|
|
@ -525,18 +477,18 @@ class Role(BaseModel):
|
|||
|
||||
async def react(self) -> Message:
|
||||
"""Entry to one of three strategies by which Role reacts to the observed Message"""
|
||||
if self._rc.react_mode == RoleReactMode.REACT:
|
||||
if self.rc.react_mode == RoleReactMode.REACT:
|
||||
rsp = await self._react()
|
||||
elif self._rc.react_mode == RoleReactMode.BY_ORDER:
|
||||
elif self.rc.react_mode == RoleReactMode.BY_ORDER:
|
||||
rsp = await self._act_by_order()
|
||||
elif self._rc.react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
elif self.rc.react_mode == RoleReactMode.PLAN_AND_ACT:
|
||||
rsp = await self._plan_and_act()
|
||||
self._set_state(state=-1) # current reaction is complete, reset state to -1 and todo back to None
|
||||
return rsp
|
||||
|
||||
def get_memories(self, k=0) -> list[Message]:
|
||||
"""A wrapper to return the most recent k memories of this role, return all when k=0"""
|
||||
return self._rc.memory.get(k=k)
|
||||
return self.rc.memory.get(k=k)
|
||||
|
||||
@role_raise_decorator
|
||||
async def run(self, with_message=None) -> Message | None:
|
||||
|
|
@ -561,7 +513,7 @@ class Role(BaseModel):
|
|||
rsp = await self.react()
|
||||
|
||||
# Reset the next action to be taken.
|
||||
self._rc.todo = None
|
||||
self.rc.todo = None
|
||||
# Send the response message to the Environment object to have it relay the message to the subscribers.
|
||||
self.publish_message(rsp)
|
||||
return rsp
|
||||
|
|
@ -569,12 +521,12 @@ class Role(BaseModel):
|
|||
@property
|
||||
def is_idle(self) -> bool:
|
||||
"""If true, all actions have been executed."""
|
||||
return not self._rc.news and not self._rc.todo and self._rc.msg_buffer.empty()
|
||||
return not self.rc.news and not self.rc.todo and self.rc.msg_buffer.empty()
|
||||
|
||||
async def think(self) -> Action:
|
||||
"""The exported `think` function"""
|
||||
await self._think()
|
||||
return self._rc.todo
|
||||
return self.rc.todo
|
||||
|
||||
async def act(self) -> ActionOutput:
|
||||
"""The exported `act` function"""
|
||||
|
|
@ -584,6 +536,6 @@ class Role(BaseModel):
|
|||
@property
|
||||
def todo(self) -> str:
|
||||
"""AgentStore uses this attribute to display to the user what actions the current role should take."""
|
||||
if self._actions:
|
||||
return any_to_name(self._actions[0])
|
||||
if self.actions:
|
||||
return any_to_name(self.actions[0])
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -57,19 +57,19 @@ class Searcher(Role):
|
|||
|
||||
async def _act_sp(self) -> Message:
|
||||
"""Performs the search action in a single process."""
|
||||
logger.info(f"{self._setting}: to do {self._rc.todo}({self._rc.todo.name})")
|
||||
response = await self._rc.todo.run(self._rc.memory.get(k=0))
|
||||
logger.info(f"{self._setting}: to do {self.rc.todo}({self.rc.todo.name})")
|
||||
response = await self.rc.todo.run(self.rc.memory.get(k=0))
|
||||
|
||||
if isinstance(response, (ActionOutput, ActionNode)):
|
||||
msg = Message(
|
||||
content=response.content,
|
||||
instruct_content=response.instruct_content,
|
||||
role=self.profile,
|
||||
cause_by=self._rc.todo,
|
||||
cause_by=self.rc.todo,
|
||||
)
|
||||
else:
|
||||
msg = Message(content=response, role=self.profile, cause_by=self._rc.todo)
|
||||
self._rc.memory.add(msg)
|
||||
msg = Message(content=response, role=self.profile, cause_by=self.rc.todo)
|
||||
self.rc.memory.add(msg)
|
||||
return msg
|
||||
|
||||
async def _act(self) -> Message:
|
||||
|
|
|
|||
|
|
@ -7,13 +7,13 @@
|
|||
@Modified By: mashenquan, 2023-11-1. In accordance with Chapter 2.2.1 and 2.2.2 of RFC 116, utilize the new message
|
||||
distribution feature for message filtering.
|
||||
"""
|
||||
from typing import Any, Type
|
||||
from typing import Any, Callable, Union
|
||||
|
||||
from pydantic import Field
|
||||
from semantic_kernel import Kernel
|
||||
from semantic_kernel.planning import SequentialPlanner
|
||||
from semantic_kernel.planning.action_planner.action_planner import ActionPlanner
|
||||
from semantic_kernel.planning.basic_planner import BasicPlanner
|
||||
from semantic_kernel.planning.basic_planner import BasicPlanner, Plan
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.actions.execute_task import ExecuteTask
|
||||
|
|
@ -41,17 +41,17 @@ class SkAgent(Role):
|
|||
goal: str = "Execute task based on passed in task description"
|
||||
constraints: str = ""
|
||||
|
||||
plan: Any = None
|
||||
plan: Plan = None
|
||||
planner_cls: Any = None
|
||||
planner: Any = None
|
||||
planner: Union[BasicPlanner, SequentialPlanner, ActionPlanner] = None
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
kernel: Kernel = Field(default_factory=Kernel)
|
||||
import_semantic_skill_from_directory: Type[Kernel.import_semantic_skill_from_directory] = None
|
||||
import_skill: Type[Kernel.import_skill] = None
|
||||
import_semantic_skill_from_directory: Callable = None
|
||||
import_skill: Callable = None
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
def __init__(self, **data: Any) -> None:
|
||||
"""Initializes the Engineer role with given attributes."""
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(**data)
|
||||
self._init_actions([ExecuteTask()])
|
||||
self._watch([UserRequirement])
|
||||
self.kernel = make_sk_kernel()
|
||||
|
|
@ -71,10 +71,10 @@ class SkAgent(Role):
|
|||
self._set_state(0)
|
||||
# how funny the interface is inconsistent
|
||||
if isinstance(self.planner, BasicPlanner):
|
||||
self.plan = await self.planner.create_plan_async(self._rc.important_memory[-1].content, self.kernel)
|
||||
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content, self.kernel)
|
||||
logger.info(self.plan.generated_plan)
|
||||
elif any(isinstance(self.planner, cls) for cls in [SequentialPlanner, ActionPlanner]):
|
||||
self.plan = await self.planner.create_plan_async(self._rc.important_memory[-1].content)
|
||||
self.plan = await self.planner.create_plan_async(self.rc.important_memory[-1].content)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
# how funny the interface is inconsistent
|
||||
|
|
@ -85,6 +85,6 @@ class SkAgent(Role):
|
|||
result = (await self.plan.invoke_async()).result
|
||||
logger.info(result)
|
||||
|
||||
msg = Message(content=result, role=self.profile, cause_by=self._rc.todo)
|
||||
self._rc.memory.add(msg)
|
||||
msg = Message(content=result, role=self.profile, cause_by=self.rc.todo)
|
||||
self.rc.memory.add(msg)
|
||||
return msg
|
||||
|
|
|
|||
|
|
@ -42,34 +42,34 @@ class Teacher(Role):
|
|||
|
||||
async def _think(self) -> bool:
|
||||
"""Everything will be done part by part."""
|
||||
if not self._actions:
|
||||
if not self._rc.news or self._rc.news[0].cause_by != any_to_str(UserRequirement):
|
||||
if not self.actions:
|
||||
if not self.rc.news or self.rc.news[0].cause_by != any_to_str(UserRequirement):
|
||||
raise ValueError("Lesson content invalid.")
|
||||
actions = []
|
||||
print(TeachingPlanBlock.TOPICS)
|
||||
for topic in TeachingPlanBlock.TOPICS:
|
||||
act = WriteTeachingPlanPart(context=self._rc.news[0].content, topic=topic, llm=self._llm)
|
||||
act = WriteTeachingPlanPart(context=self.rc.news[0].content, topic=topic, llm=self.llm)
|
||||
actions.append(act)
|
||||
self._init_actions(actions)
|
||||
|
||||
if self._rc.todo is None:
|
||||
if self.rc.todo is None:
|
||||
self._set_state(0)
|
||||
return True
|
||||
|
||||
if self._rc.state + 1 < len(self._states):
|
||||
self._set_state(self._rc.state + 1)
|
||||
if self.rc.state + 1 < len(self.states):
|
||||
self._set_state(self.rc.state + 1)
|
||||
return True
|
||||
|
||||
self._rc.todo = None
|
||||
self.rc.todo = None
|
||||
return False
|
||||
|
||||
async def _react(self) -> Message:
|
||||
ret = Message(content="")
|
||||
while True:
|
||||
await self._think()
|
||||
if self._rc.todo is None:
|
||||
if self.rc.todo is None:
|
||||
break
|
||||
logger.debug(f"{self._setting}: {self._rc.state=}, will do {self._rc.todo}")
|
||||
logger.debug(f"{self._setting}: {self.rc.state=}, will do {self.rc.todo}")
|
||||
msg = await self._act()
|
||||
if ret.content != "":
|
||||
ret.content += "\n\n\n"
|
||||
|
|
@ -104,7 +104,7 @@ class Teacher(Role):
|
|||
def course_title(self):
|
||||
"""Return course title of teaching plan"""
|
||||
default_title = "teaching_plan"
|
||||
for act in self._actions:
|
||||
for act in self.actions:
|
||||
if act.topic != TeachingPlanBlock.COURSE_TITLE:
|
||||
continue
|
||||
if act.rsp is None:
|
||||
|
|
|
|||
|
|
@ -34,9 +34,9 @@ class TutorialAssistant(Role):
|
|||
constraints: str = "Strictly follow Markdown's syntax, with neat and standardized layout"
|
||||
language: str = "Chinese"
|
||||
|
||||
topic = ""
|
||||
main_title = ""
|
||||
total_content = ""
|
||||
topic: str = ""
|
||||
main_title: str = ""
|
||||
total_content: str = ""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
|
@ -71,9 +71,9 @@ class TutorialAssistant(Role):
|
|||
Returns:
|
||||
A message containing the result of the action.
|
||||
"""
|
||||
todo = self._rc.todo
|
||||
todo = self.rc.todo
|
||||
if type(todo) is WriteDirectory:
|
||||
msg = self._rc.memory.get(k=1)[0]
|
||||
msg = self.rc.memory.get(k=1)[0]
|
||||
self.topic = msg.content
|
||||
resp = await todo.run(topic=self.topic)
|
||||
logger.info(resp)
|
||||
|
|
|
|||
|
|
@ -23,9 +23,17 @@ from abc import ABC
|
|||
from asyncio import Queue, QueueEmpty, wait_for
|
||||
from json import JSONDecodeError
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Set, Type, TypeVar
|
||||
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
field_serializer,
|
||||
field_validator,
|
||||
)
|
||||
from pydantic_core import core_schema
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.const import (
|
||||
|
|
@ -46,6 +54,64 @@ from metagpt.utils.serialize import (
|
|||
)
|
||||
|
||||
|
||||
class SerializationMixin(BaseModel):
|
||||
"""SereDeserMixin for subclass' ser&deser"""
|
||||
|
||||
__is_polymorphic_base = False
|
||||
__subclasses_map__ = {}
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, source: type["SerializationMixin"], handler: Callable[[Any], core_schema.CoreSchema]
|
||||
) -> core_schema.CoreSchema:
|
||||
schema = handler(source)
|
||||
og_schema_ref = schema["ref"]
|
||||
schema["ref"] += ":mixin"
|
||||
|
||||
return core_schema.no_info_before_validator_function(
|
||||
cls.__deserialize_with_real_type__,
|
||||
schema=schema,
|
||||
ref=og_schema_ref,
|
||||
serialization=core_schema.wrap_serializer_function_ser_schema(cls.__serialize_add_class_type__),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __serialize_add_class_type__(
|
||||
cls,
|
||||
value,
|
||||
handler: core_schema.SerializerFunctionWrapHandler,
|
||||
) -> Any:
|
||||
ret = handler(value)
|
||||
if not len(cls.__subclasses__()):
|
||||
# only subclass add `__module_class_name`
|
||||
ret["__module_class_name"] = f"{cls.__module__}.{cls.__qualname__}"
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def __deserialize_with_real_type__(cls, value: Any):
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
if not cls.__is_polymorphic_base or (len(cls.__subclasses__()) and "__module_class_name" not in value):
|
||||
# add right condition to init BaseClass like Action()
|
||||
return value
|
||||
module_class_name = value.get("__module_class_name", None)
|
||||
if module_class_name is None:
|
||||
raise ValueError("Missing field: __module_class_name")
|
||||
|
||||
class_type = cls.__subclasses_map__.get(module_class_name, None)
|
||||
|
||||
if class_type is None:
|
||||
raise TypeError("Trying to instantiate {module_class_name} which not defined yet.")
|
||||
|
||||
return class_type(**value)
|
||||
|
||||
def __init_subclass__(cls, is_polymorphic_base: bool = False, **kwargs):
|
||||
cls.__is_polymorphic_base = is_polymorphic_base
|
||||
cls.__subclasses_map__[f"{cls.__module__}.{cls.__qualname__}"] = cls
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
|
||||
class SimpleMessage(BaseModel):
|
||||
content: str
|
||||
role: str
|
||||
|
|
@ -102,33 +168,64 @@ class Documents(BaseModel):
|
|||
class Message(BaseModel):
|
||||
"""list[<role>: <content>]"""
|
||||
|
||||
id: str # According to Section 2.2.3.1.1 of RFC 135
|
||||
id: str = Field(default="", validate_default=True) # According to Section 2.2.3.1.1 of RFC 135
|
||||
content: str
|
||||
instruct_content: BaseModel = None
|
||||
instruct_content: Optional[BaseModel] = Field(default=None, validate_default=True)
|
||||
role: str = "user" # system / user / assistant
|
||||
cause_by: str = ""
|
||||
sent_from: str = ""
|
||||
send_to: Set = Field(default_factory={MESSAGE_ROUTE_TO_ALL})
|
||||
cause_by: str = Field(default="", validate_default=True)
|
||||
sent_from: str = Field(default="", validate_default=True)
|
||||
send_to: set = Field(default={MESSAGE_ROUTE_TO_ALL}, validate_default=True)
|
||||
|
||||
def __init__(self, content: str = "", **kwargs):
|
||||
ic = kwargs.get("instruct_content", None)
|
||||
@field_validator("id", mode="before")
|
||||
@classmethod
|
||||
def check_id(cls, id: str) -> str:
|
||||
return id if id else uuid.uuid4().hex
|
||||
|
||||
@field_validator("instruct_content", mode="before")
|
||||
@classmethod
|
||||
def check_instruct_content(cls, ic: Any) -> BaseModel:
|
||||
if ic and not isinstance(ic, BaseModel) and "class" in ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
mapping = actionoutput_str_to_mapping(ic["mapping"])
|
||||
|
||||
actionnode_class = import_class("ActionNode", "metagpt.actions.action_node") # avoid circular import
|
||||
ic_obj = actionnode_class.create_model_class(class_name=ic["class"], mapping=mapping)
|
||||
ic_new = ic_obj(**ic["value"])
|
||||
kwargs["instruct_content"] = ic_new
|
||||
ic = ic_obj(**ic["value"])
|
||||
return ic
|
||||
|
||||
kwargs["id"] = kwargs.get("id", uuid.uuid4().hex)
|
||||
kwargs["content"] = kwargs.get("content", content)
|
||||
kwargs["cause_by"] = any_to_str(
|
||||
kwargs.get("cause_by", import_class("UserRequirement", "metagpt.actions.add_requirement"))
|
||||
)
|
||||
kwargs["sent_from"] = any_to_str(kwargs.get("sent_from", ""))
|
||||
kwargs["send_to"] = any_to_str_set(kwargs.get("send_to", {MESSAGE_ROUTE_TO_ALL}))
|
||||
super(Message, self).__init__(**kwargs)
|
||||
@field_validator("cause_by", mode="before")
|
||||
@classmethod
|
||||
def check_cause_by(cls, cause_by: Any) -> str:
|
||||
return any_to_str(cause_by if cause_by else import_class("UserRequirement", "metagpt.actions.add_requirement"))
|
||||
|
||||
@field_validator("sent_from", mode="before")
|
||||
@classmethod
|
||||
def check_sent_from(cls, sent_from: Any) -> str:
|
||||
return any_to_str(sent_from if sent_from else "")
|
||||
|
||||
@field_validator("send_to", mode="before")
|
||||
@classmethod
|
||||
def check_send_to(cls, send_to: Any) -> set:
|
||||
return any_to_str_set(send_to if send_to else {MESSAGE_ROUTE_TO_ALL})
|
||||
|
||||
@field_serializer("instruct_content", mode="plain")
|
||||
def ser_instruct_content(self, ic: BaseModel) -> Union[str, None]:
|
||||
ic_dict = None
|
||||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.model_json_schema()
|
||||
# `Documents` contain definitions
|
||||
if "definitions" not in schema:
|
||||
# TODO refine with nested BaseModel
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
ic_dict = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
||||
return ic_dict
|
||||
|
||||
def __init__(self, content: str = "", **data: Any):
|
||||
data["content"] = data.get("content", content)
|
||||
super().__init__(**data)
|
||||
|
||||
def __setattr__(self, key, val):
|
||||
"""Override `@property.setter`, convert non-string parameters into string parameters."""
|
||||
|
|
@ -142,26 +239,10 @@ class Message(BaseModel):
|
|||
new_val = val
|
||||
super().__setattr__(key, new_val)
|
||||
|
||||
def dict(self, *args, **kwargs) -> "DictStrAny":
|
||||
"""overwrite the `dict` to dump dynamic pydantic model"""
|
||||
obj_dict = super(Message, self).dict(*args, **kwargs)
|
||||
ic = self.instruct_content
|
||||
if ic:
|
||||
# compatible with custom-defined ActionOutput
|
||||
schema = ic.schema()
|
||||
# `Documents` contain definitions
|
||||
if "definitions" not in schema:
|
||||
# TODO refine with nested BaseModel
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
mapping = actionoutput_mapping_to_str(mapping)
|
||||
|
||||
obj_dict["instruct_content"] = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
|
||||
return obj_dict
|
||||
|
||||
def __str__(self):
|
||||
# prefix = '-'.join([self.role, str(self.cause_by)])
|
||||
if self.instruct_content:
|
||||
return f"{self.role}: {self.instruct_content.dict()}"
|
||||
return f"{self.role}: {self.instruct_content.model_dump()}"
|
||||
return f"{self.role}: {self.content}"
|
||||
|
||||
def __repr__(self):
|
||||
|
|
@ -173,7 +254,7 @@ class Message(BaseModel):
|
|||
|
||||
def dump(self) -> str:
|
||||
"""Convert the object to json string"""
|
||||
return self.json(exclude_none=True)
|
||||
return self.model_dump_json(exclude_none=True, warnings=False)
|
||||
|
||||
@staticmethod
|
||||
@handle_exception(exception_type=JSONDecodeError, default_return=None)
|
||||
|
|
@ -224,19 +305,9 @@ class AIMessage(Message):
|
|||
class MessageQueue(BaseModel):
|
||||
"""Message queue which supports asynchronous updates."""
|
||||
|
||||
_queue: Queue = Field(default_factory=Queue)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
_private_attributes = {"_queue": Queue()}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
for key in self._private_attributes.keys():
|
||||
if key in kwargs:
|
||||
object.__setattr__(self, key, kwargs[key])
|
||||
else:
|
||||
object.__setattr__(self, key, Queue())
|
||||
_queue: Queue = PrivateAttr(default_factory=Queue)
|
||||
|
||||
def pop(self) -> Message | None:
|
||||
"""Pop one message from the queue."""
|
||||
|
|
@ -312,28 +383,28 @@ class BaseContext(BaseModel, ABC):
|
|||
|
||||
class CodingContext(BaseContext):
|
||||
filename: str
|
||||
design_doc: Optional[Document]
|
||||
task_doc: Optional[Document]
|
||||
code_doc: Optional[Document]
|
||||
design_doc: Optional[Document] = None
|
||||
task_doc: Optional[Document] = None
|
||||
code_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class TestingContext(BaseContext):
|
||||
filename: str
|
||||
code_doc: Document
|
||||
test_doc: Optional[Document]
|
||||
test_doc: Optional[Document] = None
|
||||
|
||||
|
||||
class RunCodeContext(BaseContext):
|
||||
mode: str = "script"
|
||||
code: Optional[str]
|
||||
code: Optional[str] = None
|
||||
code_filename: str = ""
|
||||
test_code: Optional[str]
|
||||
test_code: Optional[str] = None
|
||||
test_filename: str = ""
|
||||
command: List[str] = Field(default_factory=list)
|
||||
working_directory: str = ""
|
||||
additional_python_paths: List[str] = Field(default_factory=list)
|
||||
output_filename: Optional[str]
|
||||
output: Optional[str]
|
||||
output_filename: Optional[str] = None
|
||||
output: Optional[str] = None
|
||||
|
||||
|
||||
class RunCodeResult(BaseContext):
|
||||
|
|
|
|||
4
metagpt/strategy/__init__.py
Normal file
4
metagpt/strategy/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/23/2023 4:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
108
metagpt/strategy/base.py
Normal file
108
metagpt/strategy/base.py
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 9:16 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from typing import List
|
||||
|
||||
from anytree import Node, RenderTree
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseParser(BaseModel):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def propose(self, current_state: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def sample(self, current_state: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
def value(self, input: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class BaseEvaluator(BaseModel):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
def status_verify(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ThoughtNode(Node):
|
||||
"""A node representing a thought in the thought tree."""
|
||||
|
||||
name: str = ""
|
||||
value: int = 0
|
||||
id: int = 0
|
||||
valid_status: bool = True
|
||||
|
||||
def update_value(self, value) -> None:
|
||||
"""Update the value of the thought node."""
|
||||
self.value = value
|
||||
|
||||
def update_valid_status(self, status) -> None:
|
||||
"""Update the validity status of the thought node."""
|
||||
self.valid_status = status
|
||||
|
||||
|
||||
class ThoughtTree(RenderTree):
|
||||
"""A tree structure to represent thoughts."""
|
||||
|
||||
@property
|
||||
def all_nodes(self) -> List[ThoughtNode]:
|
||||
"""
|
||||
Get a list of all nodes in the thought tree.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: A list containing all nodes in the thought tree.
|
||||
"""
|
||||
all_nodes = [node for _, _, node in self]
|
||||
return all_nodes
|
||||
|
||||
def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]:
|
||||
"""
|
||||
Update the tree with new thoughts.
|
||||
|
||||
Args:
|
||||
thought (List[dict]): A list of dictionaries representing thought information.
|
||||
current_node (ThoughtNode): The current node under which new thoughts will be added.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes.
|
||||
"""
|
||||
nodes = []
|
||||
for node_info in thought:
|
||||
node = ThoughtNode(
|
||||
name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"])
|
||||
)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
def parse_node_path(self, node) -> List[str]:
|
||||
"""
|
||||
Parse and retrieve the hierarchical path of the given thought node.
|
||||
|
||||
This method traverses the parent nodes of the provided 'node' and constructs
|
||||
the full path from the root node to the given node.
|
||||
|
||||
Args:
|
||||
node: The thought node for which the hierarchical path needs to be parsed.
|
||||
|
||||
Returns:
|
||||
List[str]: A list representing the full hierarchical path of the given thought node.
|
||||
The list is ordered from the root node to the provided node.
|
||||
"""
|
||||
full_node_path = []
|
||||
while node is not None:
|
||||
full_node_path.append(node.name)
|
||||
node = node.parent
|
||||
full_node_path.reverse()
|
||||
return full_node_path
|
||||
|
||||
def show(self) -> None:
|
||||
"""Print the updated tree."""
|
||||
print("\nUpdated Tree:")
|
||||
for pre, _, node in self:
|
||||
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")
|
||||
4
metagpt/strategy/examples/__init__.py
Normal file
4
metagpt/strategy/examples/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/26/2023 3:32 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
73
metagpt/strategy/examples/creative_writing.py
Normal file
73
metagpt/strategy/examples/creative_writing.py
Normal file
|
|
@ -0,0 +1,73 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 1:06 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import re
|
||||
|
||||
from metagpt.strategy.prompt_templates.creative_writing import cot_prompt, vote_prompt
|
||||
from metagpt.strategy.tot import TreeofThought
|
||||
from metagpt.strategy.tot_schema import (
|
||||
BaseEvaluator,
|
||||
BaseParser,
|
||||
Strategy,
|
||||
ThoughtSolverConfig,
|
||||
)
|
||||
|
||||
|
||||
class TextGenParser(BaseParser):
|
||||
propose_prompt: str = cot_prompt
|
||||
value_prompt: str = vote_prompt
|
||||
|
||||
def __call__(self, input_text: str) -> str:
|
||||
return input_text
|
||||
|
||||
def propose(self, current_state: str, **kwargs) -> str:
|
||||
return self.propose_prompt.format(input=current_state, **kwargs)
|
||||
|
||||
def value(self, input: str = "", **kwargs) -> str:
|
||||
# node_result = self(input)
|
||||
id = kwargs.get("node_id", "0")
|
||||
return self.value_prompt + f"Choice {id}:\n{input}\n"
|
||||
|
||||
|
||||
class TextGenEvaluator(BaseEvaluator):
|
||||
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
|
||||
status_map = {val: key for key, val in value_map.items()}
|
||||
|
||||
def __call__(self, evaluation: str, **kwargs) -> float:
|
||||
try:
|
||||
value = 0
|
||||
node_id = kwargs.get("node_id", "0")
|
||||
pattern = r".*best choice is .*(\d+).*"
|
||||
match = re.match(pattern, evaluation, re.DOTALL)
|
||||
|
||||
if match:
|
||||
vote = int(match.groups()[0])
|
||||
print(vote)
|
||||
if vote == int(node_id):
|
||||
value = 1
|
||||
except:
|
||||
value = 0
|
||||
return value
|
||||
|
||||
def status_verify(self, value):
|
||||
status = False
|
||||
if value in self.status_map:
|
||||
status_value = self.status_map[value]
|
||||
if status_value != "impossible":
|
||||
status = True
|
||||
return status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
initial_prompt = """It isn't difficult to do a handstand if you just stand on your hands. It caught him off guard that space smelled of seared steak. When she didn’t like a guy who was trying to pick her up, she started using sign language. Each person who knows you has a different perception of who you are."""
|
||||
|
||||
parser = TextGenParser()
|
||||
evaluator = TextGenEvaluator()
|
||||
|
||||
config = ThoughtSolverConfig(n_generate_sample=3, parser=parser, evaluator=evaluator)
|
||||
|
||||
tot_base = TreeofThought(strategy=Strategy.BFS, config=config)
|
||||
asyncio.run(tot_base.solve(init_prompt=initial_prompt))
|
||||
64
metagpt/strategy/examples/game24.py
Normal file
64
metagpt/strategy/examples/game24.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 1:36 AM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import re
|
||||
|
||||
from metagpt.strategy.prompt_templates.game24 import propose_prompt, value_prompt
|
||||
from metagpt.strategy.tot import TreeofThought
|
||||
from metagpt.strategy.tot_schema import (
|
||||
BaseEvaluator,
|
||||
BaseParser,
|
||||
Strategy,
|
||||
ThoughtSolverConfig,
|
||||
)
|
||||
|
||||
|
||||
class Game24Parser(BaseParser):
|
||||
propose_prompt: str = propose_prompt
|
||||
value_prompt: str = value_prompt
|
||||
|
||||
def __call__(self, input_text: str) -> str:
|
||||
last_line = input_text.strip().split("\n")[-1]
|
||||
return last_line.split("left: ")[-1].split(")")[0]
|
||||
|
||||
def propose(self, current_state: str, **kwargs) -> str:
|
||||
return self.propose_prompt.format(input=current_state, **kwargs)
|
||||
|
||||
def value(self, input: str = "", **kwargs) -> str:
|
||||
node_result = self(input)
|
||||
return self.value_prompt.format(input=node_result)
|
||||
|
||||
|
||||
class Game24Evaluator(BaseEvaluator):
|
||||
value_map = {"impossible": 0.001, "likely": 1, "sure": 20} # TODO: ad hoc
|
||||
status_map = {val: key for key, val in value_map.items()}
|
||||
|
||||
def __call__(self, evaluation: str, **kwargs) -> float:
|
||||
try:
|
||||
matches = re.findall(r"\b(impossible|sure|likely)\b", evaluation)
|
||||
value = self.value_map[matches[0]]
|
||||
except:
|
||||
value = 0.001
|
||||
return value
|
||||
|
||||
def status_verify(self, value):
|
||||
status = False
|
||||
if value in self.status_map:
|
||||
status_value = self.status_map[value]
|
||||
if status_value != "impossible":
|
||||
status = True
|
||||
return status
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
|
||||
initial_prompt = """4 5 6 10"""
|
||||
parser = Game24Parser()
|
||||
evaluator = Game24Evaluator()
|
||||
|
||||
config = ThoughtSolverConfig(n_generate_sample=5, parser=parser, evaluator=evaluator)
|
||||
|
||||
tot = TreeofThought(strategy=Strategy.BFS, config=config)
|
||||
asyncio.run(tot.solve(init_prompt=initial_prompt))
|
||||
4
metagpt/strategy/prompt_templates/__init__.py
Normal file
4
metagpt/strategy/prompt_templates/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/23/2023 5:21 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
25
metagpt/strategy/prompt_templates/creative_writing.py
Normal file
25
metagpt/strategy/prompt_templates/creative_writing.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
standard_prompt = """
|
||||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
|
||||
"""
|
||||
|
||||
cot_prompt = """
|
||||
Write a coherent passage of 4 short paragraphs. The end sentence of each paragraph must be: {input}
|
||||
|
||||
Make a plan then write. Your output should be of the following format:
|
||||
|
||||
Plan:
|
||||
Your plan here.
|
||||
|
||||
Passage:
|
||||
Your passage here.
|
||||
"""
|
||||
|
||||
|
||||
vote_prompt = """Given an instruction and several choices, decide which choice is most promising. Analyze each choice in detail, then conclude in the last line "The best choice is {s}", where s the integer id of the choice.
|
||||
"""
|
||||
|
||||
compare_prompt = """Briefly analyze the coherency of the following two passages. Conclude in the last line "The more coherent passage is 1", "The more coherent passage is 2", or "The two passages are similarly coherent".
|
||||
"""
|
||||
|
||||
score_prompt = """Analyze the following passage, then at the last line conclude "Thus the coherency score is {s}", where s is an integer from 1 to 10.
|
||||
"""
|
||||
139
metagpt/strategy/prompt_templates/game24.py
Normal file
139
metagpt/strategy/prompt_templates/game24.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
# 5-shot
|
||||
standard_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24.
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) = 24
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * 12 * (10 - 9) = 24
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 9) * (10 - 4) = 24
|
||||
Input: 1 4 8 8
|
||||
Answer: (8 / 4 + 1) * 8 = 24
|
||||
Input: 5 5 5 9
|
||||
Answer: 5 + 5 + 5 + 9 = 24
|
||||
Input: {input}
|
||||
"""
|
||||
|
||||
# 5-shot
|
||||
cot_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Each step, you are only allowed to choose two of the remaining numbers to obtain a new number.
|
||||
Input: 4 4 6 8
|
||||
Steps:
|
||||
4 + 8 = 12 (left: 4 6 12)
|
||||
6 - 4 = 2 (left: 2 12)
|
||||
2 * 12 = 24 (left: 24)
|
||||
Answer: (6 - 4) * (4 + 8) = 24
|
||||
Input: 2 9 10 12
|
||||
Steps:
|
||||
12 * 2 = 24 (left: 9 10 24)
|
||||
10 - 9 = 1 (left: 1 24)
|
||||
24 * 1 = 24 (left: 24)
|
||||
Answer: (12 * 2) * (10 - 9) = 24
|
||||
Input: 4 9 10 13
|
||||
Steps:
|
||||
13 - 10 = 3 (left: 3 4 9)
|
||||
9 - 3 = 6 (left: 4 6)
|
||||
4 * 6 = 24 (left: 24)
|
||||
Answer: 4 * (9 - (13 - 10)) = 24
|
||||
Input: 1 4 8 8
|
||||
Steps:
|
||||
8 / 4 = 2 (left: 1 2 8)
|
||||
1 + 2 = 3 (left: 3 8)
|
||||
3 * 8 = 24 (left: 24)
|
||||
Answer: (1 + 8 / 4) * 8 = 24
|
||||
Input: 5 5 5 9
|
||||
Steps:
|
||||
5 + 5 = 10 (left: 5 9 10)
|
||||
10 + 5 = 15 (left: 9 15)
|
||||
15 + 9 = 24 (left: 24)
|
||||
Answer: ((5 + 5) + 5) + 9 = 24
|
||||
Input: {input}
|
||||
"""
|
||||
|
||||
# 1-shot
|
||||
propose_prompt = """Here is an Example for 1 input and 8 possible thoughts:
|
||||
Input: 2 8 8 14
|
||||
Possible next steps:
|
||||
2 + 8 = 10 (left: 8 10 14)
|
||||
8 / 2 = 4 (left: 4 8 14)
|
||||
14 + 2 = 16 (left: 8 8 16)
|
||||
2 * 8 = 16 (left: 8 14 16)
|
||||
8 - 2 = 6 (left: 6 8 14)
|
||||
14 - 8 = 6 (left: 2 6 8)
|
||||
14 / 2 = 7 (left: 7 8 8)
|
||||
14 - 2 = 12 (left: 8 8 12)
|
||||
|
||||
Here is my task for 1 input and {n_generate_sample} possible thoughts:
|
||||
Input: {input}
|
||||
Possible next steps:
|
||||
|
||||
|
||||
"""
|
||||
|
||||
value_prompt = """Evaluate if given numbers can reach 24 (sure/likely/impossible)
|
||||
10 14
|
||||
10 + 14 = 24
|
||||
sure
|
||||
11 12
|
||||
11 + 12 = 23
|
||||
12 - 11 = 1
|
||||
11 * 12 = 132
|
||||
11 / 12 = 0.91
|
||||
impossible
|
||||
4 4 10
|
||||
4 + 4 + 10 = 8 + 10 = 18
|
||||
4 * 10 - 4 = 40 - 4 = 36
|
||||
(10 - 4) * 4 = 6 * 4 = 24
|
||||
sure
|
||||
4 9 11
|
||||
9 + 11 + 4 = 20 + 4 = 24
|
||||
sure
|
||||
5 7 8
|
||||
5 + 7 + 8 = 12 + 8 = 20
|
||||
(8 - 5) * 7 = 3 * 7 = 21
|
||||
I cannot obtain 24 now, but numbers are within a reasonable range
|
||||
likely
|
||||
5 6 6
|
||||
5 + 6 + 6 = 17
|
||||
(6 - 5) * 6 = 1 * 6 = 6
|
||||
I cannot obtain 24 now, but numbers are within a reasonable range
|
||||
likely
|
||||
10 10 11
|
||||
10 + 10 + 11 = 31
|
||||
(11 - 10) * 10 = 10
|
||||
10 10 10 are all too big
|
||||
impossible
|
||||
1 3 3
|
||||
1 * 3 * 3 = 9
|
||||
(1 + 3) * 3 = 12
|
||||
1 3 3 are all too small
|
||||
impossible
|
||||
{input}
|
||||
"""
|
||||
|
||||
value_last_step_prompt = """Use numbers and basic arithmetic operations (+ - * /) to obtain 24. Given an input and an answer, give a judgement (sure/impossible) if the answer is correct, i.e. it uses each input exactly once and no other numbers, and reach 24.
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * 12 * (10 - 9) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 9) * (10 - 4) = 24
|
||||
Judge:
|
||||
sure
|
||||
Input: 4 4 6 8
|
||||
Answer: (4 + 8) * (6 - 4) + 1 = 25
|
||||
Judge:
|
||||
impossible
|
||||
Input: 2 9 10 12
|
||||
Answer: 2 * (12 - 10) = 24
|
||||
Judge:
|
||||
impossible
|
||||
Input: 4 9 10 13
|
||||
Answer: (13 - 4) * (10 - 9) = 24
|
||||
Judge:
|
||||
impossible
|
||||
Input: {input}
|
||||
Answer: {answer}
|
||||
Judge:"""
|
||||
272
metagpt/strategy/tot.py
Normal file
272
metagpt/strategy/tot.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/23/2023 4:51 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
from typing import Any, List
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.strategy.base import ThoughtNode, ThoughtTree
|
||||
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
||||
OUTPUT_FORMAT = """
|
||||
Output a list of jsons following the format:
|
||||
```json
|
||||
[
|
||||
{
|
||||
"node_id": str = "unique identifier for a solution, can be an ordinal",
|
||||
"node_state_instruction": "specified sample of solution",
|
||||
},
|
||||
...
|
||||
]
|
||||
```
|
||||
"""
|
||||
|
||||
|
||||
class ThoughtSolverBase(BaseModel):
|
||||
thought_tree: str = ""
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
|
||||
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.llm.use_system_prompt = False
|
||||
|
||||
async def solve(self, init_prompt):
|
||||
"""
|
||||
Solve method for subclasses to implement.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the solve method")
|
||||
|
||||
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
|
||||
"""
|
||||
Generate children thoughts based on the current state.
|
||||
|
||||
Args:
|
||||
current_state (str): The current state for which thoughts are generated.
|
||||
current_node (ThoughtNode): The current node in the thought tree.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: List of nodes representing the generated thoughts.
|
||||
"""
|
||||
state_prompt = self.config.parser.propose(
|
||||
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
|
||||
)
|
||||
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
|
||||
thoughts = CodeParser.parse_code(block=None, text=rsp)
|
||||
thoughts = eval(thoughts)
|
||||
# fixme 避免不跟随,生成过多nodes
|
||||
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
|
||||
return self.thought_tree.update_node(thoughts, current_node=current_node)
|
||||
|
||||
async def evaluate_node(self, node, parent_value) -> None:
|
||||
"""
|
||||
Evaluate a node and update its status and value.
|
||||
|
||||
Args:
|
||||
node (ThoughtNode): The node to be evaluated.
|
||||
parent_value (float): The parent node's value.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
|
||||
evaluation = await self.llm.aask(msg=eval_prompt)
|
||||
|
||||
value = self.config.evaluator(evaluation, **{"node_id": node.id})
|
||||
status = self.config.evaluator.status_verify(value)
|
||||
|
||||
node.update_valid_status(status=status)
|
||||
# 累计分数
|
||||
node.update_value(parent_value + value)
|
||||
|
||||
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
|
||||
"""
|
||||
Select nodes based on the configured selection method.
|
||||
|
||||
Args:
|
||||
thought_nodes (List[ThoughtNode]): List of nodes to be selected.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: List of selected nodes.
|
||||
"""
|
||||
# selection
|
||||
if self.config.method_select == MethodSelect.SAMPLE:
|
||||
raise NotImplementedError
|
||||
elif self.config.method_select == MethodSelect.GREEDY:
|
||||
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
|
||||
for node in thought_nodes:
|
||||
if node not in select_nodes:
|
||||
node.parent = None # 从树中删除节点
|
||||
return select_nodes
|
||||
|
||||
def update_solution(self):
|
||||
"""
|
||||
Select the result with the highest score.
|
||||
|
||||
Returns:
|
||||
- List[ThoughtNode]: List of nodes representing the best solution.
|
||||
- List[str]: List of node names forming the best solution path.
|
||||
"""
|
||||
best_node = max(self.thought_tree.all_nodes, key=lambda x: x.value, default=None)
|
||||
best_solution_path = self.thought_tree.parse_node_path(best_node)
|
||||
return [best_node], best_solution_path
|
||||
|
||||
|
||||
class BFSSolver(ThoughtSolverBase):
|
||||
async def solve(self, init_prompt=""):
|
||||
"""
|
||||
Solve the problem using Breadth-First Search (BFS) strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
|
||||
Returns:
|
||||
List[str]: The best solution path obtained through BFS.
|
||||
"""
|
||||
root = ThoughtNode(init_prompt)
|
||||
self.thought_tree = ThoughtTree(root)
|
||||
current_nodes = [root]
|
||||
for step in range(self.config.max_steps):
|
||||
solutions = await self._bfs_build(current_nodes)
|
||||
|
||||
selected_nodes = self.select_nodes(solutions)
|
||||
current_nodes = selected_nodes
|
||||
|
||||
self.thought_tree.show()
|
||||
|
||||
best_solution, best_solution_path = self.update_solution()
|
||||
logger.info(f"best solution is: {best_solution_path}")
|
||||
return best_solution_path
|
||||
|
||||
async def _bfs_build(self, current_nodes):
|
||||
"""
|
||||
Build the thought tree using Breadth-First Search (BFS) strategy.
|
||||
|
||||
Args:
|
||||
current_nodes (List[ThoughtNode]): Current nodes to expand.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: The solutions obtained after expanding the current nodes.
|
||||
"""
|
||||
tasks = []
|
||||
for node in current_nodes:
|
||||
current_state = self.config.parser(node.name)
|
||||
current_value = node.value
|
||||
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
|
||||
|
||||
thought_nodes_list = await asyncio.gather(*tasks)
|
||||
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
|
||||
return solutions
|
||||
|
||||
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
|
||||
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
|
||||
await asyncio.gather(
|
||||
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
|
||||
)
|
||||
return thought_nodes
|
||||
|
||||
|
||||
class DFSSolver(ThoughtSolverBase):
|
||||
async def _dfs(self, root_node):
|
||||
"""
|
||||
Perform Depth-First Search (DFS) on the thought tree.
|
||||
|
||||
Args:
|
||||
root_node (ThoughtNode): The root node of the thought tree.
|
||||
|
||||
Returns:
|
||||
List[str]: The solution path obtained through DFS.
|
||||
"""
|
||||
impossible_state_cnt = 0
|
||||
node = root_node
|
||||
for step in range(self.max_steps):
|
||||
current_state = self.config.parser(node.name)
|
||||
current_value = node.value
|
||||
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
|
||||
await self.evaluate_node(thought_nodes[0], parent_value=current_value)
|
||||
if thought_nodes[0].valid_status is False:
|
||||
impossible_state_cnt += 1
|
||||
if impossible_state_cnt >= 2:
|
||||
logger.info("impossible state reached, break")
|
||||
break
|
||||
node = thought_nodes[0]
|
||||
_solution_path = self.thought_tree.parse_node_path(node)
|
||||
self.thought_tree.show()
|
||||
|
||||
return _solution_path
|
||||
|
||||
async def solve(self, init_prompt="", root=ThoughtNode("")):
|
||||
"""
|
||||
Solve the problem using Depth-First Search (DFS) strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
|
||||
Returns:
|
||||
List[str]: The best solution path obtained through DFS.
|
||||
"""
|
||||
root = ThoughtNode(init_prompt)
|
||||
self.thought_tree = ThoughtTree(root)
|
||||
for n in range(self.config.n_solution_sample):
|
||||
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
|
||||
await self._dfs(root)
|
||||
|
||||
best_solution, best_solution_path = self.update_solution()
|
||||
logger.info(f"best solution is: {best_solution_path}")
|
||||
return best_solution_path
|
||||
|
||||
|
||||
class MCTSSolver(ThoughtSolverBase):
|
||||
async def solve(self, init_prompt=""):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TreeofThought(BaseModel):
|
||||
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
|
||||
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
|
||||
strategy: Strategy = Field(default=Strategy.BFS)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._initialize_solver(self.strategy)
|
||||
|
||||
def _initialize_solver(self, strategy):
|
||||
"""
|
||||
Initialize the solver based on the chosen strategy.
|
||||
|
||||
Args:
|
||||
strategy (Strategy): The strategy to use for solving.
|
||||
|
||||
Returns:
|
||||
ThoughtSolverBase: An instance of the appropriate solver.
|
||||
"""
|
||||
if strategy == Strategy.BFS:
|
||||
self.solver = BFSSolver(config=self.config)
|
||||
elif strategy == Strategy.DFS:
|
||||
self.solver = DFSSolver(config=self.config)
|
||||
elif strategy == Strategy.MCTS:
|
||||
self.solver = MCTSSolver(config=self.config)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
|
||||
|
||||
async def solve(self, init_prompt=""):
|
||||
"""
|
||||
Solve the problem using the specified strategy.
|
||||
|
||||
Args:
|
||||
init_prompt (str): The initial prompt for the solver.
|
||||
strategy (str): The strategy to use for solving.
|
||||
|
||||
Returns:
|
||||
Any: The solution obtained using the selected strategy.
|
||||
"""
|
||||
await self.solver.solve(init_prompt)
|
||||
30
metagpt/strategy/tot_schema.py
Normal file
30
metagpt/strategy/tot_schema.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 12/25/2023 9:14 PM
|
||||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
from enum import Enum
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.strategy.base import BaseEvaluator, BaseParser
|
||||
|
||||
|
||||
class MethodSelect(Enum):
|
||||
SAMPLE = "sample"
|
||||
GREEDY = "greedy"
|
||||
|
||||
|
||||
class Strategy(Enum):
|
||||
BFS = "BFS"
|
||||
DFS = "DFS"
|
||||
MCTS = "MCTS"
|
||||
|
||||
|
||||
class ThoughtSolverConfig(BaseModel):
|
||||
max_steps: int = 3
|
||||
method_select: str = MethodSelect.GREEDY # ["sample"/"greedy"]
|
||||
n_generate_sample: int = 5 # per node
|
||||
n_select_sample: int = 3 # per path
|
||||
n_solution_sample: int = 5 # only for dfs
|
||||
parser: BaseParser = Field(default_factory=BaseParser)
|
||||
evaluator: BaseEvaluator = Field(default_factory=BaseEvaluator)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
from typing import AsyncGenerator, Awaitable, Callable
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.logs import logger
|
||||
from metagpt.roles import Role
|
||||
|
|
@ -33,10 +33,9 @@ class SubscriptionRunner(BaseModel):
|
|||
>>> asyncio.run(main())
|
||||
"""
|
||||
|
||||
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict)
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
tasks: dict[Role, asyncio.Task] = Field(default_factory=dict)
|
||||
|
||||
async def subscribe(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -10,8 +10,9 @@
|
|||
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from metagpt.actions import UserRequirement
|
||||
from metagpt.config import CONFIG
|
||||
|
|
@ -34,32 +35,27 @@ class Team(BaseModel):
|
|||
dedicated to env any multi-agent activity, such as collaboratively writing executable code.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
env: Environment = Field(default_factory=Environment)
|
||||
investment: float = Field(default=10.0)
|
||||
idea: str = Field(default="")
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
if "roles" in kwargs:
|
||||
self.hire(kwargs["roles"])
|
||||
if "env_desc" in kwargs:
|
||||
self.env.desc = kwargs["env_desc"]
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
def __init__(self, **data: Any):
|
||||
super(Team, self).__init__(**data)
|
||||
if "roles" in data:
|
||||
self.hire(data["roles"])
|
||||
if "env_desc" in data:
|
||||
self.env.desc = data["env_desc"]
|
||||
|
||||
def serialize(self, stg_path: Path = None):
|
||||
stg_path = SERDESER_PATH.joinpath("team") if stg_path is None else stg_path
|
||||
|
||||
team_info_path = stg_path.joinpath("team_info.json")
|
||||
write_json_file(team_info_path, self.dict(exclude={"env": True}))
|
||||
write_json_file(team_info_path, self.model_dump(exclude={"env": True}))
|
||||
|
||||
self.env.serialize(stg_path.joinpath("environment")) # save environment alone
|
||||
|
||||
@classmethod
|
||||
def recover(cls, stg_path: Path) -> "Team":
|
||||
return cls.deserialize(stg_path)
|
||||
|
||||
@classmethod
|
||||
def deserialize(cls, stg_path: Path) -> "Team":
|
||||
"""stg_path = ./storage/team"""
|
||||
|
|
@ -76,7 +72,6 @@ class Team(BaseModel):
|
|||
# recover environment
|
||||
environment = Environment.deserialize(stg_path=stg_path.joinpath("environment"))
|
||||
team_info.update({"env": environment})
|
||||
|
||||
team = Team(**team_info)
|
||||
return team
|
||||
|
||||
|
|
@ -121,7 +116,7 @@ class Team(BaseModel):
|
|||
return self.run_project(idea=idea, send_to=send_to)
|
||||
|
||||
def _save(self):
|
||||
logger.info(self.json(ensure_ascii=False))
|
||||
logger.info(self.model_dump_json())
|
||||
|
||||
@serialize_decorator
|
||||
async def run(self, n_round=3, idea="", send_to="", auto_archive=True):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ from typing import Optional
|
|||
from urllib.parse import urlparse
|
||||
|
||||
import httplib2
|
||||
from pydantic import BaseModel, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -25,15 +25,14 @@ except ImportError:
|
|||
|
||||
|
||||
class GoogleAPIWrapper(BaseModel):
|
||||
google_api_key: Optional[str] = None
|
||||
google_cse_id: Optional[str] = None
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
google_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
google_cse_id: Optional[str] = Field(default=None, validate_default=True)
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
executor: Optional[futures.Executor] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("google_api_key", always=True)
|
||||
@field_validator("google_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_google_api_key(cls, val: str):
|
||||
val = val or CONFIG.google_api_key
|
||||
|
|
@ -45,7 +44,7 @@ class GoogleAPIWrapper(BaseModel):
|
|||
)
|
||||
return val
|
||||
|
||||
@validator("google_cse_id", always=True)
|
||||
@field_validator("google_cse_id", mode="before")
|
||||
@classmethod
|
||||
def check_google_cse_id(cls, val: str):
|
||||
val = val or CONFIG.google_cse_id
|
||||
|
|
|
|||
|
|
@ -8,13 +8,15 @@
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class SerpAPIWrapper(BaseModel):
|
||||
search_engine: Any #: :meta private:
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
search_engine: Any = None #: :meta private:
|
||||
params: dict = Field(
|
||||
default={
|
||||
"engine": "google",
|
||||
|
|
@ -23,13 +25,11 @@ class SerpAPIWrapper(BaseModel):
|
|||
"hl": "en",
|
||||
}
|
||||
)
|
||||
serpapi_api_key: Optional[str] = None
|
||||
# should add `validate_default=True` to check with default value
|
||||
serpapi_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("serpapi_api_key", always=True)
|
||||
@field_validator("serpapi_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serpapi_api_key(cls, val: str):
|
||||
val = val or CONFIG.serpapi_api_key
|
||||
|
|
|
|||
|
|
@ -9,21 +9,18 @@ import json
|
|||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from metagpt.config import CONFIG
|
||||
|
||||
|
||||
class SerperWrapper(BaseModel):
|
||||
search_engine: Any #: :meta private:
|
||||
search_engine: Any = None #: :meta private:
|
||||
payload: dict = Field(default={"page": 1, "num": 10})
|
||||
serper_api_key: Optional[str] = None
|
||||
serper_api_key: Optional[str] = Field(default=None, validate_default=True)
|
||||
aiosession: Optional[aiohttp.ClientSession] = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@validator("serper_api_key", always=True)
|
||||
@field_validator("serper_api_key", mode="before")
|
||||
@classmethod
|
||||
def check_serper_api_key(cls, val: str):
|
||||
val = val or CONFIG.serper_api_key
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ from typing import Any, Callable, List, Tuple, Union, get_args, get_origin
|
|||
|
||||
import aiofiles
|
||||
import loguru
|
||||
from pydantic.json import pydantic_encoder
|
||||
from pydantic_core import to_jsonable_python
|
||||
from tenacity import RetryCallState, _utils
|
||||
|
||||
from metagpt.const import MESSAGE_ROUTE_TO_ALL
|
||||
|
|
@ -472,7 +472,7 @@ def write_json_file(json_file: str, data: list, encoding=None):
|
|||
folder_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with open(json_file, "w", encoding=encoding) as fout:
|
||||
json.dump(data, fout, ensure_ascii=False, indent=4, default=pydantic_encoder)
|
||||
json.dump(data, fout, ensure_ascii=False, indent=4, default=to_jsonable_python)
|
||||
|
||||
|
||||
def import_class(class_name: str, module_name: str) -> type:
|
||||
|
|
@ -512,7 +512,7 @@ def role_raise_decorator(func):
|
|||
except KeyboardInterrupt as kbi:
|
||||
logger.error(f"KeyboardInterrupt: {kbi} occurs, start to serialize the project")
|
||||
if self.latest_observed_msg:
|
||||
self._rc.memory.delete(self.latest_observed_msg)
|
||||
self.rc.memory.delete(self.latest_observed_msg)
|
||||
# raise again to make it captured outside
|
||||
raise Exception(format_trackback_info(limit=None))
|
||||
except Exception:
|
||||
|
|
@ -522,7 +522,7 @@ def role_raise_decorator(func):
|
|||
"we delete the newest role communication message in the role's memory."
|
||||
)
|
||||
# remove role newest observed msg to make it observed again
|
||||
self._rc.memory.delete(self.latest_observed_msg)
|
||||
self.rc.memory.delete(self.latest_observed_msg)
|
||||
# raise again to make it captured outside
|
||||
raise Exception(format_trackback_info(limit=None))
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from typing import Generator, Optional
|
|||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
from bs4 import BeautifulSoup
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr
|
||||
|
||||
|
||||
class WebPage(BaseModel):
|
||||
|
|
@ -13,11 +13,8 @@ class WebPage(BaseModel):
|
|||
html: str
|
||||
url: str
|
||||
|
||||
class Config:
|
||||
underscore_attrs_are_private = True
|
||||
|
||||
_soup: Optional[BeautifulSoup] = None
|
||||
_title: Optional[str] = None
|
||||
_soup: Optional[BeautifulSoup] = PrivateAttr(default=None)
|
||||
_title: Optional[str] = PrivateAttr(default=None)
|
||||
|
||||
@property
|
||||
def soup(self) -> BeautifulSoup:
|
||||
|
|
|
|||
|
|
@ -62,10 +62,10 @@ def serialize_message(message: "Message"):
|
|||
ic = message_cp.instruct_content
|
||||
if ic:
|
||||
# model create by pydantic create_model like `pydantic.main.prd`, can't pickle.dump directly
|
||||
schema = ic.schema()
|
||||
schema = ic.model_json_schema()
|
||||
mapping = actionoutout_schema_to_mapping(schema)
|
||||
|
||||
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.dict()}
|
||||
message_cp.instruct_content = {"class": schema["title"], "mapping": mapping, "value": ic.model_dump()}
|
||||
msg_ser = pickle.dumps(message_cp)
|
||||
|
||||
return msg_ser
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue