mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-14 15:25:17 +02:00
update docstring
This commit is contained in:
parent
e94ccbf631
commit
86d497a0bd
2 changed files with 77 additions and 51 deletions
|
|
@ -4,21 +4,20 @@
|
|||
# @Desc :
|
||||
from typing import List
|
||||
|
||||
from pydantic import BaseModel
|
||||
from anytree import Node, RenderTree
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class BaseParser(BaseModel):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def propose(self, current_state: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def sample(self, current_state: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def value(self, input: str, **kwargs) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
@ -26,22 +25,23 @@ class BaseParser(BaseModel):
|
|||
class BaseEvaluator(BaseModel):
|
||||
def __call__(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def status_verify(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
||||
class ThoughtNode(Node):
|
||||
"""A node representing a thought in the thought tree."""
|
||||
|
||||
|
||||
name: str = ""
|
||||
value: int = 0
|
||||
id: int = 0
|
||||
valid_status: bool = True
|
||||
|
||||
|
||||
def update_value(self, value) -> None:
|
||||
"""Update the value of the thought node."""
|
||||
self.value = value
|
||||
|
||||
|
||||
def update_valid_status(self, status) -> None:
|
||||
"""Update the validity status of the thought node."""
|
||||
self.valid_status = status
|
||||
|
|
@ -49,33 +49,60 @@ class ThoughtNode(Node):
|
|||
|
||||
class ThoughtTree(RenderTree):
|
||||
"""A tree structure to represent thoughts."""
|
||||
|
||||
|
||||
@property
|
||||
def all_nodes(self) -> List[ThoughtNode]:
|
||||
"""Get a list of all nodes in the thought tree."""
|
||||
"""
|
||||
Get a list of all nodes in the thought tree.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: A list containing all nodes in the thought tree.
|
||||
"""
|
||||
all_nodes = [node for _, _, node in self]
|
||||
return all_nodes
|
||||
|
||||
|
||||
def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]:
|
||||
"""Update the tree with new thoughts."""
|
||||
"""
|
||||
Update the tree with new thoughts.
|
||||
|
||||
Args:
|
||||
thought (List[dict]): A list of dictionaries representing thought information.
|
||||
current_node (ThoughtNode): The current node under which new thoughts will be added.
|
||||
|
||||
Returns:
|
||||
List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes.
|
||||
"""
|
||||
nodes = []
|
||||
for node_info in thought:
|
||||
node = ThoughtNode(name=node_info["node_state_instruction"], parent=current_node,
|
||||
id=int(node_info["node_id"]))
|
||||
node = ThoughtNode(
|
||||
name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"])
|
||||
)
|
||||
nodes.append(node)
|
||||
return nodes
|
||||
|
||||
|
||||
def parse_node_path(self, node) -> List[str]:
|
||||
"""Parse the path of the given thought node."""
|
||||
"""
|
||||
Parse and retrieve the hierarchical path of the given thought node.
|
||||
|
||||
This method traverses the parent nodes of the provided 'node' and constructs
|
||||
the full path from the root node to the given node.
|
||||
|
||||
Args:
|
||||
node: The thought node for which the hierarchical path needs to be parsed.
|
||||
|
||||
Returns:
|
||||
List[str]: A list representing the full hierarchical path of the given thought node.
|
||||
The list is ordered from the root node to the provided node.
|
||||
"""
|
||||
full_node_path = []
|
||||
while node is not None:
|
||||
full_node_path.append(node.name)
|
||||
node = node.parent
|
||||
full_node_path.reverse()
|
||||
return full_node_path
|
||||
|
||||
|
||||
def show(self) -> None:
|
||||
"""Print the updated tree."""
|
||||
print("\nUpdated Tree:")
|
||||
for pre, _, node in self:
|
||||
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")
|
||||
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")
|
||||
|
|
|
|||
|
|
@ -3,18 +3,16 @@
|
|||
# @Author : stellahong (stellahong@fuzhi.ai)
|
||||
# @Desc :
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Any, List
|
||||
from functools import wraps
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.base_gpt_api import BaseGPTAPI
|
||||
from metagpt.strategy.base import ThoughtNode, ThoughtTree
|
||||
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.strategy.tot_schema import ThoughtSolverConfig, Strategy, MethodSelect
|
||||
from metagpt.strategy.base import ThoughtNode, ThoughtTree, BaseParser, BaseEvaluator
|
||||
|
||||
OUTPUT_FORMAT = """
|
||||
Output a list of jsons following the format:
|
||||
|
|
@ -34,17 +32,17 @@ class ThoughtSolverBase(BaseModel):
|
|||
thought_tree: str = ""
|
||||
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
|
||||
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
|
||||
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self.llm.use_system_prompt = False
|
||||
|
||||
|
||||
async def solve(self, init_prompt):
|
||||
"""
|
||||
Solve method for subclasses to implement.
|
||||
"""
|
||||
raise NotImplementedError("Subclasses must implement the solve method")
|
||||
|
||||
|
||||
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
|
||||
"""
|
||||
Generate children thoughts based on the current state.
|
||||
|
|
@ -56,15 +54,16 @@ class ThoughtSolverBase(BaseModel):
|
|||
Returns:
|
||||
List[ThoughtNode]: List of nodes representing the generated thoughts.
|
||||
"""
|
||||
state_prompt = self.config.parser.propose(current_state=current_state,
|
||||
**{"n_generate_sample": self.config.n_generate_sample})
|
||||
state_prompt = self.config.parser.propose(
|
||||
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
|
||||
)
|
||||
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
|
||||
thoughts = CodeParser.parse_code(block=None, text=rsp)
|
||||
thoughts = eval(thoughts)
|
||||
# fixme 避免不跟随,生成过多nodes
|
||||
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
|
||||
return self.thought_tree.update_node(thoughts, current_node=current_node)
|
||||
|
||||
|
||||
async def evaluate_node(self, node, parent_value) -> None:
|
||||
"""
|
||||
Evaluate a node and update its status and value.
|
||||
|
|
@ -78,14 +77,14 @@ class ThoughtSolverBase(BaseModel):
|
|||
"""
|
||||
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
|
||||
evaluation = await self.llm.aask(msg=eval_prompt)
|
||||
|
||||
|
||||
value = self.config.evaluator(evaluation, **{"node_id": node.id})
|
||||
status = self.config.evaluator.status_verify(value)
|
||||
|
||||
|
||||
node.update_valid_status(status=status)
|
||||
# 累计分数
|
||||
node.update_value(parent_value + value)
|
||||
|
||||
|
||||
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
|
||||
"""
|
||||
Select nodes based on the configured selection method.
|
||||
|
|
@ -100,12 +99,12 @@ class ThoughtSolverBase(BaseModel):
|
|||
if self.config.method_select == MethodSelect.SAMPLE:
|
||||
raise NotImplementedError
|
||||
elif self.config.method_select == MethodSelect.GREEDY:
|
||||
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[:self.config.n_select_sample]
|
||||
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
|
||||
for node in thought_nodes:
|
||||
if node not in select_nodes:
|
||||
node.parent = None # 从树中删除节点
|
||||
return select_nodes
|
||||
|
||||
|
||||
def update_solution(self):
|
||||
"""
|
||||
Select the result with the highest score.
|
||||
|
|
@ -135,16 +134,16 @@ class BFSSolver(ThoughtSolverBase):
|
|||
current_nodes = [root]
|
||||
for step in range(self.config.max_steps):
|
||||
solutions = await self._bfs_build(current_nodes)
|
||||
|
||||
|
||||
selected_nodes = self.select_nodes(solutions)
|
||||
current_nodes = selected_nodes
|
||||
|
||||
|
||||
self.thought_tree.show()
|
||||
|
||||
|
||||
best_solution, best_solution_path = self.update_solution()
|
||||
logger.info(f"best solution is: {best_solution_path}")
|
||||
return best_solution_path
|
||||
|
||||
|
||||
async def _bfs_build(self, current_nodes):
|
||||
"""
|
||||
Build the thought tree using Breadth-First Search (BFS) strategy.
|
||||
|
|
@ -160,15 +159,16 @@ class BFSSolver(ThoughtSolverBase):
|
|||
current_state = self.config.parser(node.name)
|
||||
current_value = node.value
|
||||
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
|
||||
|
||||
|
||||
thought_nodes_list = await asyncio.gather(*tasks)
|
||||
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
|
||||
return solutions
|
||||
|
||||
|
||||
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
|
||||
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
|
||||
await asyncio.gather(
|
||||
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes))
|
||||
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
|
||||
)
|
||||
return thought_nodes
|
||||
|
||||
|
||||
|
|
@ -186,7 +186,6 @@ class DFSSolver(ThoughtSolverBase):
|
|||
impossible_state_cnt = 0
|
||||
node = root_node
|
||||
for step in range(self.max_steps):
|
||||
|
||||
current_state = self.config.parser(node.name)
|
||||
current_value = node.value
|
||||
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
|
||||
|
|
@ -199,9 +198,9 @@ class DFSSolver(ThoughtSolverBase):
|
|||
node = thought_nodes[0]
|
||||
_solution_path = self.thought_tree.parse_node_path(node)
|
||||
self.thought_tree.show()
|
||||
|
||||
|
||||
return _solution_path
|
||||
|
||||
|
||||
async def solve(self, init_prompt="", root=ThoughtNode("")):
|
||||
"""
|
||||
Solve the problem using Depth-First Search (DFS) strategy.
|
||||
|
|
@ -217,7 +216,7 @@ class DFSSolver(ThoughtSolverBase):
|
|||
for n in range(self.config.n_solution_sample):
|
||||
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
|
||||
await self._dfs(root)
|
||||
|
||||
|
||||
best_solution, best_solution_path = self.update_solution()
|
||||
logger.info(f"best solution is: {best_solution_path}")
|
||||
return best_solution_path
|
||||
|
|
@ -232,14 +231,14 @@ class TreeofThought(BaseModel):
|
|||
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
|
||||
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
|
||||
strategy: Strategy = Field(default=Strategy.BFS)
|
||||
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(**kwargs)
|
||||
self._initialize_solver(self.strategy)
|
||||
|
||||
|
||||
def _initialize_solver(self, strategy):
|
||||
"""
|
||||
Initialize the solver based on the chosen strategy.
|
||||
|
|
@ -258,7 +257,7 @@ class TreeofThought(BaseModel):
|
|||
self.solver = MCTSSolver(config=self.config)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
|
||||
|
||||
|
||||
async def solve(self, init_prompt=""):
|
||||
"""
|
||||
Solve the problem using the specified strategy.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue