From 1060292cbf76620281bfd08532bfc24fcb84f194 Mon Sep 17 00:00:00 2001 From: geekan Date: Wed, 3 Jan 2024 14:21:21 +0800 Subject: [PATCH] refine code --- metagpt/strategy/tot.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/metagpt/strategy/tot.py b/metagpt/strategy/tot.py index 4f33698bf..e67d272c7 100644 --- a/metagpt/strategy/tot.py +++ b/metagpt/strategy/tot.py @@ -2,6 +2,8 @@ # @Date : 12/23/2023 4:51 PM # @Author : stellahong (stellahong@fuzhi.ai) # @Desc : +from __future__ import annotations + import asyncio from typing import Any, List @@ -31,7 +33,7 @@ Output a list of jsons following the format: class ThoughtSolverBase(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) - thought_tree: str = "" + thought_tree: ThoughtTree | None = None llm: BaseLLM = Field(default_factory=LLM, exclude=True) config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig) @@ -60,7 +62,7 @@ class ThoughtSolverBase(BaseModel): current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample} ) rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT) - thoughts = CodeParser.parse_code(block=None, text=rsp) + thoughts = CodeParser.parse_code(block="", text=rsp) thoughts = eval(thoughts) # fixme 避免不跟随,生成过多nodes # valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample] @@ -97,15 +99,16 @@ class ThoughtSolverBase(BaseModel): Returns: List[ThoughtNode]: List of selected nodes. """ - # selection + # nodes to be selected + nodes = [] if self.config.method_select == MethodSelect.SAMPLE: raise NotImplementedError elif self.config.method_select == MethodSelect.GREEDY: - select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample] + nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample] for node in thought_nodes: - if node not in select_nodes: + if node not in nodes: node.parent = None # 从树中删除节点 - return select_nodes + return nodes def update_solution(self): """