diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index ca41c76a5..162ab90eb 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -12,7 +12,7 @@ import json from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union -from pydantic import BaseModel, create_model, model_validator +from pydantic import BaseModel, Field, create_model, model_validator from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_outcls_registry import register_action_outcls @@ -186,11 +186,27 @@ class ActionNode: obj.add_children(nodes) return obj - def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: + def get_children_mapping_old(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: """获得子ActionNode的字典,以key索引""" exclude = exclude or [] return {k: (v.expected_type, ...) for k, v in self.children.items() if k not in exclude} + def get_children_mapping(self, exclude=None) -> Dict[str, Tuple[Type, Any]]: + """获得子ActionNode的字典,以key索引,支持多级结构""" + exclude = exclude or [] + mapping = {} + + def _get_mapping(node: "ActionNode", prefix: str = ""): + for key, child in node.children.items(): + if key in exclude: + continue + full_key = f"{prefix}{key}" + mapping[full_key] = (child.expected_type, ...) + _get_mapping(child, prefix=f"{full_key}.") + + _get_mapping(self) + return mapping + def get_self_mapping(self) -> Dict[str, Tuple[Type, Any]]: """get self key: type mapping""" return {self.key: (self.expected_type, ...)} @@ -616,3 +632,62 @@ class ActionNode: self.update_instruct_content(revise_contents) return revise_contents + + @classmethod + def from_pydantic(cls, model: Type[BaseModel], key: str = None): + """ + Creates an ActionNode tree from a Pydantic model. + + Args: + model (Type[BaseModel]): The Pydantic model to convert. + + Returns: + ActionNode: The root node of the created ActionNode tree. + """ + key = key or model.__name__ + root_node = cls(key=model.__name__, expected_type=Type[model], instruction="", example="") + + for field_name, field_model in model.model_fields.items(): + # Extracting field details + expected_type = field_model.annotation + instruction = field_model.description or "" + example = field_model.default + + # Check if the field is a Pydantic model itself. + # Use isinstance to avoid typing.List, typing.Dict, etc. (they are instances of type, not subclasses) + if isinstance(expected_type, type) and issubclass(expected_type, BaseModel): + # Recursively process the nested model + child_node = cls.from_pydantic(expected_type, key=field_name) + else: + child_node = cls(key=field_name, expected_type=expected_type, instruction=instruction, example=example) + + root_node.add_child(child_node) + + return root_node + + +class ToolUse(BaseModel): + tool_name: str = Field(default="a", description="tool name", examples=[]) + + +class Task(BaseModel): + task_id: int = Field(default="1", description="task id", examples=[1, 2, 3]) + name: str = Field(default="Get data from ...", description="task name", examples=[]) + dependent_task_ids: List[int] = Field(default=[], description="dependent task ids", examples=[1, 2, 3]) + tool: ToolUse = Field(default=ToolUse(), description="tool use", examples=[]) + + +class Tasks(BaseModel): + tasks: List[Task] = Field(default=[], description="tasks", examples=[]) + + +if __name__ == "__main__": + node = ActionNode.from_pydantic(Tasks) + print("Tasks") + print(Tasks.model_json_schema()) + print("Task") + print(Task.model_json_schema()) + print(node) + prompt = node.compile(context="") + node.create_children_class() + print(prompt)