refine code

This commit is contained in:
geekan 2024-01-03 14:21:21 +08:00
parent feb89ec17f
commit 1060292cbf

View file

@ -2,6 +2,8 @@
# @Date : 12/23/2023 4:51 PM
# @Author : stellahong (stellahong@fuzhi.ai)
# @Desc :
from __future__ import annotations
import asyncio
from typing import Any, List
@ -31,7 +33,7 @@ Output a list of jsons following the format:
class ThoughtSolverBase(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
thought_tree: str = ""
thought_tree: ThoughtTree | None = None
llm: BaseLLM = Field(default_factory=LLM, exclude=True)
config: ThoughtSolverConfig = Field(default_factory=ThoughtSolverConfig)
@ -60,7 +62,7 @@ class ThoughtSolverBase(BaseModel):
current_state=current_state, **{"n_generate_sample": self.config.n_generate_sample}
)
rsp = await self.llm.aask(msg=state_prompt + "\n" + OUTPUT_FORMAT)
thoughts = CodeParser.parse_code(block=None, text=rsp)
thoughts = CodeParser.parse_code(block="", text=rsp)
thoughts = eval(thoughts)
# fixme 避免不跟随生成过多nodes
# valid_thoughts = [_node for idx, _node in enumerate(thoughts) if idx < self.n_generate_sample]
@ -97,15 +99,16 @@ class ThoughtSolverBase(BaseModel):
Returns:
List[ThoughtNode]: List of selected nodes.
"""
# selection
# nodes to be selected
nodes = []
if self.config.method_select == MethodSelect.SAMPLE:
raise NotImplementedError
elif self.config.method_select == MethodSelect.GREEDY:
select_nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
nodes = sorted(thought_nodes, key=lambda x: x.value, reverse=True)[: self.config.n_select_sample]
for node in thought_nodes:
if node not in select_nodes:
if node not in nodes:
node.parent = None # 从树中删除节点
return select_nodes
return nodes
def update_solution(self):
"""