From 008c5f0f1f7c89a60be9be1f4a889e68d651c344 Mon Sep 17 00:00:00 2001 From: didi <84363704+didiforgithub@users.noreply.github.com> Date: Tue, 6 Aug 2024 10:53:12 +0800 Subject: [PATCH] Update --- examples/ags/w_action_node/graph.py | 6 +- examples/ags/w_action_node/operator.py | 94 ++++++++++++++++++-------- examples/ags/w_action_node/utils.py | 22 +++++- 3 files changed, 88 insertions(+), 34 deletions(-) diff --git a/examples/ags/w_action_node/graph.py b/examples/ags/w_action_node/graph.py index c0557a1dd..314f16af2 100644 --- a/examples/ags/w_action_node/graph.py +++ b/examples/ags/w_action_node/graph.py @@ -47,10 +47,10 @@ class HumanEvalGraph(Graph): self.fuensemble = FuEnsemble(llm=llm) self.mdensemble = MdEnsemble(llm=llm, vote_count=vote_count) - async def __call__(self, problem: str, ensemble_count: int = 3): + async def __call__(self, problem: str, function_name: str, ensemble_count: int = 3): solution_list = [] for _ in range(ensemble_count): - solution = await self.generate_code_block(problem) + solution = await self.generate_code_block(problem, function_name) solution = solution.get("code_solution") solution_list.append(solution) solution = await self.mdensemble("code", solution_list, problem) @@ -73,7 +73,7 @@ class HumanEvalGraph(Graph): solution = solution.get("code_solution") solution_list.append(solution) solution = await self.mdensemble("code", solution_list, problem) - solution = await self.tester(problem_id, problem, rephrase_problem, solution, test_cases) + solution = await self.tester(problem_id, problem, rephrase_problem, solution, test_cases, entry_point) return solution async def review_revise_ensemble(self, problem: str, ensemble_count: int = 2, revise_round: int = 3): diff --git a/examples/ags/w_action_node/operator.py b/examples/ags/w_action_node/operator.py index 72c0b30fc..b7a1ad384 100644 --- a/examples/ags/w_action_node/operator.py +++ b/examples/ags/w_action_node/operator.py @@ -40,7 +40,7 @@ from examples.ags.w_action_node.prompt import ( REVIEW_PROMPT, REVISE_PROMPT, ) -from examples.ags.w_action_node.utils import test_cases_2_test_functions +from examples.ags.w_action_node.utils import test_case_2_test_function from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM from metagpt.logs import logger @@ -165,6 +165,7 @@ class MdEnsemble(Operator): async def __call__(self, solution_type: str, solutions: List[str], problem_description: str): all_responses = [] # 当Ensmeble方案是Code类型时,我们使用AST进行去重 + # TODO AgentLess + 尝试权重 if solution_type == "code": unique_structures = {} updated_solutions = [] @@ -209,6 +210,10 @@ class MdEnsemble(Operator): return {"final_solution": final_answer} +class Md_Ensmble: + pass + + class ScEnsemble(Operator): """ Paper: Self-Consistency Improves Chain of Thought Reasoning in Language Models @@ -359,34 +364,66 @@ class Test(Operator): def __init__(self, name: str = "Test", llm: LLM = LLM()): super().__init__(name, llm) - def exec_code(self, solution, test_cases, problem_id): - # TODO - # 1. 获取更加详细的Test error信息 - # 2. 更换Public Test数据集,当前使用的数据存在Label Leak(使用的Reflexion的数据集) -> 这个问题使用LLM抽取解决,直接生成为assert代码串 - # 3. 实现单独测试每一个test case -> 1 - solution = solution["final_solution"] - test_code = test_cases_2_test_functions(solution, test_cases) - try: - exec(test_code, globals()) - except AssertionError as e: - exc_type, exc_value, exc_traceback = sys.exc_info() - tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback) - with open("tester.txt", "a") as f: - f.write("test_error" + problem_id + "\n") - error_infomation = { - "test_fail_case": {"error_type": "AssertionError", "error_message": str(e), "traceback": tb_str} - } - logger.info(f"test error: {error_infomation}") - return error_infomation - except Exception as e: - with open("tester.txt", "a") as f: - f.write(problem_id + "\n") - return {"exec_fail_case": str(e)} - return [] + # def exec_code(self, solution, test_cases, problem_id): + # # TODO + # # 1. 获取更加详细的Test error信息 + # # 2. 更换Public Test数据集,当前使用的数据存在Label Leak(使用的Reflexion的数据集) -> 这个问题使用LLM抽取解决,直接生成为assert代码串 + # # 3. 实现单独测试每一个test case -> 1 + # solution = solution["final_solution"] + # test_code = test_cases_2_test_functions(solution, test_cases) + # fail_case = [] + # try: + # exec(test_code, globals()) + # except AssertionError as e: + # exc_type, exc_value, exc_traceback = sys.exc_info() + # tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback) + # with open("tester.txt", "a") as f: + # f.write("test_error" + problem_id + "\n") + # error_infomation = { + # "test_fail_case": {"error_type": "AssertionError", "error_message": str(e), "traceback": tb_str} + # } + # logger.info(f"test error: {error_infomation}") + # return error_infomation + # except Exception as e: + # with open("tester.txt", "a") as f: + # f.write(problem_id + "\n") + # return {"exec_fail_case": str(e)} + # return [] - async def __call__(self, problem_id, problem, rephrase_problem, solution, test_cases): - result = self.exec_code(solution, test_cases, problem_id) - if result == []: + def exec_code(self, solution, test_cases, problem_id, entry_point): + solution = solution["final_solution"] + fail_cases = [] + for test_case in test_cases: + test_code = test_case_2_test_function(solution, test_case, entry_point) + try: + exec(test_code, globals()) + except AssertionError as e: + exc_type, exc_value, exc_traceback = sys.exc_info() + tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback) + with open("tester.txt", "a") as f: + f.write("test_error" + problem_id + "\n") + error_infomation = { + "test_fail_case": { + "test_case": test_case, + "error_type": "AssertionError", + "error_message": str(e), + "traceback": tb_str, + } + } + fail_cases.append(error_infomation) + logger.info(f"test error: {error_infomation}") + except Exception as e: + with open("tester.txt", "a") as f: + f.write(problem_id + "\n") + return {"exec_fail_case": str(e)} + if fail_cases != []: + return fail_cases + else: + return "no error" + + async def __call__(self, problem_id, problem, rephrase_problem, solution, test_cases, entry_point): + result = self.exec_code(solution, test_cases, problem_id, entry_point) + if result == "no error": return solution elif "exec_fail_case" in result: result = result["exec_fail_case"] @@ -401,7 +438,6 @@ class Test(Operator): response = node.instruct_content.model_dump() return {"final_solution": response["refined_solution"]} else: - result = result["test_fail_case"] prompt = REFLECTION_ON_PUBLIC_TEST_PROMPT.format( problem_description=problem, rephrase_problem=rephrase_problem, diff --git a/examples/ags/w_action_node/utils.py b/examples/ags/w_action_node/utils.py index fd3341cca..13de2a27d 100644 --- a/examples/ags/w_action_node/utils.py +++ b/examples/ags/w_action_node/utils.py @@ -47,7 +47,9 @@ def parse_python_literal(s): return s -def extract_test_cases_from_jsonl(problem_id: str, file_path: str = "public_test_reflexion.jsonl"): +def extract_test_cases_from_jsonl( + problem_id: str, file_path: str = "examples/ags/benchmark/data/humaneval_public_test.jsonl" +): # 保留原有的硬编码测试用例 hardcoded_cases = { "HumanEval/32": "", @@ -63,7 +65,7 @@ def extract_test_cases_from_jsonl(problem_id: str, file_path: str = "public_test with open(file_path, "r") as file: for line in file: data = json.loads(line) - if data.get("id") == problem_id: + if data.get("task_id") == problem_id: return data.get("test") return None # 如果没有找到问题,返回 None @@ -128,3 +130,19 @@ def test_cases_2_test_functions(solution: str, test_cases: str): {test_cases} """ return tester_function + + +def test_case_2_test_function(solution: str, test_case: str, entry_point: str): + tester_function = f""" +{solution} + + +def check(candidate): + {test_case} + +def test_check(): + check({entry_point}) + +test_check() +""" + return tester_function