update docstring

This commit is contained in:
stellahsr 2023-12-28 16:31:24 +08:00
parent e94ccbf631
commit 86d497a0bd
2 changed files with 77 additions and 51 deletions

View file

@ -4,21 +4,20 @@
# @Desc :
from typing import List
from pydantic import BaseModel
from anytree import Node, RenderTree
from pydantic import BaseModel
class BaseParser(BaseModel):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def propose(self, current_state: str, **kwargs) -> str:
raise NotImplementedError
def sample(self, current_state: str, **kwargs) -> str:
raise NotImplementedError
def value(self, input: str, **kwargs) -> str:
raise NotImplementedError
@ -26,22 +25,23 @@ class BaseParser(BaseModel):
class BaseEvaluator(BaseModel):
def __call__(self, *args, **kwargs):
raise NotImplementedError
def status_verify(self, *args, **kwargs):
raise NotImplementedError
class ThoughtNode(Node):
"""A node representing a thought in the thought tree."""
name: str = ""
value: int = 0
id: int = 0
valid_status: bool = True
def update_value(self, value) -> None:
"""Update the value of the thought node."""
self.value = value
def update_valid_status(self, status) -> None:
"""Update the validity status of the thought node."""
self.valid_status = status
@ -49,33 +49,60 @@ class ThoughtNode(Node):
class ThoughtTree(RenderTree):
"""A tree structure to represent thoughts."""
@property
def all_nodes(self) -> List[ThoughtNode]:
"""Get a list of all nodes in the thought tree."""
"""
Get a list of all nodes in the thought tree.
Returns:
List[ThoughtNode]: A list containing all nodes in the thought tree.
"""
all_nodes = [node for _, _, node in self]
return all_nodes
def update_node(self, thought: List[dict] = [], current_node: ThoughtNode = None) -> List[ThoughtNode]:
"""Update the tree with new thoughts."""
"""
Update the tree with new thoughts.
Args:
thought (List[dict]): A list of dictionaries representing thought information.
current_node (ThoughtNode): The current node under which new thoughts will be added.
Returns:
List[ThoughtNode]: A list of ThoughtNode instances representing the updated tree nodes.
"""
nodes = []
for node_info in thought:
node = ThoughtNode(name=node_info["node_state_instruction"], parent=current_node,
id=int(node_info["node_id"]))
node = ThoughtNode(
name=node_info["node_state_instruction"], parent=current_node, id=int(node_info["node_id"])
)
nodes.append(node)
return nodes
def parse_node_path(self, node) -> List[str]:
"""Parse the path of the given thought node."""
"""
Parse and retrieve the hierarchical path of the given thought node.
This method traverses the parent nodes of the provided 'node' and constructs
the full path from the root node to the given node.
Args:
node: The thought node for which the hierarchical path needs to be parsed.
Returns:
List[str]: A list representing the full hierarchical path of the given thought node.
The list is ordered from the root node to the provided node.
"""
full_node_path = []
while node is not None:
full_node_path.append(node.name)
node = node.parent
full_node_path.reverse()
return full_node_path
def show(self) -> None:
"""Print the updated tree."""
print("\nUpdated Tree:")
for pre, _, node in self:
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")
print(f"{pre}{node.name}, value: {node.value}, valid_status: {node.valid_status}")

View file

@ -3,18 +3,16 @@
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
import asyncio
import json
from typing import Any, List
from functools import wraps
from pydantic import BaseModel, Field
from metagpt.llm import LLM
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.logs import logger
from metagpt.provider.base_gpt_api import BaseGPTAPI
from metagpt.strategy.base import ThoughtNode, ThoughtTree
from metagpt.strategy.tot_schema import MethodSelect, Strategy, ThoughtSolverConfig
from metagpt.utils.common import CodeParser
from metagpt.strategy.tot_schema import ThoughtSolverConfig, Strategy, MethodSelect
from metagpt.strategy.base import ThoughtNode, ThoughtTree, BaseParser, BaseEvaluator
OUTPUT_FORMAT = """
Output a list of jsons following the format:
@ -34,17 +32,17 @@ class ThoughtSolverBase(BaseModel):
thought_tree: str = ""
llm: BaseGPTAPI = Field(default_factory=LLM, exclude=True)
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self.llm.use_system_prompt = False
async def solve(self, init_prompt):
"""
Solve method for subclasses to implement.
"""
raise NotImplementedError("Subclasses must implement the solve method")
async def generate_thoughts(self, current_state="", current_node=None) -> List[ThoughtNode]:
"""
Generate children thoughts based on the current state.
@ -56,15 +54,16 @@ class ThoughtSolverBase(BaseModel):
Returns:
List[ThoughtNode]: List of nodes representing the generated thoughts.
"""
state_prompt = self.config.parser.propose(current_state=current_state,
**{"n_generate_sample": self.config.n_generate_sample})
state_prompt = self.config.parser.propose(
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
)
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
thoughts = CodeParser.parse_code(block=None, text=rsp)
thoughts = eval(thoughts)
# fixme 避免不跟随生成过多nodes
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
return self.thought_tree.update_node(thoughts, current_node=current_node)
async def evaluate_node(self, node, parent_value) -> None:
"""
Evaluate a node and update its status and value.
@ -78,14 +77,14 @@ class ThoughtSolverBase(BaseModel):
"""
eval_prompt = self.config.parser.value(input=node.name, **{"node_id": node.id})
evaluation = await self.llm.aask(msg=eval_prompt)
value = self.config.evaluator(evaluation, **{"node_id": node.id})
status = self.config.evaluator.status_verify(value)
node.update_valid_status(status=status)
# 累计分数
node.update_value(parent_value + value)
def select_nodes(self, thought_nodes: List[ThoughtNode]) -> List[ThoughtNode]:
"""
Select nodes based on the configured selection method.
@ -100,12 +99,12 @@ class ThoughtSolverBase(BaseModel):
if self.config.method_select == MethodSelect.SAMPLE:
raise NotImplementedError
elif self.config.method_select == MethodSelect.GREEDY:
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[:self.config.n_select_sample]
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
for node in thought_nodes:
if node not in select_nodes:
node.parent = None # 从树中删除节点
return select_nodes
def update_solution(self):
"""
Select the result with the highest score.
@ -135,16 +134,16 @@ class BFSSolver(ThoughtSolverBase):
current_nodes = [root]
for step in range(self.config.max_steps):
solutions = await self._bfs_build(current_nodes)
selected_nodes = self.select_nodes(solutions)
current_nodes = selected_nodes
self.thought_tree.show()
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
async def _bfs_build(self, current_nodes):
"""
Build the thought tree using Breadth-First Search (BFS) strategy.
@ -160,15 +159,16 @@ class BFSSolver(ThoughtSolverBase):
current_state = self.config.parser(node.name)
current_value = node.value
tasks.append(self.generate_and_evaluate_nodes(current_state, current_value, node))
thought_nodes_list = await asyncio.gather(*tasks)
solutions = [child_node for thought_nodes in thought_nodes_list for child_node in thought_nodes]
return solutions
async def generate_and_evaluate_nodes(self, current_state, current_value, node):
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
await asyncio.gather(
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes))
*(self.evaluate_node(child_node, parent_value=current_value) for child_node in thought_nodes)
)
return thought_nodes
@ -186,7 +186,6 @@ class DFSSolver(ThoughtSolverBase):
impossible_state_cnt = 0
node = root_node
for step in range(self.max_steps):
current_state = self.config.parser(node.name)
current_value = node.value
thought_nodes = await self.generate_thoughts(current_state, current_node=node)
@ -199,9 +198,9 @@ class DFSSolver(ThoughtSolverBase):
node = thought_nodes[0]
_solution_path = self.thought_tree.parse_node_path(node)
self.thought_tree.show()
return _solution_path
async def solve(self, init_prompt="", root=ThoughtNode("")):
"""
Solve the problem using Depth-First Search (DFS) strategy.
@ -217,7 +216,7 @@ class DFSSolver(ThoughtSolverBase):
for n in range(self.config.n_solution_sample):
# fixme: 需要产生回退,当前节点不可用时回退到父节点,产生新的节点继续探索
await self._dfs(root)
best_solution, best_solution_path = self.update_solution()
logger.info(f"best solution is: {best_solution_path}")
return best_solution_path
@ -232,14 +231,14 @@ class TreeofThought(BaseModel):
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
solver: ThoughtSolverBase = Field(default_factory=ThoughtSolverBase)
strategy: Strategy = Field(default=Strategy.BFS)
class Config:
arbitrary_types_allowed = True
def __init__(self, **kwargs: Any):
super().__init__(**kwargs)
self._initialize_solver(self.strategy)
def _initialize_solver(self, strategy):
"""
Initialize the solver based on the chosen strategy.
@ -258,7 +257,7 @@ class TreeofThought(BaseModel):
self.solver = MCTSSolver(config=self.config)
else:
raise NotImplementedError(f"Invalid strategy: {strategy}, only support BFS/DFS/MCTS currently!")
async def solve(self, init_prompt=""):
"""
Solve the problem using the specified strategy.