update: 增加注释

This commit is contained in:
zhanglei 2024-06-28 15:36:57 +08:00
parent 8d973c2cf7
commit 853f5fc400
2 changed files with 7 additions and 0 deletions

View file

@ -86,6 +86,8 @@ class DataAnalyst(DataInterpreter):
async with ThoughtReporter(enable_llm_stream=True):
rsp = await self.llm.aask(context)
# 临时方案待role zero的版本完成可将本注释内的代码直接替换掉
# -------------开始---------------
try:
commands = CodeParser.parse_code(block=None, lang="json", text=rsp)
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
@ -99,6 +101,7 @@ class DataAnalyst(DataInterpreter):
# 为了对LLM不按格式生成进行容错
if isinstance(commands, dict):
commands = commands["commands"] if "commands" in commands else [commands]
# -------------结束---------------
self.rc.working_memory.add(Message(content=rsp, role="assistant"))
await run_commands(self, commands, self.rc.working_memory)

View file

@ -136,6 +136,8 @@ class ToolRecommender(BaseModel):
)
rsp = await LLM().aask(prompt, stream=False)
# 临时方案待role zero的版本完成可将本注释内的代码直接替换掉
# -------------开始---------------
try:
ranked_tools = CodeParser.parse_code(block=None, lang="json", text=rsp)
ranked_tools = json.loads(repair_llm_raw_output(output=ranked_tools, req_keys=[None], repair_type=RepairType.JSON))
@ -149,6 +151,8 @@ class ToolRecommender(BaseModel):
# 为了对LLM不按格式生成进行容错
if isinstance(ranked_tools, dict):
ranked_tools = list(ranked_tools.values())[0]
# -------------结束---------------
valid_tools = validate_tool_names(ranked_tools)
return list(valid_tools.values())[:topk]