mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
refactor action_output and action_node
This commit is contained in:
parent
a75ab7971f
commit
d159bfc4e1
7 changed files with 18 additions and 39 deletions
|
|
@ -6,7 +6,7 @@
|
|||
@File : action_node.py
|
||||
"""
|
||||
import json
|
||||
from typing import Dict, Generic, List, Optional, Type, TypeVar
|
||||
from typing import Any, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from pydantic import BaseModel, create_model, root_validator, validator
|
||||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
|
@ -127,7 +127,7 @@ class ActionNode(Generic[T]):
|
|||
return self.get_self_mapping()
|
||||
|
||||
@classmethod
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Type]):
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Tuple[Type, Any]]):
|
||||
"""基于pydantic v1的模型动态生成,用来检验结果类型正确性"""
|
||||
new_class = create_model(class_name, **mapping)
|
||||
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@
|
|||
@File : action_output
|
||||
"""
|
||||
|
||||
from typing import Dict, Type
|
||||
|
||||
from pydantic import BaseModel, create_model, root_validator, validator
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class ActionOutput:
|
||||
|
|
@ -18,25 +16,3 @@ class ActionOutput:
|
|||
def __init__(self, content: str, instruct_content: BaseModel):
|
||||
self.content = content
|
||||
self.instruct_content = instruct_content
|
||||
|
||||
@classmethod
|
||||
def create_model_class(cls, class_name: str, mapping: Dict[str, Type]):
|
||||
new_class = create_model(class_name, **mapping)
|
||||
|
||||
@validator("*", allow_reuse=True)
|
||||
def check_name(v, field):
|
||||
if field.name not in mapping.keys():
|
||||
raise ValueError(f"Unrecognized block: {field.name}")
|
||||
return v
|
||||
|
||||
@root_validator(pre=True, allow_reuse=True)
|
||||
def check_missing_fields(values):
|
||||
required_fields = set(mapping.keys())
|
||||
missing_fields = required_fields - set(values.keys())
|
||||
if missing_fields:
|
||||
raise ValueError(f"Missing fields: {missing_fields}")
|
||||
return values
|
||||
|
||||
new_class.__validator_check_name = classmethod(check_name)
|
||||
new_class.__root_validator_check_missing_fields = classmethod(check_missing_fields)
|
||||
return new_class
|
||||
|
|
|
|||
|
|
@ -69,7 +69,7 @@ class WritePRD(Action):
|
|||
content: Optional[str] = None
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM)
|
||||
|
||||
async def run(self, with_messages, format=CONFIG.prompt_format, *args, **kwargs) -> ActionOutput | Message:
|
||||
async def run(self, with_messages, schema=CONFIG.prompt_schema, *args, **kwargs) -> ActionOutput | Message:
|
||||
# Determine which requirement documents need to be rewritten: Use LLM to assess whether new requirements are
|
||||
# related to the PRD. If they are related, rewrite the PRD.
|
||||
docs_file_repo = CONFIG.git_repo.new_file_repository(relative_path=DOCS_FILE_REPO)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,12 @@
|
|||
import copy
|
||||
import pickle
|
||||
|
||||
<<<<<<< HEAD
|
||||
from metagpt.utils.common import import_class
|
||||
=======
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.schema import Message
|
||||
>>>>>>> 09e2f05 (refactor action_output and action_node)
|
||||
|
||||
|
||||
def actionoutout_schema_to_mapping(schema: dict) -> dict:
|
||||
|
|
@ -104,13 +109,11 @@ def deserialize_general_message(message_dict: dict) -> "Message":
|
|||
return message
|
||||
|
||||
|
||||
def deserialize_message(message_ser: str) -> "Message":
|
||||
def deserialize_message(message_ser: str) -> Message:
|
||||
message = pickle.loads(message_ser)
|
||||
if message.instruct_content:
|
||||
ic = message.instruct_content
|
||||
|
||||
actionoutput_class = import_class("ActionOutput", "metagpt.actions.action_output")
|
||||
ic_obj = actionoutput_class.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
ic_obj = ActionNode.create_model_class(class_name=ic["class"], mapping=ic["mapping"])
|
||||
ic_new = ic_obj(**ic["value"])
|
||||
message.instruct_content = ic_new
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
"""
|
||||
from typing import List, Tuple
|
||||
|
||||
from metagpt.actions import ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
|
||||
t_dict = {
|
||||
"Required Python third-party packages": '"""\nflask==1.1.2\npygame==2.0.1\n"""\n',
|
||||
|
|
@ -37,12 +37,12 @@ WRITE_TASKS_OUTPUT_MAPPING = {
|
|||
|
||||
|
||||
def test_create_model_class():
|
||||
test_class = ActionOutput.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
test_class = ActionNode.create_model_class("test_class", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
assert test_class.__name__ == "test_class"
|
||||
|
||||
|
||||
def test_create_model_class_with_mapping():
|
||||
t = ActionOutput.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t = ActionNode.create_model_class("test_class_1", WRITE_TASKS_OUTPUT_MAPPING)
|
||||
t1 = t(**t_dict)
|
||||
value = t1.dict()["Task list"]
|
||||
assert value == ["game.py", "app.py", "static/css/styles.css", "static/js/script.js", "templates/index.html"]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@
|
|||
from typing import List
|
||||
|
||||
from metagpt.actions import UserRequirement, WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.memory.memory_storage import MemoryStorage
|
||||
from metagpt.schema import Message
|
||||
|
||||
|
|
@ -42,7 +42,7 @@ def test_idea_message():
|
|||
def test_actionout_message():
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
role_id = "UTUser2(Architect)"
|
||||
content = "The user has requested the creation of a command-line interface (CLI) snake game"
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@
|
|||
from typing import List, Tuple
|
||||
|
||||
from metagpt.actions import WritePRD
|
||||
from metagpt.actions.action_output import ActionOutput
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.schema import Message
|
||||
from metagpt.utils.serialize import (
|
||||
actionoutout_schema_to_mapping,
|
||||
|
|
@ -54,7 +54,7 @@ def test_actionoutout_schema_to_mapping():
|
|||
def test_serialize_and_deserialize_message():
|
||||
out_mapping = {"field1": (str, ...), "field2": (List[str], ...)}
|
||||
out_data = {"field1": "field1 value", "field2": ["field2 value1", "field2 value2"]}
|
||||
ic_obj = ActionOutput.create_model_class("prd", out_mapping)
|
||||
ic_obj = ActionNode.create_model_class("prd", out_mapping)
|
||||
|
||||
message = Message(
|
||||
content="prd demand", instruct_content=ic_obj(**out_data), role="user", cause_by=WritePRD
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue