Merge pull request #500 from better629/feat_repair_llmoutput

Feat repair llmoutput
This commit is contained in:
Sirui Hong 2023-12-12 15:16:34 +08:00 committed by GitHub
commit 5a7b0115ea
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
16 changed files with 1012 additions and 23 deletions

View file

@ -34,6 +34,15 @@ RPM: 10
#### if zhipuai from `https://open.bigmodel.cn`. You can set here or export API_KEY="YOUR_API_KEY"
# ZHIPUAI_API_KEY: "YOUR_API_KEY"
#### if use self-host open llm model with openai-compatible interface
#OPEN_LLM_API_BASE: "http://127.0.0.1:8000/v1"
#OPEN_LLM_API_MODEL: "llama2-13b"
#
##### if use Fireworks api
#FIREWORKS_API_KEY: "YOUR_API_KEY"
#FIREWORKS_API_BASE: "https://api.fireworks.ai/inference/v1"
#FIREWORKS_API_MODEL: "YOUR_LLM_MODEL" # example, accounts/fireworks/models/llama-v2-13b-chat
#### for Search
## Supported values: serpapi/google/serper/ddg
@ -94,4 +103,9 @@ MODEL_FOR_RESEARCHER_REPORT: gpt-3.5-turbo-16k
### browser path for pyppeteer engine, support Chrome, Chromium,MS Edge
#PYPPETEER_EXECUTABLE_PATH: "/usr/bin/google-chrome-stable"
### for repair non-openai LLM's output when parse json-text if PROMPT_FORMAT=json
### due to non-openai LLM's output will not always follow the instruction, so here activate a post-process
### repair operation on the content extracted from LLM's raw output. Warning, it improves the result but not fix all cases.
# REPAIR_LLM_OUTPUT: false
PROMPT_FORMAT: json #json or markdown

View file

@ -5,7 +5,7 @@
@Author : alexanderwu
@File : action.py
"""
import re
from abc import ABC
from typing import Optional
@ -15,7 +15,8 @@ from metagpt.actions.action_output import ActionOutput
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.utils.common import OutputParser
from metagpt.utils.custom_decoder import CustomDecoder
from metagpt.utils.utils import general_after_log
from metagpt.provider.postprecess.llm_output_postprecess import llm_output_postprecess
class Action(ABC):
@ -49,7 +50,11 @@ class Action(ABC):
system_msgs.append(self.prefix)
return await self.llm.aask(prompt, system_msgs)
@retry(wait=wait_random_exponential(min=1, max=60), stop=stop_after_attempt(6))
@retry(
wait=wait_random_exponential(min=1, max=60),
stop=stop_after_attempt(6),
after=general_after_log(logger),
)
async def _aask_v1(
self,
prompt: str,
@ -63,24 +68,16 @@ class Action(ABC):
system_msgs = []
system_msgs.append(self.prefix)
content = await self.llm.aask(prompt, system_msgs)
logger.debug(content)
logger.debug(f"llm raw output:\n{content}")
output_class = ActionOutput.create_model_class(output_class_name, output_data_mapping)
if format == "json":
pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
matches = re.findall(pattern, content, re.DOTALL)
for match in matches:
if match:
content = match
break
parsed_data = CustomDecoder(strict=False).decode(content)
parsed_data = llm_output_postprecess(output=content, schema=output_class.schema(), req_key="[/CONTENT]")
else: # using markdown parser
parsed_data = OutputParser.parse_data_with_mapping(content, output_data_mapping)
logger.debug(parsed_data)
logger.debug(f"parsed_data:\n{parsed_data}")
instruct_content = output_class(**parsed_data)
return ActionOutput(content, instruct_content)

View file

@ -46,10 +46,18 @@ class Config(metaclass=Singleton):
self.openai_api_key = self._get("OPENAI_API_KEY")
self.anthropic_api_key = self._get("Anthropic_API_KEY")
self.zhipuai_api_key = self._get("ZHIPUAI_API_KEY")
self.open_llm_api_base = self._get("OPEN_LLM_API_BASE")
self.open_llm_api_model = self._get("OPEN_LLM_API_MODEL")
self.fireworks_api_key = self._get("FIREWORKS_API_KEY")
if (not self.openai_api_key or "YOUR_API_KEY" == self.openai_api_key) and \
(not self.anthropic_api_key or "YOUR_API_KEY" == self.anthropic_api_key) and \
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first")
(not self.zhipuai_api_key or "YOUR_API_KEY" == self.zhipuai_api_key) and \
(not self.open_llm_api_base) and \
(not self.fireworks_api_key or "YOUR_API_KEY" == self.fireworks_api_key):
raise NotConfiguredException("Set OPENAI_API_KEY or Anthropic_API_KEY or ZHIPUAI_API_KEY first "
"or FIREWORKS_API_KEY or OPEN_LLM_API_BASE")
self.openai_api_base = self._get("OPENAI_API_BASE")
openai_proxy = self._get("OPENAI_PROXY") or self.global_proxy
if openai_proxy:
@ -69,6 +77,9 @@ class Config(metaclass=Singleton):
self.domain = self._get("DOMAIN")
self.spark_url = self._get("SPARK_URL")
self.fireworks_api_base = self._get("FIREWORKS_API_BASE")
self.fireworks_api_model = self._get("FIREWORKS_API_MODEL")
self.claude_api_key = self._get("Anthropic_API_KEY")
self.serpapi_api_key = self._get("SERPAPI_API_KEY")
self.serper_api_key = self._get("SERPER_API_KEY")
@ -93,6 +104,7 @@ class Config(metaclass=Singleton):
self.mermaid_engine = self._get("MERMAID_ENGINE", "nodejs")
self.pyppeteer_executable_path = self._get("PYPPETEER_EXECUTABLE_PATH", "")
self.repair_llm_output = self._get("REPAIR_LLM_OUTPUT", False)
self.prompt_format = self._get("PROMPT_FORMAT", "markdown")
def _init_with_config_files_and_env(self, configs: dict, yaml_file):

View file

@ -6,12 +6,13 @@
@File : llm.py
"""
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.provider.anthropic_api import Claude2 as Claude
from metagpt.provider.openai_api import OpenAIGPTAPI
from metagpt.provider.zhipuai_api import ZhiPuAIGPTAPI
from metagpt.provider.spark_api import SparkAPI
from metagpt.provider.open_llm_api import OpenLLMGPTAPI
from metagpt.provider.fireworks_api import FireWorksGPTAPI
from metagpt.provider.human_provider import HumanProvider
@ -20,12 +21,14 @@ def LLM() -> "BaseGPTAPI":
# TODO a little trick, can use registry to initialize LLM instance further
if CONFIG.openai_api_key:
llm = OpenAIGPTAPI()
elif CONFIG.claude_api_key:
llm = Claude()
elif CONFIG.spark_api_key:
llm = SparkAPI()
elif CONFIG.zhipuai_api_key:
llm = ZhiPuAIGPTAPI()
elif CONFIG.open_llm_api_base:
llm = OpenLLMGPTAPI()
elif CONFIG.fireworks_api_key:
llm = FireWorksGPTAPI()
else:
raise RuntimeError("You should config a LLM configuration first")

View file

@ -0,0 +1,24 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : fireworks.ai's api
import openai
from metagpt.config import CONFIG
from metagpt.provider.openai_api import OpenAIGPTAPI, CostManager, RateLimiter
class FireWorksGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_fireworks(CONFIG)
self.llm = openai
self.model = CONFIG.fireworks_api_model
self.auto_max_tokens = False
self._cost_manager = CostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_fireworks(self, config: "Config"):
openai.api_key = config.fireworks_api_key
openai.api_base = config.fireworks_api_base
self.rpm = int(config.get("RPM", 10))

View file

@ -0,0 +1,47 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : self-host open llm model with openai-compatible interface
import openai
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.provider.openai_api import OpenAIGPTAPI, CostManager, RateLimiter
class OpenLLMCostManager(CostManager):
""" open llm model is self-host, it's free and without cost"""
def update_cost(self, prompt_tokens, completion_tokens, model):
"""
Update the total cost, prompt tokens, and completion tokens.
Args:
prompt_tokens (int): The number of tokens used in the prompt.
completion_tokens (int): The number of tokens used in the completion.
model (str): The model used for the API call.
"""
self.total_prompt_tokens += prompt_tokens
self.total_completion_tokens += completion_tokens
logger.info(
f"Max budget: ${CONFIG.max_budget:.3f} | "
f"prompt_tokens: {prompt_tokens}, completion_tokens: {completion_tokens}"
)
CONFIG.total_cost = self.total_cost
class OpenLLMGPTAPI(OpenAIGPTAPI):
def __init__(self):
self.__init_openllm(CONFIG)
self.llm = openai
self.model = CONFIG.open_llm_api_model
self.auto_max_tokens = False
self._cost_manager = OpenLLMCostManager()
RateLimiter.__init__(self, rpm=self.rpm)
def __init_openllm(self, config: "Config"):
openai.api_key = "sk-xx" # self-host api doesn't need api-key, use the default value
openai.api_base = config.open_llm_api_base
self.rpm = int(config.get("RPM", 10))

View file

@ -0,0 +1,3 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :

View file

