This commit is contained in:
didi 2024-08-06 10:53:12 +08:00
parent 47470fb74c
commit 008c5f0f1f
3 changed files with 88 additions and 34 deletions

View file

@ -47,10 +47,10 @@ class HumanEvalGraph(Graph):
self.fuensemble = FuEnsemble(llm=llm)
self.mdensemble = MdEnsemble(llm=llm, vote_count=vote_count)
async def __call__(self, problem: str, ensemble_count: int = 3):
async def __call__(self, problem: str, function_name: str, ensemble_count: int = 3):
solution_list = []
for _ in range(ensemble_count):
solution = await self.generate_code_block(problem)
solution = await self.generate_code_block(problem, function_name)
solution = solution.get("code_solution")
solution_list.append(solution)
solution = await self.mdensemble("code", solution_list, problem)
@ -73,7 +73,7 @@ class HumanEvalGraph(Graph):
solution = solution.get("code_solution")
solution_list.append(solution)
solution = await self.mdensemble("code", solution_list, problem)
solution = await self.tester(problem_id, problem, rephrase_problem, solution, test_cases)
solution = await self.tester(problem_id, problem, rephrase_problem, solution, test_cases, entry_point)
return solution
async def review_revise_ensemble(self, problem: str, ensemble_count: int = 2, revise_round: int = 3):

View file

@ -40,7 +40,7 @@ from examples.ags.w_action_node.prompt import (
REVIEW_PROMPT,
REVISE_PROMPT,
)
from examples.ags.w_action_node.utils import test_cases_2_test_functions
from examples.ags.w_action_node.utils import test_case_2_test_function
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
from metagpt.logs import logger
@ -165,6 +165,7 @@ class MdEnsemble(Operator):
async def __call__(self, solution_type: str, solutions: List[str], problem_description: str):
all_responses = []
# 当Ensmeble方案是Code类型时我们使用AST进行去重
# TODO AgentLess + 尝试权重
if solution_type == "code":
unique_structures = {}
updated_solutions = []
@ -209,6 +210,10 @@ class MdEnsemble(Operator):
return {"final_solution": final_answer}
class Md_Ensmble:
pass
class ScEnsemble(Operator):
"""
Paper: Self-Consistency Improves Chain of Thought Reasoning in Language Models
@ -359,34 +364,66 @@ class Test(Operator):
def __init__(self, name: str = "Test", llm: LLM = LLM()):
super().__init__(name, llm)
def exec_code(self, solution, test_cases, problem_id):
# TODO
# 1. 获取更加详细的Test error信息
# 2. 更换Public Test数据集当前使用的数据存在Label Leak(使用的Reflexion的数据集) -> 这个问题使用LLM抽取解决直接生成为assert代码串
# 3. 实现单独测试每一个test case -> 1
solution = solution["final_solution"]
test_code = test_cases_2_test_functions(solution, test_cases)
try:
exec(test_code, globals())
except AssertionError as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
with open("tester.txt", "a") as f:
f.write("test_error" + problem_id + "\n")
error_infomation = {
"test_fail_case": {"error_type": "AssertionError", "error_message": str(e), "traceback": tb_str}
}
logger.info(f"test error: {error_infomation}")
return error_infomation
except Exception as e:
with open("tester.txt", "a") as f:
f.write(problem_id + "\n")
return {"exec_fail_case": str(e)}
return []
# def exec_code(self, solution, test_cases, problem_id):
# # TODO
# # 1. 获取更加详细的Test error信息
# # 2. 更换Public Test数据集当前使用的数据存在Label Leak(使用的Reflexion的数据集) -> 这个问题使用LLM抽取解决直接生成为assert代码串
# # 3. 实现单独测试每一个test case -> 1
# solution = solution["final_solution"]
# test_code = test_cases_2_test_functions(solution, test_cases)
# fail_case = []
# try:
# exec(test_code, globals())
# except AssertionError as e:
# exc_type, exc_value, exc_traceback = sys.exc_info()
# tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
# with open("tester.txt", "a") as f:
# f.write("test_error" + problem_id + "\n")
# error_infomation = {
# "test_fail_case": {"error_type": "AssertionError", "error_message": str(e), "traceback": tb_str}
# }
# logger.info(f"test error: {error_infomation}")
# return error_infomation
# except Exception as e:
# with open("tester.txt", "a") as f:
# f.write(problem_id + "\n")
# return {"exec_fail_case": str(e)}
# return []
async def __call__(self, problem_id, problem, rephrase_problem, solution, test_cases):
result = self.exec_code(solution, test_cases, problem_id)
if result == []:
def exec_code(self, solution, test_cases, problem_id, entry_point):
solution = solution["final_solution"]
fail_cases = []
for test_case in test_cases:
test_code = test_case_2_test_function(solution, test_case, entry_point)
try:
exec(test_code, globals())
except AssertionError as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
with open("tester.txt", "a") as f:
f.write("test_error" + problem_id + "\n")
error_infomation = {
"test_fail_case": {
"test_case": test_case,
"error_type": "AssertionError",
"error_message": str(e),
"traceback": tb_str,
}
}
fail_cases.append(error_infomation)
logger.info(f"test error: {error_infomation}")
except Exception as e:
with open("tester.txt", "a") as f:
f.write(problem_id + "\n")
return {"exec_fail_case": str(e)}
if fail_cases != []:
return fail_cases
else:
return "no error"
async def __call__(self, problem_id, problem, rephrase_problem, solution, test_cases, entry_point):
result = self.exec_code(solution, test_cases, problem_id, entry_point)
if result == "no error":
return solution
elif "exec_fail_case" in result:
result = result["exec_fail_case"]
@ -401,7 +438,6 @@ class Test(Operator):
response = node.instruct_content.model_dump()
return {"final_solution": response["refined_solution"]}
else:
result = result["test_fail_case"]
prompt = REFLECTION_ON_PUBLIC_TEST_PROMPT.format(
problem_description=problem,
rephrase_problem=rephrase_problem,

View file

@ -47,7 +47,9 @@ def parse_python_literal(s):
return s
def extract_test_cases_from_jsonl(problem_id: str, file_path: str = "public_test_reflexion.jsonl"):
def extract_test_cases_from_jsonl(
problem_id: str, file_path: str = "examples/ags/benchmark/data/humaneval_public_test.jsonl"
):
# 保留原有的硬编码测试用例
hardcoded_cases = {
"HumanEval/32": "",
@ -63,7 +65,7 @@ def extract_test_cases_from_jsonl(problem_id: str, file_path: str = "public_test
with open(file_path, "r") as file:
for line in file:
data = json.loads(line)
if data.get("id") == problem_id:
if data.get("task_id") == problem_id:
return data.get("test")
return None # 如果没有找到问题,返回 None
@ -128,3 +130,19 @@ def test_cases_2_test_functions(solution: str, test_cases: str):
{test_cases}
"""
return tester_function
def test_case_2_test_function(solution: str, test_case: str, entry_point: str):
tester_function = f"""
{solution}
def check(candidate):
{test_case}
def test_check():
check({entry_point})
test_check()
"""
return tester_function