Merge branch 'main' into dev

This commit is contained in:
geekan 2024-02-01 13:25:05 +08:00
commit f3cb2cfbed
4 changed files with 43 additions and 10 deletions

View file

@ -42,8 +42,8 @@ Determine the ONE file to rewrite in order to fix the error, for example, xyz.py
Determine if all of the code works fine, if so write PASS, else FAIL,
WRITE ONLY ONE WORD, PASS OR FAIL, IN THIS SECTION
## Send To:
Please write Engineer if the errors are due to problematic development codes, and QaEngineer to problematic test codes, and NoOne if there are no errors,
WRITE ONLY ONE WORD, Engineer OR QaEngineer OR NoOne, IN THIS SECTION.
Please write NoOne if there are no errors, Engineer if the errors are due to problematic development codes, else QaEngineer,
WRITE ONLY ONE WORD, NoOne OR Engineer OR QaEngineer, IN THIS SECTION.
---
You should fill in necessary instruction, status, send to, and finally return all content between the --- segment line.
"""

View file

@ -108,7 +108,7 @@ class BaseLLM(ABC):
def get_choice_delta_text(self, rsp: dict) -> str:
"""Required to provide the first text of stream choice"""
return rsp.get("choices")[0]["delta"]["content"]
return rsp.get("choices", [{}])[0].get("delta", {}).get("content", "")
def get_choice_function(self, rsp: dict) -> dict:
"""Required to provide the first function of choice

View file

@ -119,15 +119,22 @@ def repair_json_format(output: str) -> str:
logger.info(f"repair_json_format: {'}]'}")
elif output.startswith("{") and output.endswith("]"):
output = output[:-1] + "}"
# remove `#` in output json str, usually appeared in `glm-4`
# remove comments in output json string, after json value content, maybe start with #, maybe start with //
arr = output.split("\n")
new_arr = []
for line in arr:
idx = line.find("#")
if idx >= 0:
line = line[:idx]
new_arr.append(line)
for json_line in arr:
# look for # or // comments and make sure they are not inside the string value
comment_index = -1
for match in re.finditer(r"(\".*?\"|\'.*?\')|(#|//)", json_line):
if match.group(1): # if the string value
continue
if match.group(2): # if comments
comment_index = match.start(2)
break
# if comments, then delete them
if comment_index != -1:
json_line = json_line[:comment_index].rstrip()
new_arr.append(json_line)
output = "\n".join(new_arr)
return output

View file

@ -141,6 +141,32 @@ def test_repair_json_format():
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
assert output == target_output
raw_output = """
{
"Language": "en_us", // define language
"Programming Language": "Python" # define code language
}
"""
target_output = """{
"Language": "en_us",
"Programming Language": "Python"
}"""
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
assert output == target_output
raw_output = """
{
"Language": "#en_us#", // define language
"Programming Language": "//Python # Code // Language//" # define code language
}
"""
target_output = """{
"Language": "#en_us#",
"Programming Language": "//Python # Code // Language//"
}"""
output = repair_llm_raw_output(output=raw_output, req_keys=[None], repair_type=RepairType.JSON)
assert output == target_output
def test_repair_invalid_json():
from metagpt.utils.repair_llm_raw_output import repair_invalid_json