Merge branch 'json_repair' into 'mgx_ops'

update: 1. DA增加json容错 2. 提示增加强制要求以[开头(现在平台流式解析json的,而json容错是json全部输出完才有的,所以如果json不是以[开头的话,流式就失效了)

See merge request pub/MetaGPT!194
This commit is contained in:
林义章 2024-06-28 08:31:23 +00:00
commit b4cf9cee3b
5 changed files with 47 additions and 8 deletions

View file

@ -40,4 +40,5 @@ Some text indicating your thoughts, such as how you should update the plan statu
...
]
```
Notice: your output JSON data section must start with **```json [**
"""

View file

@ -49,8 +49,8 @@ Some text indicating your thoughts, such as how you should update the plan statu
...
]
```
Notice: your output JSON data section must start with **```json [**
"""
JSON_REPAIR_PROMPT = """
## json data
{json_data}

View file

@ -9,6 +9,7 @@ from metagpt.actions import Action
from metagpt.actions.di.write_analysis_code import WriteAnalysisCode
from metagpt.logs import logger
from metagpt.prompts.di.data_analyst import CMD_PROMPT
from metagpt.prompts.di.role_zero import JSON_REPAIR_PROMPT
from metagpt.roles.di.data_interpreter import DataInterpreter
from metagpt.schema import Message, TaskResult
from metagpt.strategy.experience_retriever import KeywordExpRetriever
@ -21,6 +22,7 @@ from metagpt.strategy.thinking_command import (
from metagpt.tools.tool_recommend import BM25ToolRecommender
from metagpt.utils.common import CodeParser
from metagpt.utils.report import ThoughtReporter
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType
class DataAnalyst(DataInterpreter):
@ -83,11 +85,26 @@ class DataAnalyst(DataInterpreter):
# print(*context, sep="\n" + "*" * 5 + "\n")
async with ThoughtReporter(enable_llm_stream=True):
rsp = await self.llm.aask(context)
self.commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=rsp))
# 临时方案待role zero的版本完成可将本注释内的代码直接替换掉
# -------------开始---------------
try:
commands = CodeParser.parse_code(block=None, lang="json", text=rsp)
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError as e:
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=rsp))
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except Exception as e:
tb = traceback.format_exc()
print(tb)
# 为了对LLM不按格式生成进行容错
if isinstance(commands, dict):
commands = commands["commands"] if "commands" in commands else [commands]
# -------------结束---------------
self.rc.working_memory.add(Message(content=rsp, role="assistant"))
await run_commands(self, self.commands, self.rc.working_memory)
await run_commands(self, commands, self.rc.working_memory)
return bool(self.rc.todo)
async def _act(self) -> Message:

View file

@ -100,7 +100,8 @@ def run_plan_command(role: Role, cmd: list[dict]):
elif cmd["command_name"] == Command.FINISH_CURRENT_TASK.cmd_name:
if role.planner.plan.is_plan_finished():
return
role.planner.plan.current_task.update_task_result(task_result=role.task_result)
if role.task_result:
role.planner.plan.current_task.update_task_result(task_result=role.task_result)
role.planner.plan.finish_current_task()
role.rc.working_memory.clear()

View file

@ -1,6 +1,7 @@
from __future__ import annotations
import json
import traceback
from typing import Any
import numpy as np
@ -9,11 +10,13 @@ from rank_bm25 import BM25Okapi
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.prompts.di.role_zero import JSON_REPAIR_PROMPT
from metagpt.schema import Plan
from metagpt.tools import TOOL_REGISTRY
from metagpt.tools.tool_data_type import Tool
from metagpt.tools.tool_registry import validate_tool_names
from metagpt.utils.common import CodeParser
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
TOOL_INFO_PROMPT = """
## Capabilities
@ -132,8 +135,25 @@ class ToolRecommender(BaseModel):
topk=topk,
)
rsp = await LLM().aask(prompt, stream=False)
rsp = CodeParser.parse_code(text=rsp)
ranked_tools = json.loads(rsp)
# 临时方案待role zero的版本完成可将本注释内的代码直接替换掉
# -------------开始---------------
try:
ranked_tools = CodeParser.parse_code(block=None, lang="json", text=rsp)
ranked_tools = json.loads(
repair_llm_raw_output(output=ranked_tools, req_keys=[None], repair_type=RepairType.JSON)
)
except json.JSONDecodeError:
ranked_tools = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=rsp))
ranked_tools = json.loads(CodeParser.parse_code(block=None, lang="json", text=ranked_tools))
except Exception:
tb = traceback.format_exc()
print(tb)
# 为了对LLM不按格式生成进行容错
if isinstance(ranked_tools, dict):
ranked_tools = list(ranked_tools.values())[0]
# -------------结束---------------
valid_tools = validate_tool_names(ranked_tools)