From 86d497a0bd274d881b5d733e664527f98d702712 Mon Sep 17 00:00:00 2001 From: stellahsr Date: Thu, 28 Dec 2023 16:31:24 +0800 Subject: [PATCH] update docstring --- metagpt/strategy/base.py | 67 ++++++++++++++++++++++++++++------------ metagpt/strategy/tot.py | 61 ++++++++++++++++++------------------ 2 files changed, 77 insertions(+), 51 deletions(-) diff --git a/metagpt/strategy/base.py b/metagpt/strategy/base.py index fb2adc8f2..5b535ab12 100644 --- a/metagpt/strategy/base.py +++ b/metagpt/strategy/base.py @@ -4,21 +4,20 @@ # @Desc : from typing import List -from pydantic import BaseModel from anytree import Node, RenderTree - +from pydantic import BaseModel class BaseParser(BaseModel): def __call__(self, *args, **kwargs): raise NotImplementedError - + def propose(self, current_state: str, **kwargs) -> str: raise NotImplementedError - + def sample(self, current_state: str, **kwargs) -> str: raise NotImplementedError - + def value(self, input: str, **kwargs) -> str: raise NotImplementedError @@ -26,22 +25,23 @@ class BaseParser(BaseModel): class BaseEvaluator(BaseModel): def __call__(self, *args, **kwargs): raise NotImplementedError - + def status_verify(self, *args, **kwargs): raise NotImplementedError - + + class ThoughtNode(Node): """A node representing a thought in the thought tree.""" - + name: str = "" value: int = 0 id: int = 0 valid_status: bool = True - + def update_value(self, value) -> None: """Update the value of the thought node.""" self.value = value - + def update_valid_status(self, status) -> None: """Update the validity status of the thought node.""" self.valid_status = status @@ -49,33 +49,60 @@ class ThoughtNode(Node): class ThoughtTree(RenderTree): """A tree structure to represent thoughts.""" - + @property def all_nodes(self) -> List[ThoughtNode]: - """Get a list of all nodes in the thought tree.""" + """ + Get a list of all nodes in the thought tree. + + Returns: + List[ThoughtNode]: A list containing all nodes in the thought tree. + """ all_nodes = [node for _, _, node in self] return all_nodes - + def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]: - """Update the tree with new thoughts.""" + """ + Update the tree with new thoughts. + + Args: + thought (List[dict]): A list of dictionaries representing thought information. + current_node (ThoughtNode): The current node under which new thoughts will be added. + + Returns: + List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes. + """ nodes = [] for node_info in thought: - node = ThoughtNode(name=node_info["node_state_instruction"], parent=current_node, - id=int(node_info["node_id"])) + node = ThoughtNode( + name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"]) + ) nodes.append(node) return nodes - + def parse_node_path(self, node) -> List[str]: - """Parse the path of the given thought node.""" + """ + Parse and retrieve the hierarchical path of the given thought node. + + This method traverses the parent nodes of the provided 'node' and constructs + the full path from the root node to the given node. + + Args: + node: The thought node for which the hierarchical path needs to be parsed. + + Returns: + List[str]: A list representing the full hierarchical path of the given thought node. + The list is ordered from the root node to the provided node. + """ full_node_path = [] while node is not None: full_node_path.append(node.name) node = node.parent full_node_path.reverse() return full_node_path - + def show(self) -> None: """Print the updated tree.""" print("\nUpdated Tree:") for pre, _, node in self: - print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}") \ No newline at end of file + print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}") diff --git a/metagpt/strategy/tot.py b/metagpt/strategy/tot.py index 8f4d129d8..7f080fa69 100644 --- a/metagpt/strategy/tot.py +++ b/metagpt/strategy/tot.py @@ -3,18 +3,16 @@ # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : import asyncio -import json from typing import Any, List -from functools import wraps from pydantic import BaseModel, Field from metagpt.llm import LLM -from metagpt.provider.base_gpt_api import BaseGPTAPI from metagpt.logs import logger +from metagpt.provider.base_gpt_api import BaseGPTAPI +from metagpt.strategy.base import ThoughtNode, ThoughtTree +from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig from metagpt.utils.common import CodeParser -from metagpt.strategy.tot_schema import ThoughtSolverConfig, Strategy, MethodSelect -from metagpt.strategy.base import ThoughtNode, ThoughtTree, BaseParser, BaseEvaluator OUTPUT_FORMAT = """ Output a list of jsons following the format: @@ -34,17 +32,17 @@ class ThoughtSolverBase(BaseModel): thought_tree: str = "" llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True) config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig) - + def __init__(self, **kwargs: Any): super().__init__(**kwargs) self.llm.use_system_prompt = False - + async def solve(self, init_prompt): """ Solve method for subclasses to implement. """ raise NotImplementedError("Subclasses must implement the solve method") - + async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]: """ Generate children thoughts based on the current state. @@ -56,15 +54,16 @@ class ThoughtSolverBase(BaseModel): Returns: List[ThoughtNode]: List of nodes representing the generated thoughts. """ - state_prompt = self.config.parser.propose(current_state=current_state, - **{"n_generate_sample": self.config.n_generate_sample}) + state_prompt = self.config.parser.propose( + current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample} + ) rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT) thoughts = CodeParser.parse_code(block=None, text=rsp) thoughts = eval(thoughts) # fixme 避免不跟随,生成过多nodes # valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample] return self.thought_tree.update_node(thoughts, current_node=current_node) - + async def evaluate_node(self, node, parent_value) -> None: """ Evaluate a node and update its status and value. @@ -78,14 +77,14 @@ class ThoughtSolverBase(BaseModel): """ eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id}) evaluation = await self.llm.aask(msg=eval_prompt) - + value = self.config.evaluator(evaluation, **{"node_id": node.id}) status = self.config.evaluator.status_verify(value) - + node.update_valid_status(status=status) # 累计分数 node.update_value(parent_value + value) - + def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]: """ Select nodes based on the configured selection method. @@ -100,12 +99,12 @@ class ThoughtSolverBase(BaseModel): if self.config.method_select == MethodSelect.SAMPLE: raise NotImplementedError elif self.config.method_select == MethodSelect.GREEDY: - select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[:self.config.n_select_sample] + select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample] for node in thought_nodes: if node not in select_nodes: node.parent = None # 从树中删除节点 return select_nodes - + def update_solution(self): """ Select the result with the highest score. @@ -135,16 +134,16 @@ class BFSSolver(ThoughtSolverBase): current_nodes = [root] for step in range(self.config.max_steps): solutions = await self._bfs_build(current_nodes) - + selected_nodes = self.select_nodes(solutions) current_nodes = selected_nodes - + self.thought_tree.show() - + best_solution, best_solution_path = self.update_solution() logger.info(f"best solution is: {best_solution_path}") return best_solution_path - + async def _bfs_build(self, current_nodes): """ Build the thought tree using Breadth-First Search (BFS) strategy. @@ -160,15 +159,16 @@ class BFSSolver(ThoughtSolverBase): current_state = self.config.parser(node.name) current_value = node.value tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node)) - + thought_nodes_list = await asyncio.gather(*tasks) solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes] return solutions - + async def generate_and_evaluate_nodes(self, current_state, current_value, node): thought_nodes = await self.generate_thoughts(current_state, current_node=node) await asyncio.gather( - *(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)) + *(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes) + ) return thought_nodes @@ -186,7 +186,6 @@ class DFSSolver(ThoughtSolverBase): impossible_state_cnt = 0 node = root_node for step in range(self.max_steps): - current_state = self.config.parser(node.name) current_value = node.value thought_nodes = await self.generate_thoughts(current_state, current_node=node) @@ -199,9 +198,9 @@ class DFSSolver(ThoughtSolverBase): node = thought_nodes[0] _solution_path = self.thought_tree.parse_node_path(node) self.thought_tree.show() - + return _solution_path - + async def solve(self, init_prompt="", root=ThoughtNode("")): """ Solve the problem using Depth-First Search (DFS) strategy. @@ -217,7 +216,7 @@ class DFSSolver(ThoughtSolverBase): for n in range(self.config.n_solution_sample): # fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索 await self._dfs(root) - + best_solution, best_solution_path = self.update_solution() logger.info(f"best solution is: {best_solution_path}") return best_solution_path @@ -232,14 +231,14 @@ class TreeofThought(BaseModel): config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig) solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase) strategy: Strategy = Field(default=Strategy.BFS) - + class Config: arbitrary_types_allowed = True - + def __init__(self, **kwargs: Any): super().__init__(**kwargs) self._initialize_solver(self.strategy) - + def _initialize_solver(self, strategy): """ Initialize the solver based on the chosen strategy. @@ -258,7 +257,7 @@ class TreeofThought(BaseModel): self.solver = MCTSSolver(config=self.config) else: raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!") - + async def solve(self, init_prompt=""): """ Solve the problem using the specified strategy.