@ -0,0 +1,72 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : base llm postprocess plugin to do the operations like repair the raw llm output
from typing import Union
from metagpt.logs import logger
from metagpt.utils.repair_llm_raw_output import RepairType
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, extract_content_from_output, \
retry_parse_json_text
class BasePostPrecessPlugin(object):
model = None # the plugin of the `model`, use to judge in `llm_postprecess`
def run_repair_llm_output(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
"""
repair steps
1. repair the case sensitive problem using the schema's fields
2. extract the content from the req_key pair( xx[REQ_KEY]xxx[/REQ_KEY]xx )
3. repair the invalid json text in the content
4. parse the json text and repair it according to the exception with retry loop
"""
output_class_fields = list(schema["properties"].keys()) # Custom ActionOutput's fields
content = self.run_repair_llm_raw_output(output, req_keys=output_class_fields + [req_key])
content = self.run_extract_content_from_output(content, right_key=req_key)
# # req_keys mocked
content = self.run_repair_llm_raw_output(content, req_keys=[None], repair_type=RepairType.JSON)
parsed_data = self.run_retry_parse_json_text(content)
return parsed_data
def run_repair_llm_raw_output(self, content: str, req_keys: list[str], repair_type: str = None) -> str:
""" inherited class can re-implement the function"""
return repair_llm_raw_output(content, req_keys=req_keys, repair_type=repair_type)
def run_extract_content_from_output(self, content: str, right_key: str) -> str:
""" inherited class can re-implement the function"""
return extract_content_from_output(content, right_key=right_key)
def run_retry_parse_json_text(self, content: str) -> Union[dict, list]:
""" inherited class can re-implement the function"""
logger.info(f"extracted json CONTENT from output:\n{content}")
parsed_data = retry_parse_json_text(output=content) # should use output=content
return parsed_data
def run(self, output: str, schema: dict, req_key: str = "[/CONTENT]") -> Union[dict, list]:
"""
this is used for prompt with a json-format output requirement and outer pair key, like
[REQ_KEY]
{
"Key": "value"
}
[/REQ_KEY]
Args
outer (str): llm raw output
schema: output json schema
req_key: outer pair right key, usually in `[/REQ_KEY]` format
"""
assert len(schema.get("properties")) > 0
assert "/" in req_key
# current, postprocess only deal the repair_llm_raw_output
new_output = self.run_repair_llm_output(
output=output,
schema=schema,
req_key=req_key
)
return new_output

View file

@ -0,0 +1,23 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : the entry of choosing which PostProcessPlugin to deal particular LLM model's output
from typing import Union
from metagpt.provider.postprecess.base_postprecess_plugin import BasePostPrecessPlugin
def llm_output_postprecess(output: str, schema: dict, req_key: str = "[/CONTENT]",
model_name: str = None) -> Union[dict, str]:
"""
default use BasePostPrecessPlugin if there is not matched plugin.
"""
# TODO choose different model's plugin according to the model_name
postprecess_plugin = BasePostPrecessPlugin()
result = postprecess_plugin.run(
output=output,
schema=schema,
req_key=req_key
)
return result

View file

@ -19,6 +19,8 @@ from metagpt.llm import LLM, HumanProvider
from metagpt.logs import logger
from metagpt.memory import Memory, LongTermMemory
from metagpt.schema import Message
from metagpt.utils.repair_llm_raw_output import extract_state_value_from_output
PREFIX_TEMPLATE = """You are a {profile}, named {name}, your goal is {goal}, and the constraint is {constraints}. """
@ -49,6 +51,7 @@ ROLE_TEMPLATE = """Your response should be based on the previous conversation hi
{name}: {result}
"""
class RoleReactMode(str, Enum):
REACT = "react"
BY_ORDER = "by_order"
@ -58,6 +61,7 @@ class RoleReactMode(str, Enum):
def values(cls):
return [item.value for item in cls]
class RoleSetting(BaseModel):
"""Role Settings"""
name: str
@ -79,11 +83,11 @@ class RoleContext(BaseModel):
env: 'Environment' = Field(default=None)
memory: Memory = Field(default_factory=Memory)
long_term_memory: LongTermMemory = Field(default_factory=LongTermMemory)
state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None
state: int = Field(default=-1) # -1 indicates initial or termination state where todo is None
todo: Action = Field(default=None)
watch: set[Type[Action]] = Field(default_factory=set)
news: list[Type[Message]] = Field(default=[])
react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes
react_mode: RoleReactMode = RoleReactMode.REACT # see `Role._set_react_mode` for definitions of the following two attributes
max_react_loop: int = 1
class Config:
@ -127,8 +131,9 @@ class Role:
i = action("", llm=self._llm)
else:
if self._setting.is_human and not isinstance(action.llm, HumanProvider):
logger.warning(f"is_human attribute does not take effect,"
f"as Role's {str(action)} was initialized using LLM, try passing in Action classes instead of initialized instances")
logger.warning(f"is_human attribute does not take effect, "
f"as Role's {str(action)} was initialized using LLM, "
f"try passing in Action classes instead of initialized instances")
i = action
i.set_prefix(self._get_prefix(), self.profile)
self._actions.append(i)
@ -193,6 +198,7 @@ class Role:
n_states=len(self._states) - 1, previous_state=self._rc.state)
# print(prompt)
next_state = await self._llm.aask(prompt)
next_state = extract_state_value_from_output(next_state)
logger.debug(f"{prompt=}")
if (not next_state.isdigit() and next_state != "-1") \
or int(next_state) not in range(-1, len(self._states)):

View file

@ -0,0 +1,59 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : pure async http_client
from typing import Optional, Any, Mapping, Union
from aiohttp.client import DEFAULT_TIMEOUT
import aiohttp
async def apost(url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
as_json: bool = False,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total) -> Union[str, dict]:
async with aiohttp.ClientSession() as session:
async with session.post(
url=url,
params=params,
json=json,
data=data,
headers=headers,
timeout=timeout
) as resp:
if as_json:
data = await resp.json()
else:
data = await resp.read()
data = data.decode(encoding)
return data
async def apost_stream(url: str,
params: Optional[Mapping[str, str]] = None,
json: Any = None,
data: Any = None,
headers: Optional[dict] = None,
encoding: str = "utf-8",
timeout: int = DEFAULT_TIMEOUT.total) -> Any:
"""
usage:
result = astream(url="xx")
async for line in result:
deal_with(line)
"""
async with aiohttp.ClientSession() as session:
async with session.post(
url=url,
params=params,
json=json,
data=data,
headers=headers,
timeout=timeout
) as resp:
async for line in resp.content:
yield line.decode(encoding)

View file

@ -0,0 +1,307 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : repair llm raw output with particular conditions
import copy
from enum import Enum
from typing import Union, Callable
import regex as re
from tenacity import retry, stop_after_attempt, wait_fixed, after_log, RetryCallState
from metagpt.logs import logger
from metagpt.config import CONFIG
from metagpt.utils.custom_decoder import CustomDecoder
class RepairType(Enum):
CS = "case sensitivity"
RKPM = "required key pair missing" # condition like `[key] xx` which lacks `[/key]`
SCM = "special character missing" # Usually the req_key appear in pairs like `[key] xx [/key]`
JSON = "json format"
def repair_case_sensitivity(output: str, req_key: str) -> str:
"""
usually, req_key is the key name of expected json or markdown content, it won't appear in the value part.
fix target string `"Shared Knowledge": ""` but `"Shared knowledge": ""` actually
"""
if req_key in output:
return output
output_lower = output.lower()
req_key_lower = req_key.lower()
if req_key_lower in output_lower:
# find the sub-part index, and replace it with raw req_key
lidx = output_lower.find(req_key_lower)
source = output[lidx: lidx + len(req_key_lower)]
output = output.replace(source, req_key)
logger.info(f"repair_case_sensitivity: {req_key}")
return output
def repair_special_character_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
fix
1. target string `[CONTENT] xx [CONTENT] xxx [CONTENT]` lacks `/` in the last `[CONTENT]`
2. target string `xx [CONTENT] xxx [CONTENT] xxxx` lacks `/` in the last `[CONTENT]`
"""
sc_arr = ["/"]
if req_key in output:
return output
for sc in sc_arr:
req_key_pure = req_key.replace(sc, "")
appear_cnt = output.count(req_key_pure)
if req_key_pure in output and appear_cnt > 1:
# req_key with special_character usually in the tail side
ridx = output.rfind(req_key_pure)
output = f"{output[:ridx]}{req_key}{output[ridx + len(req_key_pure):]}"
logger.info(f"repair_special_character_missing: {sc} in {req_key_pure} as position {ridx}")
return output
def repair_required_key_pair_missing(output: str, req_key: str = "[/CONTENT]") -> str:
"""
implement the req_key pair in the begin or end of the content
req_key format
1. `[req_key]`, and its pair `[/req_key]`
2. `[/req_key]`, and its pair `[req_key]`
"""
sc = "/" # special char
if req_key.startswith("[") and req_key.endswith("]"):
if sc in req_key:
left_key = req_key.replace(sc, "") # `[/req_key]` -> `[req_key]`
right_key = req_key
else:
left_key = req_key
right_key = f"{req_key[0]}{sc}{req_key[1:]}" # `[req_key]` -> `[/req_key]`
if left_key not in output:
output = left_key + "\n" + output
if right_key not in output:
def judge_potential_json(routput: str, left_key: str) -> Union[str, None]:
ridx = routput.rfind(left_key)
if ridx < 0:
return None
sub_output = routput[ridx:]
idx1 = sub_output.rfind("}")
idx2 = sub_output.rindex("]")
idx = idx1 if idx1 >= idx2 else idx2
sub_output = sub_output[: idx+1]
return sub_output
if output.strip().endswith("}") or (output.strip().endswith("]") and not output.strip().endswith(left_key)):
# # avoid [req_key]xx[req_key] case to append [/req_key]
output = output + "\n" + right_key
elif judge_potential_json(output, left_key) and (not output.strip().endswith(left_key)):
sub_content = judge_potential_json(output, left_key)
output = sub_content + "\n" + right_key
return output
def repair_json_format(output: str) -> str:
"""
fix extra `[` or `}` in the end
"""
output = output.strip()
if output.startswith("[{"):
output = output[1:]
logger.info(f"repair_json_format: {'[{'}")
elif output.endswith("}]"):
output = output[:-1]
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
return output
def _repair_llm_raw_output(output: str, req_key: str, repair_type: RepairType = None) -> str:
repair_types = [repair_type] if repair_type else [item for item in RepairType if item not in [RepairType.JSON]]
for repair_type in repair_types:
if repair_type == RepairType.CS:
output = repair_case_sensitivity(output, req_key)
elif repair_type == RepairType.RKPM:
output = repair_required_key_pair_missing(output, req_key)
elif repair_type == RepairType.SCM:
output = repair_special_character_missing(output, req_key)
elif repair_type == RepairType.JSON:
output = repair_json_format(output)
return output
def repair_llm_raw_output(output: str, req_keys: list[str], repair_type: RepairType = None) -> str:
"""
in open-source llm model, it usually can't follow the instruction well, the output may be incomplete,
so here we try to repair it and use all repair methods by default.
typical case
1. case sensitivity
target: "Original Requirements"
output: "Original requirements"
2. special character missing
target: [/CONTENT]
output: [CONTENT]
3. json format
target: { xxx }
output: { xxx }]
"""
if not CONFIG.repair_llm_output:
return output
# do the repairation usually for non-openai models
for req_key in req_keys:
output = _repair_llm_raw_output(output=output,
req_key=req_key,
repair_type=repair_type)
return output
def repair_invalid_json(output: str, error: str) -> str:
"""
repair the situation like there are extra chars like
error examples
example 1. json.decoder.JSONDecodeError: Expecting ',' delimiter: line 154 column 1 (char 2765)
example 2. xxx.JSONDecodeError: Expecting property name enclosed in double quotes: line 14 column 1 (char 266)
"""
pattern = r"line ([0-9]+)"
matches = re.findall(pattern, error, re.DOTALL)
if len(matches) > 0:
line_no = int(matches[0]) - 1
# due to CustomDecoder can handle `"": ''` or `'': ""`, so convert `"""` -> `"`, `'''` -> `'`
output = output.replace('"""', '"').replace("'''", '"')
arr = output.split("\n")
line = arr[line_no].strip()
# different general problems
if line.endswith("],"):
# problem, redundant char `]`
new_line = line.replace("]", "")
elif line.endswith("},") and not output.endswith("},"):
# problem, redundant char `}`
new_line = line.replace("}", "")
elif line.endswith("},") and output.endswith("},"):
new_line = line[:-1]
elif '",' not in line and ',' not in line:
new_line = f'{line}",'
elif "," not in line:
# problem, miss char `,` at the end.
new_line = f"{line},"
elif "," in line and len(line) == 1:
new_line = f'"{line}'
elif '",' in line:
new_line = line[:-2] + "',"
arr[line_no] = new_line
output = "\n".join(arr)
logger.info(f"repair_invalid_json, raw error: {error}")
return output
def run_after_exp_and_passon_next_retry(logger: "loguru.Logger") -> Callable[["RetryCallState"], None]:
def run_and_passon(retry_state: RetryCallState) -> None:
"""
RetryCallState example
{
"start_time":143.098322024,
"retry_object":"<Retrying object at 0x7fabcaca25e0 (stop=<tenacity.stop.stop_after_attempt ... >)>",
"fn":"<function retry_parse_json_text_v2 at 0x7fabcac80ee0>",
"args":"(\"tag:[/CONTENT]\",)", # function input args
"kwargs":{}, # function input kwargs
"attempt_number":1, # retry number
"outcome":"<Future at xxx>", # type(outcome.result()) = "str", type(outcome.exception()) = "class"
"outcome_timestamp":143.098416904,
"idle_for":0,
"next_action":"None"
}
"""
if retry_state.outcome.failed:
if retry_state.args:
# # can't be used as args=retry_state.args
func_param_output = retry_state.args[0]
elif retry_state.kwargs:
func_param_output = retry_state.kwargs.get("output", "")
exp_str = str(retry_state.outcome.exception())
logger.warning(f"parse json from content inside [CONTENT][/CONTENT] failed at retry "
f"{retry_state.attempt_number}, try to fix it, exp: {exp_str}")
repaired_output = repair_invalid_json(func_param_output, exp_str)
retry_state.kwargs["output"] = repaired_output
return run_and_passon
@retry(
stop=stop_after_attempt(3 if CONFIG.repair_llm_output else 0),
wait=wait_fixed(1),
after=run_after_exp_and_passon_next_retry(logger),
)
def retry_parse_json_text(output: str) -> Union[list, dict]:
"""
repair the json-text situation like there are extra chars like [']', '}']
Warning
if CONFIG.repair_llm_output is False, retry _aask_v1 {x=3} times, and the retry_parse_json_text's retry not work
if CONFIG.repair_llm_output is True, the _aask_v1 and the retry_parse_json_text will loop for {x=3*3} times.
it's a two-layer retry cycle
"""
logger.debug(f"output to json decode:\n{output}")
# if CONFIG.repair_llm_output is True, it will try to fix output until the retry break
parsed_data = CustomDecoder(strict=False).decode(output)
return parsed_data
def extract_content_from_output(content: str, right_key: str = "[/CONTENT]"):
""" extract xxx from [CONTENT](xxx)[/CONTENT] using regex pattern """
def re_extract_content(cont: str, pattern: str) -> str:
matches = re.findall(pattern, cont, re.DOTALL)
for match in matches:
if match:
cont = match
break
return cont.strip()
# TODO construct the extract pattern with the `right_key`
raw_content = copy.deepcopy(content)
pattern = r"\[CONTENT\]([\s\S]*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)
if not new_content.startswith("{"):
# TODO find a more general pattern
# # for `[CONTENT]xxx[CONTENT]xxxx[/CONTENT] situation
logger.warning(f"extract_content try another pattern: {pattern}")
if right_key not in new_content:
raw_content = copy.deepcopy(new_content + "\n" + right_key)
# # pattern = r"\[CONTENT\](\s*\{.*?\}\s*)\[/CONTENT\]"
new_content = re_extract_content(raw_content, pattern)
else:
if right_key in new_content:
idx = new_content.find(right_key)
new_content = new_content[:idx]
new_content = new_content.strip()
return new_content
def extract_state_value_from_output(content: str) -> str:
"""
For openai models, they will always return state number. But for open llm models, the instruction result maybe a
long text contain target number, so here add a extraction to improve success rate.
Args:
content (str): llm's output from `Role._think`
"""
content = content.strip() # deal the output cases like " 0", "0\n" and so on.
pattern = r"([0-9])" # TODO find the number using a more proper method not just extract from content using pattern
matches = re.findall(pattern, content, re.DOTALL)
matches = list(set(matches))
state = matches[0] if len(matches) > 0 else "-1"
return state

19
metagpt/utils/utils.py Normal file
View file

@ -0,0 +1,19 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc :
import typing
from tenacity import after_log, _utils
def general_after_log(logger: "loguru.Logger", sec_format: str = "%0.3f") -> typing.Callable[["RetryCallState"], None]:
def log_it(retry_state: "RetryCallState") -> None:
if retry_state.fn is None:
fn_name = "<unknown>"
else:
fn_name = _utils.get_callback_name(retry_state.fn)
logger.error(f"Finished call to '{fn_name}' after {sec_format % retry_state.seconds_since_start}(s), "
f"this was the {_utils.to_ordinal(retry_state.attempt_number)} time calling it. "
f"exp: {retry_state.outcome.exception()}")
return log_it

View file

@ -0,0 +1,38 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of ahttp_client
import pytest
from metagpt.utils.ahttp_client import apost, apost_stream
@pytest.mark.asyncio
async def test_apost():
result = await apost(
url="https://www.baidu.com/"
)
assert "百度一下" in result
result = await apost(
url="http://aider.meizu.com/app/weather/listWeather",
data={"cityIds": "101240101"},
as_json=True
)
assert result["code"] == "200"
@pytest.mark.asyncio
async def test_apost_stream():
result = apost_stream(
url="https://www.baidu.com/"
)
async for line in result:
assert len(line) >= 0
result = apost_stream(
url="http://aider.meizu.com/app/weather/listWeather",
data={"cityIds": "101240101"}
)
async for line in result:
assert len(line) >= 0

View file

@ -6,6 +6,7 @@
@File : test_custom_decoder.py
"""
import pytest
from metagpt.utils.custom_decoder import CustomDecoder
@ -37,6 +38,46 @@ def test_parse_single_quote():
parsed_data = decoder.decode(input_data)
assert 'a"\n b' in parsed_data
input_data = """{
'a': "
b
"
}
"""
with pytest.raises(Exception):
parsed_data = decoder.decode(input_data)
input_data = """{
'a': '
b
'
}
"""
with pytest.raises(Exception):
parsed_data = decoder.decode(input_data)
def test_parse_double_quote():
decoder = CustomDecoder(strict=False)
input_data = """{
"a": "
b
"
}
"""
parsed_data = decoder.decode(input_data)
assert parsed_data["a"] == "\n b\n"
input_data = """{
"a": '
b
'
}
"""
parsed_data = decoder.decode(input_data)
assert parsed_data["a"] == "\n b\n"
def test_parse_triple_double_quote():
# Create a custom JSON decoder
@ -54,6 +95,10 @@ def test_parse_triple_double_quote():
parsed_data = decoder.decode(input_data)
assert parsed_data["a"] == "b"
input_data = "{\"\"\"a\"\"\": '''b'''}"
parsed_data = decoder.decode(input_data)
assert parsed_data["a"] == "b"
def test_parse_triple_single_quote():
# Create a custom JSON decoder

View file

@ -0,0 +1,320 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
# @Desc : unittest of repair_llm_raw_output
from metagpt.config import CONFIG
CONFIG.repair_llm_output = True
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType, repair_invalid_json,\
extract_content_from_output, retry_parse_json_text
def test_repair_case_sensitivity():
raw_output = """{
"Original requirements": "Write a 2048 game",
"search Information": "",
"competitive Quadrant charT": "quadrantChart
Campaign A: [0.3, 0.6]",
"requirement analysis": "The 2048 game should be simple to play"
}"""
target_output = """{
"Original Requirements": "Write a 2048 game",
"Search Information": "",
"Competitive Quadrant Chart": "quadrantChart
Campaign A: [0.3, 0.6]",
"Requirement Analysis": "The 2048 game should be simple to play"
}"""
req_keys = ["Original Requirements", "Search Information", "Competitive Quadrant Chart", "Requirement Analysis"]
output = repair_llm_raw_output(output=raw_output,
req_keys=req_keys)
assert output == target_output
def test_repair_special_character_missing():
raw_output = """[CONTENT]
"Anything UNCLEAR": "No unclear requirements or information."
[CONTENT]"""
target_output = """[CONTENT]
"Anything UNCLEAR": "No unclear requirements or information."
[/CONTENT]"""
req_keys = ["[/CONTENT]"]
output = repair_llm_raw_output(output=raw_output,
req_keys=req_keys)
assert output == target_output
raw_output = """[CONTENT] tag
[CONTENT]
{
"Anything UNCLEAR": "No unclear requirements or information."
}
[CONTENT]"""
target_output = """[CONTENT] tag
[CONTENT]
{
"Anything UNCLEAR": "No unclear requirements or information."
}
[/CONTENT]"""
output = repair_llm_raw_output(output=raw_output,
req_keys=req_keys)
assert output == target_output
raw_output = '[CONTENT] {"a": "b"} [CONTENT]'
target_output = '[CONTENT] {"a": "b"} [/CONTENT]'
output = repair_llm_raw_output(output=raw_output,
req_keys=["[/CONTENT]"])
print("output\n", output)
assert output == target_output
def test_required_key_pair_missing():
raw_output = '[CONTENT] {"a": "b"}'
target_output = '[CONTENT] {"a": "b"}\n[/CONTENT]'
output = repair_llm_raw_output(output=raw_output,
req_keys=["[/CONTENT]"])
assert output == target_output
raw_output = '''[CONTENT]
{
"key": "value"
]'''
target_output = '''[CONTENT]
{
"key": "value"
]
[/CONTENT]'''
output = repair_llm_raw_output(output=raw_output,
req_keys=["[/CONTENT]"])
assert output == target_output
raw_output = '''[CONTENT] tag
[CONTENT]
{
"key": "value"
}
xxx
'''
target_output = '''[CONTENT]
{
"key": "value"
}
[/CONTENT]'''
output = repair_llm_raw_output(output=raw_output,
req_keys=["[/CONTENT]"])
assert output == target_output
def test_repair_json_format():
raw_output = "{ xxx }]"
target_output = "{ xxx }"
output = repair_llm_raw_output(output=raw_output,
req_keys=[None],
repair_type=RepairType.JSON)
assert output == target_output
raw_output = "[{ xxx }"
target_output = "{ xxx }"
output = repair_llm_raw_output(output=raw_output,
req_keys=[None],
repair_type=RepairType.JSON)
assert output == target_output
raw_output = "{ xxx ]"
target_output = "{ xxx }"
output = repair_llm_raw_output(output=raw_output,
req_keys=[None],
repair_type=RepairType.JSON)
assert output == target_output
def test_repair_invalid_json():
raw_output = """{
"key": "value"
},
}"""
target_output = """{
"key": "value"
,
}"""
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 3 column 1")
assert output == target_output
raw_output = """{
"key": "
value
},
}"""
target_output = """{
"key": "
value
",
}"""
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 4 column 1")
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
assert output == target_output
raw_output = """{
"key": '
value
},
}"""
target_output = """{
"key": '
value
',
}"""
output = repair_invalid_json(raw_output, "Expecting ',' delimiter: line 4 column 1")
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
output = repair_invalid_json(output, "Expecting ',' delimiter: line 4 column 1")
assert output == target_output
def test_retry_parse_json_text():
invalid_json_text = """{
"Original Requirements": "Create a 2048 game",
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis"
],
"Requirement Analysis": "The requirements are clear and well-defined"
}"""
target_json = {
"Original Requirements": "Create a 2048 game",
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis",
"Requirement Analysis": "The requirements are clear and well-defined"
}
output = retry_parse_json_text(output=invalid_json_text)
assert output == target_json
invalid_json_text = """{
"Original Requirements": "Create a 2048 game",
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis"
},
"Requirement Analysis": "The requirements are clear and well-defined"
}"""
target_json = {
"Original Requirements": "Create a 2048 game",
"Competitive Quadrant Chart": "quadrantChart\n\ttitle Reach and engagement of campaigns\n\t\tx-axis",
"Requirement Analysis": "The requirements are clear and well-defined"
}
output = retry_parse_json_text(output=invalid_json_text)
assert output == target_json
def test_extract_content_from_output():
"""
cases
xxx [CONTENT] xxxx [/CONTENT]
xxx [CONTENT] xxx [CONTENT] xxxx [/CONTENT]
xxx [CONTENT] xxxx [/CONTENT] xxx [CONTENT][/CONTENT] xxx [CONTENT][/CONTENT] # target pair is the last one
"""
output = 'Sure! Here is the properly formatted JSON output based on the given context:\n\n[CONTENT]\n{\n"' \
'Required Python third-party packages": [\n"pygame==2.0.4",\n"pytest"\n],\n"Required Other language ' \
'third-party packages": [\n"No third-party packages are required."\n],\n"Full API spec": "\nopenapi: ' \
'3.0.0\n\ndescription: A JSON object representing the game state.\n\npaths:\n game:\n get:\n ' \
'summary: Get the current game state.\n responses:\n 200:\n description: Game state.' \
'\n\n moves:\n post:\n summary: Make a move.\n requestBody:\n description: Move to be ' \
'made.\n content:\n applicationjson:\n schema:\n type: object\n ' \
' properties:\n x:\n type: integer\n y:\n ' \
' type: integer\n tile:\n type: object\n ' \
'properties:\n value:\n type: integer\n x:\n ' \
' type: integer\n y:\n type: integer\n\n ' \
'undo-move:\n post:\n summary: Undo the last move.\n responses:\n 200:\n ' \
' description: Undone move.\n\n end-game:\n post:\n summary: End the game.\n responses:\n ' \
' 200:\n description: Game ended.\n\n start-game:\n post:\n summary: Start a new ' \
'game.\n responses:\n 200:\n description: Game started.\n\n game-over:\n get:\n ' \
' summary: Check if the game is over.\n responses:\n 200:\n description: Game ' \
'over.\n 404:\n description: Game not over.\n\n score:\n get:\n summary: Get the ' \
'current score.\n responses:\n 200:\n description: Score.\n\n tile:\n get:\n ' \
'summary: Get a specific tile.\n parameters:\n tile_id:\n type: integer\n ' \
'description: ID of the tile to get.\n responses:\n 200:\n description: Tile.\n\n ' \
'tiles:\n get:\n summary: Get all tiles.\n responses:\n 200:\n description: ' \
'Tiles.\n\n level:\n get:\n summary: Get the current level.\n responses:\n 200:\n ' \
' description: Level.\n\n level-up:\n post:\n summary: Level up.\n responses:\n ' \
'200:\n description: Level up successful.\n\n level-down:\n post:\n summary: Level ' \
'down.\n responses:\n 200:\n description: Level down successful.\n\n restart:\n ' \
'post:\n summary: Restart the game.\n responses:\n 200:\n description: Game ' \
'restarted.\n\n help:\n get:\n summary: Get help.\n responses:\n 200:\n ' \
'description: Help.\n\n version:\n get:\n summary: Get the version of the game.\n ' \
'responses:\n 200:\n description: Version.\n\n}\n\n"Logic Analysis": [\n"game.py",' \
'\n"Contains the game logic."\n],\n"Task list": [\n"game.py",\n"Contains the game logic and should be ' \
'done first."\n],\n"Shared Knowledge": "\n\'game.py\' contains the game logic.\n",\n"Anything ' \
'UNCLEAR": "How to start the game."\n]\n\n[/CONTENT] Great! Your JSON output is properly formatted ' \
'and correctly includes all the required sections. Here\'s a breakdown of what each section ' \
'contains:\n\nRequired Python third-party packages:\n\n* pygame==2.0.4\n* pytest\n\nRequired Other ' \
'language third-party packages:\n\n* No third-party packages are required.\n\nFull API spec:\n\n* ' \
'openapi: 3.0.0\n* description: A JSON object representing the game state.\n* paths:\n + game: ' \
'Get the current game state.\n + moves: Make a move.\n + undo-move: Undo the last move.\n + ' \
'end-game: End the game.\n + start-game: Start a new game.\n + game-over: Check if the game is ' \
'over.\n + score: Get the current score.\n + tile: Get a specific tile.\n + tiles: Get all tiles.\n ' \
'+ level: Get the current level.\n + level-up: Level up.\n + level-down: Level down.\n + restart: ' \
'Restart the game.\n + help: Get help.\n + version: Get the version of the game.\n\nLogic ' \
'Analysis:\n\n* game.py contains the game logic.\n\nTask list:\n\n* game.py contains the game logic ' \
'and should be done first.\n\nShared Knowledge:\n\n* \'game.py\' contains the game logic.\n\nAnything ' \
'UNCLEAR:\n\n* How to start the game.\n\nGreat job! This JSON output should provide a clear and ' \
'comprehensive overview of the project\'s requirements and dependencies.'
output = extract_content_from_output(output)
assert output.startswith('{\n"Required Python third-party packages') and \
output.endswith('UNCLEAR": "How to start the game."\n]')
output = 'Sure, I would be happy to help! Here is the information you provided, formatted as a JSON object ' \
'inside the [CONTENT] tag:\n\n[CONTENT]\n{\n"Original Requirements": "Create a 2048 game",\n"Search ' \
'Information": "Search results for 2048 game",\n"Requirements": [\n"Create a game with the same rules ' \
'as the original 2048 game",\n"Implement a user interface that is easy to use and understand",\n"Add a ' \
'scoreboard to track the player progress",\n"Allow the player to undo and redo moves",\n"Implement a ' \
'game over screen to display the final score"\n],\n"Product Goals": [\n"Create a fun and engaging game ' \
'experience for the player",\n"Design a user interface that is visually appealing and easy to use",\n"' \
'Optimize the game for performance and responsiveness"\n],\n"User Stories": [\n"As a player, I want to ' \
'be able to move tiles around the board to combine numbers",\n"As a player, I want to be able to undo ' \
'and redo moves to correct mistakes",\n"As a player, I want to see the final score and game over screen' \
' when I win"\n],\n"Competitive Analysis": [\n"Competitor A: 2048 game with a simple user interface and' \
' basic graphics",\n"Competitor B: 2048 game with a more complex user interface and better graphics",' \
'\n"Competitor C: 2048 game with a unique twist on the rules and a more challenging gameplay experience"' \
'\n],\n"Competitive Quadrant Chart": "quadrantChart\\n\ttitle Reach and engagement of campaigns\\n\t\t' \
'x-axis Low Reach --> High Reach\\n\t\ty-axis Low Engagement --> High Engagement\\n\tquadrant-1 We ' \
'should expand\\n\tquadrant-2 Need to promote\\n\tquadrant-3 Re-evaluate\\n\tquadrant-4 May be ' \
'improved\\n\tCampaign A: [0.3, 0.6]\\n\tCampaign B: [0.45, 0.23]\\n\tCampaign C: [0.57, 0.69]\\n\t' \
'Campaign D: [0.78, 0.34]\\n\tCampaign E: [0.40, 0.34]\\n\tCampaign F: [0.35, 0.78]"\n],\n"Requirement ' \
'Analysis": "The requirements are clear and well-defined, but there may be some ambiguity around the ' \
'specific implementation details",\n"Requirement Pool": [\n["P0", "Implement a game with the same ' \
'rules as the original 2048 game"],\n["P1", "Add a scoreboard to track the player progress"],\n["P2", ' \
'"Allow the player to undo and redo moves"]\n],\n"UI Design draft": "The UI should be simple and easy ' \
'to use, with a clean and visually appealing design. The game board should be the main focus of the ' \
'UI, with clear and concise buttons for the player to interact with.",\n"Anything UNCLEAR": ""\n}\n' \
'[/CONTENT]\n\nI hope this helps! Let me know if you have any further questions or if there anything ' \
'else I can do to assist you.'
output = extract_content_from_output(output)
assert output.startswith('{\n"Original Requirements"') and \
output.endswith('"Anything UNCLEAR": ""\n}')
output = """ Sure, I'd be happy to help! Here's the JSON output for the given context:\n\n[CONTENT]\n{
"Implementation approach": "We will use the open-source framework PyGame to create a 2D game engine, which will
provide us with a robust and efficient way to handle game logic and rendering. PyGame is widely used in the game
development community and has a large number of resources and tutorials available online.",\n"Python package name":
"pygame_2048",\n"File list": ["main.py", "game.py", "constants.py", "ui.py"],\n"Data structures and interface
definitions": '\nclassDiagram\n class Game{\n +int score\n +list<tile> tiles\n +function
move_tile(tile, int dx, int dy)\n +function undo_move()\n +function get_highest_score()\n }\n
class Tile{\n +int value\n +int x\n +int y\n }\n ...\n Game "1" -- "1" Food: has\n',
\n"Program call flow": '\nsequenceDiagram\n participant M as Main\n participant G as Game\n ...\n G->>M:
end game\n',\n"Anything UNCLEAR": "The requirement is clear to me."\n}\n[/CONTENT] Here's the JSON output for the
given context, wrapped inside the [CONTENT][/CONTENT] format:\n\n[CONTENT]\n{\n"Implementation approach": "We will
use the open-source framework PyGame to create a 2D game engine, which will provide us with a robust and efficient
way to handle game logic and rendering. PyGame is widely used in the game development community and has a large
number of resources and tutorials available online.",\n"Python package name": "pygame_2048",\n"File list":
["main.py", "game.py", "constants.py", "ui.py"],\n"Data structures and interface definitions": '\nclassDiagram\n
class Game{\n +int score\n +list<tile> tiles\n +function move_tile(tile, int dx, int dy)\n
+function undo_move()\n +function get_highest_score()\n }\n class Tile{\n +int value\n +int x\n
+int y\n }\n ...\n Game "1" -- "1" Food: has\n',\n"Program call flow": '\nsequenceDiagram\n participant
M as Main\n participant G as Game\n ...\n G->>M: end game\n',\n"Anything UNCLEAR": "The requirement is
clear to me."\n}\n[/CONTENT] Great! Your JSON output is well-formatted and provides all the necessary
information for a developer to understand the design and implementation of the 2048 game.
"""
output = extract_content_from_output(output)
assert output.startswith('{\n"Implementation approach"') and \
output.endswith('"Anything UNCLEAR": "The requirement is clear to me."\n}')