From 772d2aea563fe87a2e87ab92255a4f3865fc610e Mon Sep 17 00:00:00 2001 From: didi <84363704+didiforgithub@users.noreply.github.com> Date: Thu, 25 Jul 2024 10:47:17 +0800 Subject: [PATCH] Update --- examples/ags/benchmark/humaneval.py | 28 ++-- examples/ags/w_action_node/graph.py | 44 ++++-- examples/ags/w_action_node/operator.py | 83 +++++++++-- examples/ags/w_action_node/operator_an.py | 12 ++ examples/ags/w_action_node/prompt.py | 87 +++++++++-- examples/ags/w_action_node/utils.py | 172 +++++++++++++++++++++- he_test.py | 44 ++++-- metagpt/actions/action_node.py | 11 +- metagpt/actions/code_sanitize.py | 167 +++++++++++++++++++++ 9 files changed, 583 insertions(+), 65 deletions(-) create mode 100644 metagpt/actions/code_sanitize.py diff --git a/examples/ags/benchmark/humaneval.py b/examples/ags/benchmark/humaneval.py index ce0c02bbc..5a8cef297 100644 --- a/examples/ags/benchmark/humaneval.py +++ b/examples/ags/benchmark/humaneval.py @@ -17,7 +17,6 @@ from examples.ags.w_action_node.operator import GenerateCode, GenerateCodeBlock generate_code = GenerateCode(llm=LLM()) generate_code_block = GenerateCodeBlock(llm=LLM()) - solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability', vote_count=5) async def sample_generate(id, result_path:str="samples.jsonl",mode:str="ags"): @@ -25,7 +24,10 @@ async def sample_generate(id, result_path:str="samples.jsonl",mode:str="ags"): if mode == "ags": solution_result = await solver(case['prompt'],ensemble_count=5) sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution']) - else: + elif mode == "alpha": + solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5) + sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution']) + elif mode == "llm": solution_result = await generate_code_block(case['prompt']) sample_dict = dict(task_id=case['task_id'], solution=solution_result['code_solution']) with open(result_path, mode='a') as f: @@ -39,7 +41,7 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"): async def solve_and_write(case, mode): try: if mode == 'llm': - solution_result = await generate_code_block(case['prompt']) + solution_result = await generate_code_block(problem_description=case['prompt'], function_name=case['entry_point']) # solution_result = await generate_code(case['prompt']) sample_dict = { 'task_id': case['task_id'], @@ -51,7 +53,13 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"): 'task_id': case['task_id'], 'solution': solution_result['final_solution'] } - + elif mode == "alpha": + solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5) + sample_dict = { + 'task_id': case['task_id'], + 'solution': solution_result['final_solution'] + } + # TODO 解决 final_solution 问题之后就可以开始正式测评了 async with file_lock: async with aiofiles.open(result_path, mode='a') as f: await f.write(json.dumps(sample_dict) + '\n') @@ -65,7 +73,6 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"): results = await asyncio.gather(*tasks) failed_tasks = [task_id for task_id in results if task_id is not None] - # TODO 这个地方还是不够自动化 if failed_tasks: print(failed_tasks) if mode == 'llm': @@ -73,7 +80,7 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"): case = get_human_eval_plus()[task_id] for _ in range(3): try: - solution_result = await generate_code_block(case['prompt']) + solution_result = await generate_code_block(case['prompt'],function_name=case['entry_point']) task_dict = { 'task_id': case['task_id'], 'solution': solution_result['code_solution'] @@ -84,17 +91,18 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"): break except Exception as e: print(f"{e} \n failure {task_id}") - elif mode == "ags": + elif mode == "ags" or mode == "alpha": for task_id in failed_tasks: try: - await sample_generate(task_id,result_path) + await sample_generate(task_id,result_path,mode) except Exception as e: print(f"failure {task_id}") + jsonl_ranker(result_path, result_path) if not failed_tasks: # 自动 sanitize - result_path = automatic_sanitize(result_path) + # result_path = automatic_sanitize(result_path) if automatic_evalplus(result_path): eval_path = result_path[:-6]+"_eval_results.json" unpassed_exapmle = extract_failure_tests(eval_path) @@ -107,7 +115,7 @@ async def samples_generate_ags(): cases = list(get_human_eval_plus().values()) async def solve_with_id(case): - solution_result = await solver(case['prompt'], ensemble_count=3) + solution_result = await solver(case['prompt'], ensemble_count=5) return case['task_id'], solution_result['final_solution'] tasks = [solve_with_id(case) for case in cases] diff --git a/examples/ags/w_action_node/graph.py b/examples/ags/w_action_node/graph.py index 7b029b5a9..fe9a91ce9 100644 --- a/examples/ags/w_action_node/graph.py +++ b/examples/ags/w_action_node/graph.py @@ -5,8 +5,8 @@ from metagpt.llm import LLM from typing import List -from examples.ags.w_action_node.operator import Generate, GenerateCode, GenerateCodeBlock, Review, Revise, FuEnsemble, MdEnsemble, DbEnsemble - +from examples.ags.w_action_node.operator import Generate, GenerateCode, GenerateCodeBlock, Review, Revise, FuEnsemble, MdEnsemble, DbEnsemble, Rephrase, Test +from examples.ags.w_action_node.utils import extract_test_cases_from_jsonl class Graph: def __init__(self, name:str, llm:LLM) -> None: self.name = name @@ -26,6 +26,8 @@ class HumanEvalGraph(Graph): self.generate_code_block = GenerateCodeBlock(llm=llm) self.review = Review(llm=llm, criteria=criteria) self.revise = Revise(llm=llm) + self.rephrase = Rephrase(llm=llm) + self.tester = Test(llm=llm) self.fuensemble = FuEnsemble(llm=llm) self.mdensemble = MdEnsemble(llm=llm, vote_count=vote_count) @@ -41,10 +43,28 @@ class HumanEvalGraph(Graph): break except Exception as e: print(e) - # solution list 有5个 solution = await self.mdensemble("code", solution_list, problem) return solution + + async def alpha_codium(self, problem_id:str, problem:str, ensemble_count:int = 3): + # async def __call__(self,problem_id, problem:str, ensemble_count:int = 3): + test_cases = extract_test_cases_from_jsonl(problem_id) + rephrase_problem = await self.rephrase(problem) # 在rephrase 中拼接原始的问题描述 + solution_list = [] + for _ in range(ensemble_count): + for retry_count in range(5): + try: + solution = await self.generate_code_block(problem, rephrase_problem) + solution = solution.get('code_solution') + solution_list.append(solution) + break + except Exception as e: + print(e) + solution = await self.mdensemble("code", solution_list, problem) + solution = await self.tester(problem, rephrase_problem, solution, test_cases) + return solution + async def review_revise_ensemble(self, problem:str, ensemble_count:int = 2): solution_list = [] for _ in range(ensemble_count): @@ -53,16 +73,16 @@ class HumanEvalGraph(Graph): solution = await self.ensemble(solution_list, problem) return solution - # async def simple_ensemble(self, problem:str, ensemble_count:int = 3): + async def simple_ensemble(self, problem:str, ensemble_count:int = 3): # async def __call__(self, problem:str, ensemble_count:int = 3): - # solution_list = [] - # for _ in range(ensemble_count): - # solution = await self.generate_code(problem) - # # solution = await self.generate_code_block(problem) - # solution = solution.get('code_solution') - # solution_list.append(solution) - # solution = await self.fuensemble(solution_list, problem) - # return solution + solution_list = [] + for _ in range(ensemble_count): + solution = await self.generate_code(problem) + # solution = await self.generate_code_block(problem) + solution = solution.get('code_solution') + solution_list.append(solution) + solution = await self.fuensemble(solution_list, problem) + return solution async def single_solve(self, problem:str, max_loop:int): solution = await self.generate_code(problem) diff --git a/examples/ags/w_action_node/operator.py b/examples/ags/w_action_node/operator.py index fe9ee5da7..ef6c0fb53 100644 --- a/examples/ags/w_action_node/operator.py +++ b/examples/ags/w_action_node/operator.py @@ -10,9 +10,11 @@ from collections import Counter from metagpt.actions.action_node import ActionNode from metagpt.llm import LLM -from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, GenerateCodeBlockOp ,ReviewOp, ReviseOp, FuEnsembleOp, MdEnsembleOp -from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, GENERATE_CODEBLOCK_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, FU_ENSEMBLE_PROMPT, MD_ENSEMBLE_PROMPT, DE_ENSEMBLE_ANGEL_PROMPT, DE_ENSEMBLE_DEVIL_PROMPT, DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT, DE_ENSEMBLE_JUDGE_FINAL_PROMPT -from examples.ags.w_action_node.prompt import DE_ENSEMBLE_CODE_FORMAT_PROMPT, DE_ENSEMBLE_TXT_FORMAT_PROMPT +from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, GenerateCodeBlockOp ,ReviewOp, ReviseOp, FuEnsembleOp, MdEnsembleOp, ReflectionTestOp, RephraseOp +from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, GENERATE_CODEBLOCK_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, FU_ENSEMBLE_PROMPT, MD_ENSEMBLE_PROMPT, REFLECTION_ON_PUBILIC_TEST_PROMPT, REPHRASE_ON_PROBLEM_PROMPT, GENERATE_CODEBLOCK_REPHRASE_PROMPT +from examples.ags.w_action_node.prompt import DE_ENSEMBLE_CODE_FORMAT_PROMPT, DE_ENSEMBLE_TXT_FORMAT_PROMPT, DE_ENSEMBLE_ANGEL_PROMPT, DE_ENSEMBLE_DEVIL_PROMPT, DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT, DE_ENSEMBLE_JUDGE_FINAL_PROMPT +from examples.ags.w_action_node.utils import test_cases_2_test_functions + class Operator: def __init__(self, name, llm:LLM): self.name = name @@ -47,9 +49,15 @@ class GenerateCodeBlock(Operator): def __init__(self, name:str ="Coder", llm: LLM = LLM()): super().__init__(name, llm) - async def __call__(self, problem_description): + async def __call__(self, problem_description, function_name): prompt = GENERATE_CODEBLOCK_PROMPT.format(problem_description=problem_description) - node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm,mode='code_fill') + node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm, mode='code_fill',function_name=function_name) + response = node.instruct_content.model_dump() + return response + + async def rephrase_generate(self, problem_description, rephrase_problem, function_name): + prompt = GENERATE_CODEBLOCK_REPHRASE_PROMPT.format(problem_description=problem_description,rephrase_problem=rephrase_problem) + node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm, mode='code_fill', function_name=function_name) response = node.instruct_content.model_dump() return response @@ -153,7 +161,7 @@ class MdEnsemble(Operator): most_frequent_index = Counter(all_responses).most_common(1)[0][0] print(f"most frequent_index: {most_frequent_index}") final_answer = solutions[most_frequent_index] - print(f"final answer: {final_answer}") + print(f"final answer: \n{final_answer}") # final_answer, frequency = self.most_frequent(all_responses) return {"final_solution": final_answer} @@ -293,23 +301,68 @@ class DbEnsemble(Operator): class Rephrase(Operator): """ - - https://arxiv.org/abs/2404.14963 + 1. AlphaCodium + 2. https://arxiv.org/abs/2404.14963 """ - pass + def __init__(self, name:str ="Rephraser", llm: LLM = LLM()): + super().__init__(name, llm) + async def __call__(self, problem_description:str)->str: + prompt = REPHRASE_ON_PROBLEM_PROMPT.format(problem_description=problem_description) + node = await ActionNode.from_pydantic(RephraseOp).fill(context=prompt, llm=self.llm) + response = node.instruct_content.model_dump() + return response["rephrased_problem"] + +class Test(Operator): + def __init__(self, name:str ="Tester", llm: LLM = LLM()): + super().__init__(name, llm) + + def test_cases_2_assert(self, test_cases): + return f"assert {test_cases[0]}({test_cases[1]}) == {test_cases[2]} \n" + + def exec_code(self, solution, test_cases): + solution = solution["final_solution"] + pass_case = [] + fail_case = [] + for test_case in test_cases: + test_code = test_cases_2_test_functions(solution,test_case) + try: + exec(test_code) + pass_case.append(self.test_cases_2_assert(test_case)) + except AssertionError as e: + fail_case.append(self.test_cases_2_assert(test_case)) + except Exception as e: + print(e) + return {"error":e} + if fail_case != []: + return fail_case + return [] + + async def __call__(self, problem, rephrase_problem, solution, test_cases): + result = self.exec_code(solution, test_cases) + # 处理通过Public Tests的代码 + # TODO 这里的问题是,如果Test直接通过了就没有办法Check Multi Tests了 + if result == []: + return solution + # 处理代码执行失败的代码 + elif type(result) == dict: + result = result["error"] + prompt = REFLECTION_ON_PUBILIC_TEST_PROMPT.format(problem_description=problem, rephrase_problem=rephrase_problem, code_solution=solution, exec_pass=f"executed unsuccessfully, error: \n {result}", test_fail="executed unsucessfully") + node = await ActionNode.from_pydantic(ReflectionTestOp).fill(context=prompt, llm=self.llm) + response = node.instruct_content.model_dump() + return {"final_solution":response["refined_solution"]} + else: + prompt = REFLECTION_ON_PUBILIC_TEST_PROMPT.format(problem_description=problem, rephrase_problem=rephrase_problem, code_solution=solution, exec_pass="executed successfully", test_fail=result) + node = await ActionNode.from_pydantic(ReflectionTestOp).fill(context=prompt, llm=self.llm) + response = node.instruct_content.model_dump() + return {"final_solution":response["refined_solution"]} + class FindFact(Operator): pass class SelfAsk(Operator): pass -class CodeReflection(Operator): - """ - Interpreter Part - We run code here to get error information. - """ - class Verify(Operator): """ ? 还没有想好 diff --git a/examples/ags/w_action_node/operator_an.py b/examples/ags/w_action_node/operator_an.py index 928c0f67a..2cad6b9fc 100644 --- a/examples/ags/w_action_node/operator_an.py +++ b/examples/ags/w_action_node/operator_an.py @@ -42,3 +42,15 @@ class MdEnsembleOp(BaseModel): description="The letter of the chosen best solution (only one letter)." ) +class TestCaseExtractOp(BaseModel): + test_cases: list = Field(default=[('', [5, 8, 7, 1], 12), ('', [3, 3, 3, 3, 3], 9)], + description="Extracted test cases from the problem description") + +class RephraseOp(BaseModel): + rephrased_problem: str = Field(default="", description="Rephrased problem description for this problem") + +class ReflectionTestOp(BaseModel): + reflection: str = Field(default="", description="对关于代码执行错误或者测试用例失败step by step的思考") + refined_solution: str = Field(default="", description="对于代码执行错误或者测试用例失败的修正方案") + + \ No newline at end of file diff --git a/examples/ags/w_action_node/prompt.py b/examples/ags/w_action_node/prompt.py index 9822fab49..6e41b4280 100644 --- a/examples/ags/w_action_node/prompt.py +++ b/examples/ags/w_action_node/prompt.py @@ -3,11 +3,6 @@ # @Author : didi # @Desc : prompts of operators -# TODO PromptBreeder 评分是怎么做的? -# TODO 评估案例 GSM-8K 直接拿的DataSet -# -# - GENERATE_PROMPT = """ Generate Solution for the following problem: {problem_description} """ @@ -35,21 +30,35 @@ You are an expert programmer tasked with solving a coding problem. The above is an incomplete Python code fragment. Return the complete and correct code with no additional text. Please maintain the JSON format in your response. ### Your Response: - """ -GENERATE_CODEBLOCK_PROMPT = """ -You are an expert programmer tasked with solving a coding problem. + +# GENERATE_CODEBLOCK_PROMPT = """ +# You are an expert programmer tasked with solving a coding problem. + +# ### Problem Description: +# {problem_description} + +# ### Instructions: +# The above is an incomplete Python code fragment. Return the complete and correct code with no additional text. +# """ + +GENERATE_CODEBLOCK_REPHRASE_PROMPT = """ +You are given a code contest problem, and a self-reflection on the problem: ### Problem Description: {problem_description} -### Instructions: -The above is an incomplete Python code fragment. Return the complete and correct code with no additional text. +### self reflection on the problem +{rephrase_problem} + +======================= +The above is an incomplete Python code fragment and reflection on it. Return the complete and correct code with no additional text. """ -# GENERATE_CODE_PROMPT = """ -# Generate Code Solution for the following problem: {problem_description} -# """ +GENERATE_CODEBLOCK_PROMPT = """ +Please provide a self-contained Python script that solves the following problem in a markdown code block: +{problem_description} +""" REVIEW_PROMPT = """ For the question described as {problem_description}, @@ -135,3 +144,55 @@ You, as the moderator, will evaluate both sides' answers and determine if there Please strictly output in JSON format, do not output irrelevant content """ +EXTRACT_CASE_PROMPT = """ +You are given a coding problem, and you need to extract the test cases from the problem description. +{problem_description} + +一个problem中会有多个测试用例,每个测试用例包含三个部分: +1. 函数名 +2. 输入 +3. 期望输出 +每个测试用例包裹在一个三元组之中,三元组之间用逗号分隔,整体用列表包裹。 +由于结果需要被解析到JSON中,True与False请表示为true, false; +""" + +REPHRASE_ON_PROBLEM_PROMPT = """ +You are given a code contest problem: + +### problem +{problem_description} + +### instrcutions +Given the code contest problem, Your Goal is: +Reflect on the problem, and describe it in your own words, in bullet points. Pay attention to small details, nuances, notes and examples in the problem description. + +""" + +REFLECTION_ON_PUBILIC_TEST_PROMPT = """ + +You are given a code contest problem, and a self-reflection on the problem: +### problem +{problem_description} + +### self reflection on the problem +{rephrase_problem} + +======================= +A Python code solution was generated for the problem: +### Code Solution +{code_solution} + +======================= +This section of the code execution result is +### Execution Result +{exec_pass} + +======================= +However, when running the following input example, the code solution above failed to produce the expected output: +#### Failed Test Case +{test_fail} + +Your goal is to analyze the code solution and the error, and propose a fixed code which will produce the expected output for the provided test input. +The fixed code should keep the solution robust, and work for all other input examples as well. +Make sure the fixed code has a reasonable runtime - less than three seconds on a modern computer, given the problem constraints for large input. +""" \ No newline at end of file diff --git a/examples/ags/w_action_node/utils.py b/examples/ags/w_action_node/utils.py index 98d97dd8f..6380c7bd5 100644 --- a/examples/ags/w_action_node/utils.py +++ b/examples/ags/w_action_node/utils.py @@ -5,7 +5,11 @@ import json import re -from typing import List, Dict, Any +from typing import List, Dict, Any, Tuple +from metagpt.llm import LLM +from metagpt.actions.action_node import ActionNode +from examples.ags.w_action_node.operator_an import TestCaseExtractOp +from examples.ags.w_action_node.prompt import EXTRACT_CASE_PROMPT def extract_task_id(task_id: str) -> int: """Extract the numeric part of the task_id.""" @@ -29,4 +33,168 @@ def jsonl_ranker(input_file: str, output_file: str): # Write the sorted data to a new JSONL file with open(output_file, 'w') as f: for item in sorted_data: - f.write(json.dumps(item) + '\n') \ No newline at end of file + f.write(json.dumps(item) + '\n') + +# def extract_test_cases_from_jsonl(problem_id:str, file_path:str="public_test.jsonl"): +# # TODO 这个JSONL效率有点神经病 +# if problem_id == "Humaneval/87": +# return [ ["get_row", [[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 1, 6], [1, 2, 3, 4, 5, 1]], 1], [(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)]], ["get_row", [[], 1], []], ["get_row", [[[], [1], [1, 2, 3]], 3], [(2, 2)]] ] +# elif problem_id == "Humaneval/95": +# return [ ["check_dict_case", [{"a": "apple", "b": "banana"}], True], ["check_dict_case", [{"a": "apple", "A": "banana", "B": "banana"}], False], ["check_dict_case", [{"a": "apple", "8": "banana", "a": "apple"}], False], ["check_dict_case", [{"Name": "John", "Age": "36", "City": "Houston"}], False], ["check_dict_case", [{"STATE": "NC", "ZIP": "12345"}], True] ] +# elif problem_id == "Humaneval/107": +# return [ ["even_odd_palindrome", [3], (1, 2)], ["even_odd_palindrome", [12], (4, 6)] ] +# elif problem_id == "Humaneval/112": +# return [ ["reverse_delete", ["abcde", "ae"], ("bcd", False)], ["reverse_delete", ["abcdef", "b"], ("acdef", False)], ["reverse_delete", ["abcdedcba", "ab"], ("cdedc", True)] ] +# elif problem_id == "Humaneval/127": +# return [ ["intersection", [(1, 2), (2, 3)], "NO"], ["intersection", [(-1, 1), (0, 4)], "NO"], ["intersection", [(-3, -1), (-5, 5)], "YES"] ] +# elif problem_id == "Humaneval/136": +# return [ ["largest_smallest_integers", [2, 4, 1, 3, 5, 7], (None, 1)], ["largest_smallest_integers", [], (None, None)], ["largest_smallest_integers", [0], (None, None)] ] +# elif problem_id == "Humaneval/148": +# return [ ["bf", ["Jupiter", "Neptune"], ("Saturn", "Uranus")], ["bf", ["Earth", "Mercury"], ("Venus",)], ["bf", ["Mercury", "Uranus"], ("Venus", "Earth", "Mars", "Jupiter", "Saturn")], ["bf", ["InvalidPlanet", "Neptune"], ()], ["bf", ["Jupiter", "InvalidPlanet"], ()], ["bf", ["Mercury", "Mercury"], ()] ] +# elif problem_id == "Humaneval/155": +# return [ ["even_odd_count", [-12], (1, 1)], ["even_odd_count", [123], (1, 2)] ] + +# with open(file_path, 'r') as file: +# for line in file: +# data = json.loads(line) +# if problem_id in data: +# return data[problem_id] + +# return None + +import json +import ast + +def parse_python_literal(s): + try: + return ast.literal_eval(s) + except (ValueError, SyntaxError): + return s + +def extract_test_cases_from_jsonl(problem_id:str, file_path:str="public_test.jsonl"): + # 保留原有的硬编码测试用例 + hardcoded_cases = { + "HumanEval/87": [ ["get_row", [[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 1, 6], [1, 2, 3, 4, 5, 1]], 1], [(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)]], ["get_row", [[], 1], []], ["get_row", [[[], [1], [1, 2, 3]], 3], [(2, 2)]] ], + "HumanEval/95": [ ["check_dict_case", [{"a": "apple", "b": "banana"}], True], ["check_dict_case", [{"a": "apple", "A": "banana", "B": "banana"}], False], ["check_dict_case", [{"a": "apple", "8": "banana", "a": "apple"}], False], ["check_dict_case", [{"Name": "John", "Age": "36", "City": "Houston"}], False], ["check_dict_case", [{"STATE": "NC", "ZIP": "12345"}], True] ], + "HumanEval/107": [ ["even_odd_palindrome", [3], (1, 2)], ["even_odd_palindrome", [12], (4, 6)] ], + "HumanEval/112": [ ["reverse_delete", ["abcde", "ae"], ("bcd", False)], ["reverse_delete", ["abcdef", "b"], ("acdef", False)], ["reverse_delete", ["abcdedcba", "ab"], ("cdedc", True)] ], + "HumanEval/127": [ ["intersection", [(1, 2), (2, 3)], "NO"], ["intersection", [(-1, 1), (0, 4)], "NO"], ["intersection", [(-3, -1), (-5, 5)], "YES"] ], + "HumanEval/136": [ ["largest_smallest_integers", [2, 4, 1, 3, 5, 7], (None, 1)], ["largest_smallest_integers", [], (None, None)], ["largest_smallest_integers", [0], (None, None)] ], + "HumanEval/148": [ ["bf", ["Jupiter", "Neptune"], ("Saturn", "Uranus")], ["bf", ["Earth", "Mercury"], ("Venus",)], ["bf", ["Mercury", "Uranus"], ("Venus", "Earth", "Mars", "Jupiter", "Saturn")], ["bf", ["InvalidPlanet", "Neptune"], ()], ["bf", ["Jupiter", "InvalidPlanet"], ()], ["bf", ["Mercury", "Mercury"], ()] ], + "HumanEval/155": [ ["even_odd_count", [-12], (1, 1)], ["even_odd_count", [123], (1, 2)] ] + } + + # 检查是否有硬编码的测试用例 + if problem_id in hardcoded_cases: + return hardcoded_cases[problem_id] + + # 如果没有硬编码的测试用例,从文件中读取 + with open(file_path, 'r') as file: + for line in file: + data = json.loads(line) + if problem_id in data: + problem_data = data[problem_id] + # 处理测试用例 + for i, test_case in enumerate(problem_data): + # 函数名保持不变 + # 参数列表需要解析 + test_case[1] = [parse_python_literal(arg) for arg in test_case[1]] + # 预期输出需要解析 + test_case[2] = parse_python_literal(test_case[2]) + return problem_data + + return None # 如果没有找到问题,返回 None + +def extract_test_cases(docstring: str) -> List[Tuple[str, List[Any], Any]]: + # 使用正则表达式匹配测试用例,现在捕获函数名和任意输出 + pattern = r'>>> (\w+)\((.*?)\)\n\s*(.*?)(?=\n|$)' + matches = re.findall(pattern, docstring, re.DOTALL) + + test_cases = [] + for match in matches: + func_name, input_str, expected_output = match + + # 处理输入 + input_list = [] + for item in input_str.split(','): + item = item.strip() + try: + # 尝试将输入转换为数值类型 + if '.' in item: + input_list.append(float(item)) + else: + input_list.append(int(item)) + except ValueError: + # 如果无法转换为数值,则保留为字符串 + input_list.append(item.strip("'\"")) + + # 处理输出 + try: + # 尝试将输出转换为数值或布尔值 + if expected_output.lower() == 'true': + expected_output = True + elif expected_output.lower() == 'false': + expected_output = False + elif '.' in expected_output: + expected_output = float(expected_output) + else: + expected_output = int(expected_output) + except ValueError: + # 如果无法转换,则保留为字符串 + expected_output = expected_output.strip("'\"") + + test_cases.append([func_name, input_list, expected_output]) + + return test_cases + + +async def llm_extract_test_case(id, problem_description: str, file_path:str="public_test.jsonl"): + prompt = EXTRACT_CASE_PROMPT.format(problem_description=problem_description) + node = await ActionNode.from_pydantic(TestCaseExtractOp).fill(context=prompt, llm=LLM()) + result = node.instruct_content.model_dump() + with open(file_path,"a") as f: + f.write(json.dumps({id:result["test_cases"]}) + '\n') + return {id:result["test_cases"]} + +import json + +def test_cases_2_test_functions(solution: str, test_case: List): + function_name = test_case[0] + + def format_param(param): + if isinstance(param, str): + return repr(param) + elif isinstance(param, (int, float, bool)): + return str(param) + elif isinstance(param, list): + return '[' + ', '.join(format_param(item) for item in param) + ']' + elif isinstance(param, tuple): + return '(' + ', '.join(format_param(item) for item in param) + ')' + elif isinstance(param, dict): + return '{' + ', '.join(f'{format_param(k)}: {format_param(v)}' for k, v in param.items()) + '}' + elif isinstance(param, type(None)): + return 'None' + else: + raise ValueError(f"Unsupported parameter type: {type(param)}") + + parameters = ', '.join(format_param(item) for item in test_case[1]) + print(type(test_case[2]), test_case[2]) + expected_output = format_param(test_case[2]) + print(expected_output) + + tester_function = f""" +{solution} + +def check(candidate): + assert candidate({parameters}) == {expected_output} + +check({function_name}) + """ + + print(f""" + Generated test function: + {tester_function} + """) + + return tester_function + \ No newline at end of file diff --git a/he_test.py b/he_test.py index a0cbed79f..3b348ca02 100644 --- a/he_test.py +++ b/he_test.py @@ -1,20 +1,46 @@ import asyncio +import json from metagpt.llm import LLM +from evalplus.data import get_human_eval_plus, write_jsonl from examples.ags.benchmark.humaneval import sample_generate, samples_generate, extract_failure_tests, automatic_evalplus -from examples.ags.w_action_node.utils import jsonl_ranker - +from examples.ags.w_action_node.utils import jsonl_ranker, llm_extract_test_case +from examples.ags.w_action_node.graph import HumanEvalGraph # 132 141 136 80 73 -# asyncio.run(sample_generate('HumanEval/118',result_path="llm_based_4.jsonl",mode="llm")) -# asyncio.run(samples_generate(mode='ags',result_path="ags_based_1.jsonl")) +# asyncio.run(sample_generate('HumanEval/118',result_path="llm_based_8.jsonl",mode="llm")) +asyncio.run(samples_generate(mode='llm',result_path="llm_based_100.jsonl")) # jsonl_ranker("samples.jsonl", "samples.jsonl") -result_path = "ags_based_2.jsonl" -if automatic_evalplus(result_path): - unpassed_exapmle = extract_failure_tests(result_path[:-6]+"_eval_results.json") - print(unpassed_exapmle) +# result_path = "ags_based_6.jsonl" +# if automatic_evalplus(result_path): +# unpassed_exapmle = extract_failure_tests(result_path[:-6]+"_eval_results.json") +# print(unpassed_exapmle) # unpassed_exapmle = extract_failure_tests(file_path="2_eval_results.json") # print(unpassed_exapmle) # for example in failure_list: -# asyncio.run(sample_generate(example)) \ No newline at end of file +# asyncio.run(sample_generate(example)) + +# TODO 抽取Public Test没搞完,先用几个测试跑一下流程 +# from evalplus.data import get_human_eval_plus + +# id_list = [87, 95, 107, 112, 127, 136, 148, 155] +# id_list = [155] +# cases_id = [f"HumanEval/{case_id}" for case_id in id_list] +# cases = {case_id: get_human_eval_plus()[case_id]['prompt'] for case_id in cases_id} +# async def main(cases): +# try: +# tasks = [llm_extract_test_case(case_id, case) for case_id, case in cases.items()] +# results = await asyncio.gather(*tasks) +# except: +# failed_tasks = [task_id for task_id in results if task_id is not None] +# print(failed_tasks) +# return results + +# asyncio.run(main(cases)) + +# [72, 80, 82, 87, 90, 95, 107, 109, 112, 124, 126, 127, 128, 132, 134, 136, 137, 138, 148, 154, 155] + +# case_prompt= get_human_eval_plus()["HumanEval/136"]['prompt'] +# solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability', vote_count=1) +# result = asyncio.run(solver.alpha_codium(problem_id="HumanEval/136", problem=case_prompt, ensemble_count=1)) \ No newline at end of file diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 738073277..c31900d3b 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -17,6 +17,7 @@ from pydantic import BaseModel, Field, create_model, model_validator from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_outcls_registry import register_action_outcls +from metagpt.actions.code_sanitize import sanitize from metagpt.const import USE_CONFIG_TIMEOUT from metagpt.llm import BaseLLM from metagpt.logs import logger @@ -484,6 +485,7 @@ class ActionNode: async def code_fill( self, context, + function_name=None, timeout=USE_CONFIG_TIMEOUT ): """ @@ -510,10 +512,10 @@ class ActionNode: import re field_name = self.get_field_name() prompt = context - # prompt += "\nPlease wrap the generated code within triple backticks, like this: ``````" content = await self.llm.aask(prompt, timeout=timeout) - - extracted_code = extract_code_from_response(content) + # TODO 在前置逻辑中完成entrypoint的提取就可以 + extracted_code = sanitize(code=content, entrypoint=function_name) + # extracted_code = extract_code_from_response(content) result = {field_name: extracted_code} return result @@ -536,6 +538,7 @@ class ActionNode: images: Optional[Union[str, list[str]]] = None, timeout=USE_CONFIG_TIMEOUT, exclude=[], + function_name: str = None ): """Fill the node(s) with mode. @@ -563,7 +566,7 @@ class ActionNode: schema = self.schema if mode == self.MODE_CODE_FILL: - result = await self.code_fill(context, timeout) + result = await self.code_fill(context, function_name, timeout) self.instruct_content = self.create_class()(**result) return self diff --git a/metagpt/actions/code_sanitize.py b/metagpt/actions/code_sanitize.py new file mode 100644 index 000000000..958c712df --- /dev/null +++ b/metagpt/actions/code_sanitize.py @@ -0,0 +1,167 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +@Time : 2024/7/24 16:37 +@Author : didi +@File : code_node.py +""" +import os +import ast +import pathlib +import traceback + +from typing import Dict, Generator, List, Optional, Set, Tuple + +import tree_sitter_python +from tqdm import tqdm +from tree_sitter import Language, Node, Parser + +CLASS_TYPE = "class_definition" +FUNCTION_TYPE = "function_definition" +IMPORT_TYPE = ["import_statement", "import_from_statement"] +IDENTIFIER_TYPE = "identifier" +ATTRIBUTE_TYPE = "attribute" +RETURN_TYPE = "return_statement" +EXPRESSION_TYPE = "expression_statement" +ASSIGNMENT_TYPE = "assignment" + +def traverse_tree(node: Node) -> Generator[Node, None, None]: + cursor = node.walk() + depth = 0 + + visited_children = False + while True: + if not visited_children: + yield cursor.node + if not cursor.goto_first_child(): + depth += 1 + visited_children = True + elif cursor.goto_next_sibling(): + visited_children = False + elif not cursor.goto_parent() or depth == 0: + break + else: + depth -= 1 + +def syntax_check(code, verbose=False): + try: + ast.parse(code) + return True + except (SyntaxError, MemoryError): + if verbose: + traceback.print_exc() + return False + +def code_extract(text: str) -> str: + lines = text.split("\n") + longest_line_pair = (0, 0) + longest_so_far = 0 + + for i in range(len(lines)): + for j in range(i + 1, len(lines)): + current_lines = "\n".join(lines[i : j + 1]) + if syntax_check(current_lines): + current_length = sum(1 for line in lines[i : j + 1] if line.strip()) + if current_length > longest_so_far: + longest_so_far = current_length + longest_line_pair = (i, j) + + return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) + +def get_definition_name(node: Node) -> str: + for child in node.children: + if child.type == IDENTIFIER_TYPE: + return child.text.decode("utf8") + +def has_return_statement(node: Node) -> bool: + traverse_nodes = traverse_tree(node) + for node in traverse_nodes: + if node.type == RETURN_TYPE: + return True + return False + +def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: + def dfs_get_deps(node: Node, deps: Set[str]) -> None: + for child in node.children: + if child.type == IDENTIFIER_TYPE: + deps.add(child.text.decode("utf8")) + else: + dfs_get_deps(child, deps) + + name2deps = {} + for name, node in nodes: + deps = set() + dfs_get_deps(node, deps) + name2deps[name] = deps + return name2deps + + +def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: + queue = [entrypoint] + visited = {entrypoint} + while queue: + current = queue.pop(0) + if current not in call_graph: + continue + for neighbour in call_graph[current]: + if not (neighbour in visited): + visited.add(neighbour) + queue.append(neighbour) + return visited + +def sanitize(code: str, entrypoint: Optional[str] = None) -> str: + code = code_extract(code) + code_bytes = bytes(code, "utf8") + parser = Parser(Language(tree_sitter_python.language())) + tree = parser.parse(code_bytes) + class_names = set() + function_names = set() + variable_names = set() + + root_node = tree.root_node + import_nodes = [] + definition_nodes = [] + + for child in root_node.children: + if child.type in IMPORT_TYPE: + import_nodes.append(child) + elif child.type == CLASS_TYPE: + name = get_definition_name(child) + if not ( + name in class_names or name in variable_names or name in function_names + ): + definition_nodes.append((name, child)) + class_names.add(name) + elif child.type == FUNCTION_TYPE: + name = get_definition_name(child) + if not ( + name in function_names or name in variable_names or name in class_names + ) and has_return_statement(child): + definition_nodes.append((name, child)) + function_names.add(get_definition_name(child)) + elif ( + child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE + ): + subchild = child.children[0] + name = get_definition_name(subchild) + if not ( + name in variable_names or name in function_names or name in class_names + ): + definition_nodes.append((name, subchild)) + variable_names.add(name) + + if entrypoint: + name2deps = get_deps(definition_nodes) + reacheable = get_function_dependency(entrypoint, name2deps) + + sanitized_output = b"" + + for node in import_nodes: + sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" + + for pair in definition_nodes: + name, node = pair + if entrypoint and not (name in reacheable): + continue + sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" + return sanitized_output[:-1].decode("utf8")