mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-05 14:55:18 +02:00
Update
This commit is contained in:
parent
ca1c8f8c5c
commit
772d2aea56
9 changed files with 583 additions and 65 deletions
|
|
@ -17,7 +17,6 @@ from examples.ags.w_action_node.operator import GenerateCode, GenerateCodeBlock
|
|||
|
||||
generate_code = GenerateCode(llm=LLM())
|
||||
generate_code_block = GenerateCodeBlock(llm=LLM())
|
||||
|
||||
solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability', vote_count=5)
|
||||
|
||||
async def sample_generate(id, result_path:str="samples.jsonl",mode:str="ags"):
|
||||
|
|
@ -25,7 +24,10 @@ async def sample_generate(id, result_path:str="samples.jsonl",mode:str="ags"):
|
|||
if mode == "ags":
|
||||
solution_result = await solver(case['prompt'],ensemble_count=5)
|
||||
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
|
||||
else:
|
||||
elif mode == "alpha":
|
||||
solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5)
|
||||
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
|
||||
elif mode == "llm":
|
||||
solution_result = await generate_code_block(case['prompt'])
|
||||
sample_dict = dict(task_id=case['task_id'], solution=solution_result['code_solution'])
|
||||
with open(result_path, mode='a') as f:
|
||||
|
|
@ -39,7 +41,7 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
|
|||
async def solve_and_write(case, mode):
|
||||
try:
|
||||
if mode == 'llm':
|
||||
solution_result = await generate_code_block(case['prompt'])
|
||||
solution_result = await generate_code_block(problem_description=case['prompt'], function_name=case['entry_point'])
|
||||
# solution_result = await generate_code(case['prompt'])
|
||||
sample_dict = {
|
||||
'task_id': case['task_id'],
|
||||
|
|
@ -51,7 +53,13 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
|
|||
'task_id': case['task_id'],
|
||||
'solution': solution_result['final_solution']
|
||||
}
|
||||
|
||||
elif mode == "alpha":
|
||||
solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5)
|
||||
sample_dict = {
|
||||
'task_id': case['task_id'],
|
||||
'solution': solution_result['final_solution']
|
||||
}
|
||||
# TODO 解决 final_solution 问题之后就可以开始正式测评了
|
||||
async with file_lock:
|
||||
async with aiofiles.open(result_path, mode='a') as f:
|
||||
await f.write(json.dumps(sample_dict) + '\n')
|
||||
|
|
@ -65,7 +73,6 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
|
|||
results = await asyncio.gather(*tasks)
|
||||
failed_tasks = [task_id for task_id in results if task_id is not None]
|
||||
|
||||
# TODO 这个地方还是不够自动化
|
||||
if failed_tasks:
|
||||
print(failed_tasks)
|
||||
if mode == 'llm':
|
||||
|
|
@ -73,7 +80,7 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
|
|||
case = get_human_eval_plus()[task_id]
|
||||
for _ in range(3):
|
||||
try:
|
||||
solution_result = await generate_code_block(case['prompt'])
|
||||
solution_result = await generate_code_block(case['prompt'],function_name=case['entry_point'])
|
||||
task_dict = {
|
||||
'task_id': case['task_id'],
|
||||
'solution': solution_result['code_solution']
|
||||
|
|
@ -84,17 +91,18 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
|
|||
break
|
||||
except Exception as e:
|
||||
print(f"{e} \n failure {task_id}")
|
||||
elif mode == "ags":
|
||||
elif mode == "ags" or mode == "alpha":
|
||||
for task_id in failed_tasks:
|
||||
try:
|
||||
await sample_generate(task_id,result_path)
|
||||
await sample_generate(task_id,result_path,mode)
|
||||
except Exception as e:
|
||||
print(f"failure {task_id}")
|
||||
|
||||
jsonl_ranker(result_path, result_path)
|
||||
|
||||
if not failed_tasks:
|
||||
# 自动 sanitize
|
||||
result_path = automatic_sanitize(result_path)
|
||||
# result_path = automatic_sanitize(result_path)
|
||||
if automatic_evalplus(result_path):
|
||||
eval_path = result_path[:-6]+"_eval_results.json"
|
||||
unpassed_exapmle = extract_failure_tests(eval_path)
|
||||
|
|
@ -107,7 +115,7 @@ async def samples_generate_ags():
|
|||
cases = list(get_human_eval_plus().values())
|
||||
|
||||
async def solve_with_id(case):
|
||||
solution_result = await solver(case['prompt'], ensemble_count=3)
|
||||
solution_result = await solver(case['prompt'], ensemble_count=5)
|
||||
return case['task_id'], solution_result['final_solution']
|
||||
|
||||
tasks = [solve_with_id(case) for case in cases]
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@
|
|||
|
||||
from metagpt.llm import LLM
|
||||
from typing import List
|
||||
from examples.ags.w_action_node.operator import Generate, GenerateCode, GenerateCodeBlock, Review, Revise, FuEnsemble, MdEnsemble, DbEnsemble
|
||||
|
||||
from examples.ags.w_action_node.operator import Generate, GenerateCode, GenerateCodeBlock, Review, Revise, FuEnsemble, MdEnsemble, DbEnsemble, Rephrase, Test
|
||||
from examples.ags.w_action_node.utils import extract_test_cases_from_jsonl
|
||||
class Graph:
|
||||
def __init__(self, name:str, llm:LLM) -> None:
|
||||
self.name = name
|
||||
|
|
@ -26,6 +26,8 @@ class HumanEvalGraph(Graph):
|
|||
self.generate_code_block = GenerateCodeBlock(llm=llm)
|
||||
self.review = Review(llm=llm, criteria=criteria)
|
||||
self.revise = Revise(llm=llm)
|
||||
self.rephrase = Rephrase(llm=llm)
|
||||
self.tester = Test(llm=llm)
|
||||
self.fuensemble = FuEnsemble(llm=llm)
|
||||
self.mdensemble = MdEnsemble(llm=llm, vote_count=vote_count)
|
||||
|
||||
|
|
@ -41,10 +43,28 @@ class HumanEvalGraph(Graph):
|
|||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# solution list 有5个
|
||||
solution = await self.mdensemble("code", solution_list, problem)
|
||||
return solution
|
||||
|
||||
|
||||
async def alpha_codium(self, problem_id:str, problem:str, ensemble_count:int = 3):
|
||||
# async def __call__(self,problem_id, problem:str, ensemble_count:int = 3):
|
||||
test_cases = extract_test_cases_from_jsonl(problem_id)
|
||||
rephrase_problem = await self.rephrase(problem) # 在rephrase 中拼接原始的问题描述
|
||||
solution_list = []
|
||||
for _ in range(ensemble_count):
|
||||
for retry_count in range(5):
|
||||
try:
|
||||
solution = await self.generate_code_block(problem, rephrase_problem)
|
||||
solution = solution.get('code_solution')
|
||||
solution_list.append(solution)
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
solution = await self.mdensemble("code", solution_list, problem)
|
||||
solution = await self.tester(problem, rephrase_problem, solution, test_cases)
|
||||
return solution
|
||||
|
||||
async def review_revise_ensemble(self, problem:str, ensemble_count:int = 2):
|
||||
solution_list = []
|
||||
for _ in range(ensemble_count):
|
||||
|
|
@ -53,16 +73,16 @@ class HumanEvalGraph(Graph):
|
|||
solution = await self.ensemble(solution_list, problem)
|
||||
return solution
|
||||
|
||||
# async def simple_ensemble(self, problem:str, ensemble_count:int = 3):
|
||||
async def simple_ensemble(self, problem:str, ensemble_count:int = 3):
|
||||
# async def __call__(self, problem:str, ensemble_count:int = 3):
|
||||
# solution_list = []
|
||||
# for _ in range(ensemble_count):
|
||||
# solution = await self.generate_code(problem)
|
||||
# # solution = await self.generate_code_block(problem)
|
||||
# solution = solution.get('code_solution')
|
||||
# solution_list.append(solution)
|
||||
# solution = await self.fuensemble(solution_list, problem)
|
||||
# return solution
|
||||
solution_list = []
|
||||
for _ in range(ensemble_count):
|
||||
solution = await self.generate_code(problem)
|
||||
# solution = await self.generate_code_block(problem)
|
||||
solution = solution.get('code_solution')
|
||||
solution_list.append(solution)
|
||||
solution = await self.fuensemble(solution_list, problem)
|
||||
return solution
|
||||
|
||||
async def single_solve(self, problem:str, max_loop:int):
|
||||
solution = await self.generate_code(problem)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,11 @@ from collections import Counter
|
|||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
|
||||
from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, GenerateCodeBlockOp ,ReviewOp, ReviseOp, FuEnsembleOp, MdEnsembleOp
|
||||
from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, GENERATE_CODEBLOCK_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, FU_ENSEMBLE_PROMPT, MD_ENSEMBLE_PROMPT, DE_ENSEMBLE_ANGEL_PROMPT, DE_ENSEMBLE_DEVIL_PROMPT, DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT, DE_ENSEMBLE_JUDGE_FINAL_PROMPT
|
||||
from examples.ags.w_action_node.prompt import DE_ENSEMBLE_CODE_FORMAT_PROMPT, DE_ENSEMBLE_TXT_FORMAT_PROMPT
|
||||
from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, GenerateCodeBlockOp ,ReviewOp, ReviseOp, FuEnsembleOp, MdEnsembleOp, ReflectionTestOp, RephraseOp
|
||||
from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, GENERATE_CODEBLOCK_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, FU_ENSEMBLE_PROMPT, MD_ENSEMBLE_PROMPT, REFLECTION_ON_PUBILIC_TEST_PROMPT, REPHRASE_ON_PROBLEM_PROMPT, GENERATE_CODEBLOCK_REPHRASE_PROMPT
|
||||
from examples.ags.w_action_node.prompt import DE_ENSEMBLE_CODE_FORMAT_PROMPT, DE_ENSEMBLE_TXT_FORMAT_PROMPT, DE_ENSEMBLE_ANGEL_PROMPT, DE_ENSEMBLE_DEVIL_PROMPT, DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT, DE_ENSEMBLE_JUDGE_FINAL_PROMPT
|
||||
from examples.ags.w_action_node.utils import test_cases_2_test_functions
|
||||
|
||||
class Operator:
|
||||
def __init__(self, name, llm:LLM):
|
||||
self.name = name
|
||||
|
|
@ -47,9 +49,15 @@ class GenerateCodeBlock(Operator):
|
|||
def __init__(self, name:str ="Coder", llm: LLM = LLM()):
|
||||
super().__init__(name, llm)
|
||||
|
||||
async def __call__(self, problem_description):
|
||||
async def __call__(self, problem_description, function_name):
|
||||
prompt = GENERATE_CODEBLOCK_PROMPT.format(problem_description=problem_description)
|
||||
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm,mode='code_fill')
|
||||
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm, mode='code_fill',function_name=function_name)
|
||||
response = node.instruct_content.model_dump()
|
||||
return response
|
||||
|
||||
async def rephrase_generate(self, problem_description, rephrase_problem, function_name):
|
||||
prompt = GENERATE_CODEBLOCK_REPHRASE_PROMPT.format(problem_description=problem_description,rephrase_problem=rephrase_problem)
|
||||
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm, mode='code_fill', function_name=function_name)
|
||||
response = node.instruct_content.model_dump()
|
||||
return response
|
||||
|
||||
|
|
@ -153,7 +161,7 @@ class MdEnsemble(Operator):
|
|||
most_frequent_index = Counter(all_responses).most_common(1)[0][0]
|
||||
print(f"most frequent_index: {most_frequent_index}")
|
||||
final_answer = solutions[most_frequent_index]
|
||||
print(f"final answer: {final_answer}")
|
||||
print(f"final answer: \n{final_answer}")
|
||||
# final_answer, frequency = self.most_frequent(all_responses)
|
||||
return {"final_solution": final_answer}
|
||||
|
||||
|
|
@ -293,23 +301,68 @@ class DbEnsemble(Operator):
|
|||
|
||||
class Rephrase(Operator):
|
||||
"""
|
||||
|
||||
https://arxiv.org/abs/2404.14963
|
||||
1. AlphaCodium
|
||||
2. https://arxiv.org/abs/2404.14963
|
||||
"""
|
||||
pass
|
||||
def __init__(self, name:str ="Rephraser", llm: LLM = LLM()):
|
||||
super().__init__(name, llm)
|
||||
|
||||
async def __call__(self, problem_description:str)->str:
|
||||
prompt = REPHRASE_ON_PROBLEM_PROMPT.format(problem_description=problem_description)
|
||||
node = await ActionNode.from_pydantic(RephraseOp).fill(context=prompt, llm=self.llm)
|
||||
response = node.instruct_content.model_dump()
|
||||
return response["rephrased_problem"]
|
||||
|
||||
class Test(Operator):
|
||||
def __init__(self, name:str ="Tester", llm: LLM = LLM()):
|
||||
super().__init__(name, llm)
|
||||
|
||||
def test_cases_2_assert(self, test_cases):
|
||||
return f"assert {test_cases[0]}({test_cases[1]}) == {test_cases[2]} \n"
|
||||
|
||||
def exec_code(self, solution, test_cases):
|
||||
solution = solution["final_solution"]
|
||||
pass_case = []
|
||||
fail_case = []
|
||||
for test_case in test_cases:
|
||||
test_code = test_cases_2_test_functions(solution,test_case)
|
||||
try:
|
||||
exec(test_code)
|
||||
pass_case.append(self.test_cases_2_assert(test_case))
|
||||
except AssertionError as e:
|
||||
fail_case.append(self.test_cases_2_assert(test_case))
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return {"error":e}
|
||||
if fail_case != []:
|
||||
return fail_case
|
||||
return []
|
||||
|
||||
async def __call__(self, problem, rephrase_problem, solution, test_cases):
|
||||
result = self.exec_code(solution, test_cases)
|
||||
# 处理通过Public Tests的代码
|
||||
# TODO 这里的问题是,如果Test直接通过了就没有办法Check Multi Tests了
|
||||
if result == []:
|
||||
return solution
|
||||
# 处理代码执行失败的代码
|
||||
elif type(result) == dict:
|
||||
result = result["error"]
|
||||
prompt = REFLECTION_ON_PUBILIC_TEST_PROMPT.format(problem_description=problem, rephrase_problem=rephrase_problem, code_solution=solution, exec_pass=f"executed unsuccessfully, error: \n {result}", test_fail="executed unsucessfully")
|
||||
node = await ActionNode.from_pydantic(ReflectionTestOp).fill(context=prompt, llm=self.llm)
|
||||
response = node.instruct_content.model_dump()
|
||||
return {"final_solution":response["refined_solution"]}
|
||||
else:
|
||||
prompt = REFLECTION_ON_PUBILIC_TEST_PROMPT.format(problem_description=problem, rephrase_problem=rephrase_problem, code_solution=solution, exec_pass="executed successfully", test_fail=result)
|
||||
node = await ActionNode.from_pydantic(ReflectionTestOp).fill(context=prompt, llm=self.llm)
|
||||
response = node.instruct_content.model_dump()
|
||||
return {"final_solution":response["refined_solution"]}
|
||||
|
||||
class FindFact(Operator):
|
||||
pass
|
||||
|
||||
class SelfAsk(Operator):
|
||||
pass
|
||||
|
||||
class CodeReflection(Operator):
|
||||
"""
|
||||
Interpreter Part
|
||||
We run code here to get error information.
|
||||
"""
|
||||
|
||||
class Verify(Operator):
|
||||
"""
|
||||
? 还没有想好
|
||||
|
|
|
|||
|
|
@ -42,3 +42,15 @@ class MdEnsembleOp(BaseModel):
|
|||
description="The letter of the chosen best solution (only one letter)."
|
||||
)
|
||||
|
||||
class TestCaseExtractOp(BaseModel):
|
||||
test_cases: list = Field(default=[('<function name>', [5, 8, 7, 1], 12), ('<function name>', [3, 3, 3, 3, 3], 9)],
|
||||
description="Extracted test cases from the problem description")
|
||||
|
||||
class RephraseOp(BaseModel):
|
||||
rephrased_problem: str = Field(default="", description="Rephrased problem description for this problem")
|
||||
|
||||
class ReflectionTestOp(BaseModel):
|
||||
reflection: str = Field(default="", description="对关于代码执行错误或者测试用例失败step by step的思考")
|
||||
refined_solution: str = Field(default="", description="对于代码执行错误或者测试用例失败的修正方案")
|
||||
|
||||
|
||||
|
|
@ -3,11 +3,6 @@
|
|||
# @Author : didi
|
||||
# @Desc : prompts of operators
|
||||
|
||||
# TODO PromptBreeder 评分是怎么做的?
|
||||
# TODO 评估案例 GSM-8K 直接拿的DataSet
|
||||
#
|
||||
#
|
||||
|
||||
GENERATE_PROMPT = """
|
||||
Generate Solution for the following problem: {problem_description}
|
||||
"""
|
||||
|
|
@ -35,21 +30,35 @@ You are an expert programmer tasked with solving a coding problem.
|
|||
The above is an incomplete Python code fragment. Return the complete and correct code with no additional text.
|
||||
Please maintain the JSON format in your response.
|
||||
### Your Response:
|
||||
|
||||
"""
|
||||
GENERATE_CODEBLOCK_PROMPT = """
|
||||
You are an expert programmer tasked with solving a coding problem.
|
||||
|
||||
# GENERATE_CODEBLOCK_PROMPT = """
|
||||
# You are an expert programmer tasked with solving a coding problem.
|
||||
|
||||
# ### Problem Description:
|
||||
# {problem_description}
|
||||
|
||||
# ### Instructions:
|
||||
# The above is an incomplete Python code fragment. Return the complete and correct code with no additional text.
|
||||
# """
|
||||
|
||||
GENERATE_CODEBLOCK_REPHRASE_PROMPT = """
|
||||
You are given a code contest problem, and a self-reflection on the problem:
|
||||
|
||||
### Problem Description:
|
||||
{problem_description}
|
||||
|
||||
### Instructions:
|
||||
The above is an incomplete Python code fragment. Return the complete and correct code with no additional text.
|
||||
### self reflection on the problem
|
||||
{rephrase_problem}
|
||||
|
||||
=======================
|
||||
The above is an incomplete Python code fragment and reflection on it. Return the complete and correct code with no additional text.
|
||||
"""
|
||||
|
||||
# GENERATE_CODE_PROMPT = """
|
||||
# Generate Code Solution for the following problem: {problem_description}
|
||||
# """
|
||||
GENERATE_CODEBLOCK_PROMPT = """
|
||||
Please provide a self-contained Python script that solves the following problem in a markdown code block:
|
||||
{problem_description}
|
||||
"""
|
||||
|
||||
REVIEW_PROMPT = """
|
||||
For the question described as {problem_description},
|
||||
|
|
@ -135,3 +144,55 @@ You, as the moderator, will evaluate both sides' answers and determine if there
|
|||
Please strictly output in JSON format, do not output irrelevant content
|
||||
"""
|
||||
|
||||
EXTRACT_CASE_PROMPT = """
|
||||
You are given a coding problem, and you need to extract the test cases from the problem description.
|
||||
{problem_description}
|
||||
|
||||
一个problem中会有多个测试用例,每个测试用例包含三个部分:
|
||||
1. 函数名
|
||||
2. 输入
|
||||
3. 期望输出
|
||||
每个测试用例包裹在一个三元组之中,三元组之间用逗号分隔,整体用列表包裹。
|
||||
由于结果需要被解析到JSON中,True与False请表示为true, false;
|
||||
"""
|
||||
|
||||
REPHRASE_ON_PROBLEM_PROMPT = """
|
||||
You are given a code contest problem:
|
||||
|
||||
### problem
|
||||
{problem_description}
|
||||
|
||||
### instrcutions
|
||||
Given the code contest problem, Your Goal is:
|
||||
Reflect on the problem, and describe it in your own words, in bullet points. Pay attention to small details, nuances, notes and examples in the problem description.
|
||||
|
||||
"""
|
||||
|
||||
REFLECTION_ON_PUBILIC_TEST_PROMPT = """
|
||||
|
||||
You are given a code contest problem, and a self-reflection on the problem:
|
||||
### problem
|
||||
{problem_description}
|
||||
|
||||
### self reflection on the problem
|
||||
{rephrase_problem}
|
||||
|
||||
=======================
|
||||
A Python code solution was generated for the problem:
|
||||
### Code Solution
|
||||
{code_solution}
|
||||
|
||||
=======================
|
||||
This section of the code execution result is
|
||||
### Execution Result
|
||||
{exec_pass}
|
||||
|
||||
=======================
|
||||
However, when running the following input example, the code solution above failed to produce the expected output:
|
||||
#### Failed Test Case
|
||||
{test_fail}
|
||||
|
||||
Your goal is to analyze the code solution and the error, and propose a fixed code which will produce the expected output for the provided test input.
|
||||
The fixed code should keep the solution robust, and work for all other input examples as well.
|
||||
Make sure the fixed code has a reasonable runtime - less than three seconds on a modern computer, given the problem constraints for large input.
|
||||
"""
|
||||
|
|
@ -5,7 +5,11 @@
|
|||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Dict, Any
|
||||
from typing import List, Dict, Any, Tuple
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from examples.ags.w_action_node.operator_an import TestCaseExtractOp
|
||||
from examples.ags.w_action_node.prompt import EXTRACT_CASE_PROMPT
|
||||
|
||||
def extract_task_id(task_id: str) -> int:
|
||||
"""Extract the numeric part of the task_id."""
|
||||
|
|
@ -29,4 +33,168 @@ def jsonl_ranker(input_file: str, output_file: str):
|
|||
# Write the sorted data to a new JSONL file
|
||||
with open(output_file, 'w') as f:
|
||||
for item in sorted_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
f.write(json.dumps(item) + '\n')
|
||||
|
||||
# def extract_test_cases_from_jsonl(problem_id:str, file_path:str="public_test.jsonl"):
|
||||
# # TODO 这个JSONL效率有点神经病
|
||||
# if problem_id == "Humaneval/87":
|
||||
# return [ ["get_row", [[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 1, 6], [1, 2, 3, 4, 5, 1]], 1], [(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)]], ["get_row", [[], 1], []], ["get_row", [[[], [1], [1, 2, 3]], 3], [(2, 2)]] ]
|
||||
# elif problem_id == "Humaneval/95":
|
||||
# return [ ["check_dict_case", [{"a": "apple", "b": "banana"}], True], ["check_dict_case", [{"a": "apple", "A": "banana", "B": "banana"}], False], ["check_dict_case", [{"a": "apple", "8": "banana", "a": "apple"}], False], ["check_dict_case", [{"Name": "John", "Age": "36", "City": "Houston"}], False], ["check_dict_case", [{"STATE": "NC", "ZIP": "12345"}], True] ]
|
||||
# elif problem_id == "Humaneval/107":
|
||||
# return [ ["even_odd_palindrome", [3], (1, 2)], ["even_odd_palindrome", [12], (4, 6)] ]
|
||||
# elif problem_id == "Humaneval/112":
|
||||
# return [ ["reverse_delete", ["abcde", "ae"], ("bcd", False)], ["reverse_delete", ["abcdef", "b"], ("acdef", False)], ["reverse_delete", ["abcdedcba", "ab"], ("cdedc", True)] ]
|
||||
# elif problem_id == "Humaneval/127":
|
||||
# return [ ["intersection", [(1, 2), (2, 3)], "NO"], ["intersection", [(-1, 1), (0, 4)], "NO"], ["intersection", [(-3, -1), (-5, 5)], "YES"] ]
|
||||
# elif problem_id == "Humaneval/136":
|
||||
# return [ ["largest_smallest_integers", [2, 4, 1, 3, 5, 7], (None, 1)], ["largest_smallest_integers", [], (None, None)], ["largest_smallest_integers", [0], (None, None)] ]
|
||||
# elif problem_id == "Humaneval/148":
|
||||
# return [ ["bf", ["Jupiter", "Neptune"], ("Saturn", "Uranus")], ["bf", ["Earth", "Mercury"], ("Venus",)], ["bf", ["Mercury", "Uranus"], ("Venus", "Earth", "Mars", "Jupiter", "Saturn")], ["bf", ["InvalidPlanet", "Neptune"], ()], ["bf", ["Jupiter", "InvalidPlanet"], ()], ["bf", ["Mercury", "Mercury"], ()] ]
|
||||
# elif problem_id == "Humaneval/155":
|
||||
# return [ ["even_odd_count", [-12], (1, 1)], ["even_odd_count", [123], (1, 2)] ]
|
||||
|
||||
# with open(file_path, 'r') as file:
|
||||
# for line in file:
|
||||
# data = json.loads(line)
|
||||
# if problem_id in data:
|
||||
# return data[problem_id]
|
||||
|
||||
# return None
|
||||
|
||||
import json
|
||||
import ast
|
||||
|
||||
def parse_python_literal(s):
|
||||
try:
|
||||
return ast.literal_eval(s)
|
||||
except (ValueError, SyntaxError):
|
||||
return s
|
||||
|
||||
def extract_test_cases_from_jsonl(problem_id:str, file_path:str="public_test.jsonl"):
|
||||
# 保留原有的硬编码测试用例
|
||||
hardcoded_cases = {
|
||||
"HumanEval/87": [ ["get_row", [[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 1, 6], [1, 2, 3, 4, 5, 1]], 1], [(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)]], ["get_row", [[], 1], []], ["get_row", [[[], [1], [1, 2, 3]], 3], [(2, 2)]] ],
|
||||
"HumanEval/95": [ ["check_dict_case", [{"a": "apple", "b": "banana"}], True], ["check_dict_case", [{"a": "apple", "A": "banana", "B": "banana"}], False], ["check_dict_case", [{"a": "apple", "8": "banana", "a": "apple"}], False], ["check_dict_case", [{"Name": "John", "Age": "36", "City": "Houston"}], False], ["check_dict_case", [{"STATE": "NC", "ZIP": "12345"}], True] ],
|
||||
"HumanEval/107": [ ["even_odd_palindrome", [3], (1, 2)], ["even_odd_palindrome", [12], (4, 6)] ],
|
||||
"HumanEval/112": [ ["reverse_delete", ["abcde", "ae"], ("bcd", False)], ["reverse_delete", ["abcdef", "b"], ("acdef", False)], ["reverse_delete", ["abcdedcba", "ab"], ("cdedc", True)] ],
|
||||
"HumanEval/127": [ ["intersection", [(1, 2), (2, 3)], "NO"], ["intersection", [(-1, 1), (0, 4)], "NO"], ["intersection", [(-3, -1), (-5, 5)], "YES"] ],
|
||||
"HumanEval/136": [ ["largest_smallest_integers", [2, 4, 1, 3, 5, 7], (None, 1)], ["largest_smallest_integers", [], (None, None)], ["largest_smallest_integers", [0], (None, None)] ],
|
||||
"HumanEval/148": [ ["bf", ["Jupiter", "Neptune"], ("Saturn", "Uranus")], ["bf", ["Earth", "Mercury"], ("Venus",)], ["bf", ["Mercury", "Uranus"], ("Venus", "Earth", "Mars", "Jupiter", "Saturn")], ["bf", ["InvalidPlanet", "Neptune"], ()], ["bf", ["Jupiter", "InvalidPlanet"], ()], ["bf", ["Mercury", "Mercury"], ()] ],
|
||||
"HumanEval/155": [ ["even_odd_count", [-12], (1, 1)], ["even_odd_count", [123], (1, 2)] ]
|
||||
}
|
||||
|
||||
# 检查是否有硬编码的测试用例
|
||||
if problem_id in hardcoded_cases:
|
||||
return hardcoded_cases[problem_id]
|
||||
|
||||
# 如果没有硬编码的测试用例,从文件中读取
|
||||
with open(file_path, 'r') as file:
|
||||
for line in file:
|
||||
data = json.loads(line)
|
||||
if problem_id in data:
|
||||
problem_data = data[problem_id]
|
||||
# 处理测试用例
|
||||
for i, test_case in enumerate(problem_data):
|
||||
# 函数名保持不变
|
||||
# 参数列表需要解析
|
||||
test_case[1] = [parse_python_literal(arg) for arg in test_case[1]]
|
||||
# 预期输出需要解析
|
||||
test_case[2] = parse_python_literal(test_case[2])
|
||||
return problem_data
|
||||
|
||||
return None # 如果没有找到问题,返回 None
|
||||
|
||||
def extract_test_cases(docstring: str) -> List[Tuple[str, List[Any], Any]]:
|
||||
# 使用正则表达式匹配测试用例,现在捕获函数名和任意输出
|
||||
pattern = r'>>> (\w+)\((.*?)\)\n\s*(.*?)(?=\n|$)'
|
||||
matches = re.findall(pattern, docstring, re.DOTALL)
|
||||
|
||||
test_cases = []
|
||||
for match in matches:
|
||||
func_name, input_str, expected_output = match
|
||||
|
||||
# 处理输入
|
||||
input_list = []
|
||||
for item in input_str.split(','):
|
||||
item = item.strip()
|
||||
try:
|
||||
# 尝试将输入转换为数值类型
|
||||
if '.' in item:
|
||||
input_list.append(float(item))
|
||||
else:
|
||||
input_list.append(int(item))
|
||||
except ValueError:
|
||||
# 如果无法转换为数值,则保留为字符串
|
||||
input_list.append(item.strip("'\""))
|
||||
|
||||
# 处理输出
|
||||
try:
|
||||
# 尝试将输出转换为数值或布尔值
|
||||
if expected_output.lower() == 'true':
|
||||
expected_output = True
|
||||
elif expected_output.lower() == 'false':
|
||||
expected_output = False
|
||||
elif '.' in expected_output:
|
||||
expected_output = float(expected_output)
|
||||
else:
|
||||
expected_output = int(expected_output)
|
||||
except ValueError:
|
||||
# 如果无法转换,则保留为字符串
|
||||
expected_output = expected_output.strip("'\"")
|
||||
|
||||
test_cases.append([func_name, input_list, expected_output])
|
||||
|
||||
return test_cases
|
||||
|
||||
|
||||
async def llm_extract_test_case(id, problem_description: str, file_path:str="public_test.jsonl"):
|
||||
prompt = EXTRACT_CASE_PROMPT.format(problem_description=problem_description)
|
||||
node = await ActionNode.from_pydantic(TestCaseExtractOp).fill(context=prompt, llm=LLM())
|
||||
result = node.instruct_content.model_dump()
|
||||
with open(file_path,"a") as f:
|
||||
f.write(json.dumps({id:result["test_cases"]}) + '\n')
|
||||
return {id:result["test_cases"]}
|
||||
|
||||
import json
|
||||
|
||||
def test_cases_2_test_functions(solution: str, test_case: List):
|
||||
function_name = test_case[0]
|
||||
|
||||
def format_param(param):
|
||||
if isinstance(param, str):
|
||||
return repr(param)
|
||||
elif isinstance(param, (int, float, bool)):
|
||||
return str(param)
|
||||
elif isinstance(param, list):
|
||||
return '[' + ', '.join(format_param(item) for item in param) + ']'
|
||||
elif isinstance(param, tuple):
|
||||
return '(' + ', '.join(format_param(item) for item in param) + ')'
|
||||
elif isinstance(param, dict):
|
||||
return '{' + ', '.join(f'{format_param(k)}: {format_param(v)}' for k, v in param.items()) + '}'
|
||||
elif isinstance(param, type(None)):
|
||||
return 'None'
|
||||
else:
|
||||
raise ValueError(f"Unsupported parameter type: {type(param)}")
|
||||
|
||||
parameters = ', '.join(format_param(item) for item in test_case[1])
|
||||
print(type(test_case[2]), test_case[2])
|
||||
expected_output = format_param(test_case[2])
|
||||
print(expected_output)
|
||||
|
||||
tester_function = f"""
|
||||
{solution}
|
||||
|
||||
def check(candidate):
|
||||
assert candidate({parameters}) == {expected_output}
|
||||
|
||||
check({function_name})
|
||||
"""
|
||||
|
||||
print(f"""
|
||||
Generated test function:
|
||||
{tester_function}
|
||||
""")
|
||||
|
||||
return tester_function
|
||||
|
||||
44
he_test.py
44
he_test.py
|
|
@ -1,20 +1,46 @@
|
|||
import asyncio
|
||||
import json
|
||||
from metagpt.llm import LLM
|
||||
from evalplus.data import get_human_eval_plus, write_jsonl
|
||||
from examples.ags.benchmark.humaneval import sample_generate, samples_generate, extract_failure_tests, automatic_evalplus
|
||||
from examples.ags.w_action_node.utils import jsonl_ranker
|
||||
|
||||
from examples.ags.w_action_node.utils import jsonl_ranker, llm_extract_test_case
|
||||
from examples.ags.w_action_node.graph import HumanEvalGraph
|
||||
# 132 141 136 80 73
|
||||
# asyncio.run(sample_generate('HumanEval/118',result_path="llm_based_4.jsonl",mode="llm"))
|
||||
# asyncio.run(samples_generate(mode='ags',result_path="ags_based_1.jsonl"))
|
||||
# asyncio.run(sample_generate('HumanEval/118',result_path="llm_based_8.jsonl",mode="llm"))
|
||||
asyncio.run(samples_generate(mode='llm',result_path="llm_based_100.jsonl"))
|
||||
# jsonl_ranker("samples.jsonl", "samples.jsonl")
|
||||
|
||||
result_path = "ags_based_2.jsonl"
|
||||
if automatic_evalplus(result_path):
|
||||
unpassed_exapmle = extract_failure_tests(result_path[:-6]+"_eval_results.json")
|
||||
print(unpassed_exapmle)
|
||||
# result_path = "ags_based_6.jsonl"
|
||||
# if automatic_evalplus(result_path):
|
||||
# unpassed_exapmle = extract_failure_tests(result_path[:-6]+"_eval_results.json")
|
||||
# print(unpassed_exapmle)
|
||||
|
||||
# unpassed_exapmle = extract_failure_tests(file_path="2_eval_results.json")
|
||||
# print(unpassed_exapmle)
|
||||
|
||||
# for example in failure_list:
|
||||
# asyncio.run(sample_generate(example))
|
||||
# asyncio.run(sample_generate(example))
|
||||
|
||||
# TODO 抽取Public Test没搞完,先用几个测试跑一下流程
|
||||
# from evalplus.data import get_human_eval_plus
|
||||
|
||||
# id_list = [87, 95, 107, 112, 127, 136, 148, 155]
|
||||
# id_list = [155]
|
||||
# cases_id = [f"HumanEval/{case_id}" for case_id in id_list]
|
||||
# cases = {case_id: get_human_eval_plus()[case_id]['prompt'] for case_id in cases_id}
|
||||
# async def main(cases):
|
||||
# try:
|
||||
# tasks = [llm_extract_test_case(case_id, case) for case_id, case in cases.items()]
|
||||
# results = await asyncio.gather(*tasks)
|
||||
# except:
|
||||
# failed_tasks = [task_id for task_id in results if task_id is not None]
|
||||
# print(failed_tasks)
|
||||
# return results
|
||||
|
||||
# asyncio.run(main(cases))
|
||||
|
||||
# [72, 80, 82, 87, 90, 95, 107, 109, 112, 124, 126, 127, 128, 132, 134, 136, 137, 138, 148, 154, 155]
|
||||
|
||||
# case_prompt= get_human_eval_plus()["HumanEval/136"]['prompt']
|
||||
# solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability', vote_count=1)
|
||||
# result = asyncio.run(solver.alpha_codium(problem_id="HumanEval/136", problem=case_prompt, ensemble_count=1))
|
||||
|
|
@ -17,6 +17,7 @@ from pydantic import BaseModel, Field, create_model, model_validator
|
|||
from tenacity import retry, stop_after_attempt, wait_random_exponential
|
||||
|
||||
from metagpt.actions.action_outcls_registry import register_action_outcls
|
||||
from metagpt.actions.code_sanitize import sanitize
|
||||
from metagpt.const import USE_CONFIG_TIMEOUT
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
|
|
@ -484,6 +485,7 @@ class ActionNode:
|
|||
async def code_fill(
|
||||
self,
|
||||
context,
|
||||
function_name=None,
|
||||
timeout=USE_CONFIG_TIMEOUT
|
||||
):
|
||||
"""
|
||||
|
|
@ -510,10 +512,10 @@ class ActionNode:
|
|||
import re
|
||||
field_name = self.get_field_name()
|
||||
prompt = context
|
||||
# prompt += "\nPlease wrap the generated code within triple backticks, like this: ```<code>```"
|
||||
content = await self.llm.aask(prompt, timeout=timeout)
|
||||
|
||||
extracted_code = extract_code_from_response(content)
|
||||
# TODO 在前置逻辑中完成entrypoint的提取就可以
|
||||
extracted_code = sanitize(code=content, entrypoint=function_name)
|
||||
# extracted_code = extract_code_from_response(content)
|
||||
result = {field_name: extracted_code}
|
||||
return result
|
||||
|
||||
|
|
@ -536,6 +538,7 @@ class ActionNode:
|
|||
images: Optional[Union[str, list[str]]] = None,
|
||||
timeout=USE_CONFIG_TIMEOUT,
|
||||
exclude=[],
|
||||
function_name: str = None
|
||||
):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
|
|
@ -563,7 +566,7 @@ class ActionNode:
|
|||
schema = self.schema
|
||||
|
||||
if mode == self.MODE_CODE_FILL:
|
||||
result = await self.code_fill(context, timeout)
|
||||
result = await self.code_fill(context, function_name, timeout)
|
||||
self.instruct_content = self.create_class()(**result)
|
||||
return self
|
||||
|
||||
|
|
|
|||
167
metagpt/actions/code_sanitize.py
Normal file
167
metagpt/actions/code_sanitize.py
Normal file
|
|
@ -0,0 +1,167 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
@Time : 2024/7/24 16:37
|
||||
@Author : didi
|
||||
@File : code_node.py
|
||||
"""
|
||||
import os
|
||||
import ast
|
||||
import pathlib
|
||||
import traceback
|
||||
|
||||
from typing import Dict, Generator, List, Optional, Set, Tuple
|
||||
|
||||
import tree_sitter_python
|
||||
from tqdm import tqdm
|
||||
from tree_sitter import Language, Node, Parser
|
||||
|
||||
CLASS_TYPE = "class_definition"
|
||||
FUNCTION_TYPE = "function_definition"
|
||||
IMPORT_TYPE = ["import_statement", "import_from_statement"]
|
||||
IDENTIFIER_TYPE = "identifier"
|
||||
ATTRIBUTE_TYPE = "attribute"
|
||||
RETURN_TYPE = "return_statement"
|
||||
EXPRESSION_TYPE = "expression_statement"
|
||||
ASSIGNMENT_TYPE = "assignment"
|
||||
|
||||
def traverse_tree(node: Node) -> Generator[Node, None, None]:
|
||||
cursor = node.walk()
|
||||
depth = 0
|
||||
|
||||
visited_children = False
|
||||
while True:
|
||||
if not visited_children:
|
||||
yield cursor.node
|
||||
if not cursor.goto_first_child():
|
||||
depth += 1
|
||||
visited_children = True
|
||||
elif cursor.goto_next_sibling():
|
||||
visited_children = False
|
||||
elif not cursor.goto_parent() or depth == 0:
|
||||
break
|
||||
else:
|
||||
depth -= 1
|
||||
|
||||
def syntax_check(code, verbose=False):
|
||||
try:
|
||||
ast.parse(code)
|
||||
return True
|
||||
except (SyntaxError, MemoryError):
|
||||
if verbose:
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
def code_extract(text: str) -> str:
|
||||
lines = text.split("\n")
|
||||
longest_line_pair = (0, 0)
|
||||
longest_so_far = 0
|
||||
|
||||
for i in range(len(lines)):
|
||||
for j in range(i + 1, len(lines)):
|
||||
current_lines = "\n".join(lines[i : j + 1])
|
||||
if syntax_check(current_lines):
|
||||
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
|
||||
if current_length > longest_so_far:
|
||||
longest_so_far = current_length
|
||||
longest_line_pair = (i, j)
|
||||
|
||||
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
|
||||
|
||||
def get_definition_name(node: Node) -> str:
|
||||
for child in node.children:
|
||||
if child.type == IDENTIFIER_TYPE:
|
||||
return child.text.decode("utf8")
|
||||
|
||||
def has_return_statement(node: Node) -> bool:
|
||||
traverse_nodes = traverse_tree(node)
|
||||
for node in traverse_nodes:
|
||||
if node.type == RETURN_TYPE:
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
|
||||
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
|
||||
for child in node.children:
|
||||
if child.type == IDENTIFIER_TYPE:
|
||||
deps.add(child.text.decode("utf8"))
|
||||
else:
|
||||
dfs_get_deps(child, deps)
|
||||
|
||||
name2deps = {}
|
||||
for name, node in nodes:
|
||||
deps = set()
|
||||
dfs_get_deps(node, deps)
|
||||
name2deps[name] = deps
|
||||
return name2deps
|
||||
|
||||
|
||||
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
|
||||
queue = [entrypoint]
|
||||
visited = {entrypoint}
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
if current not in call_graph:
|
||||
continue
|
||||
for neighbour in call_graph[current]:
|
||||
if not (neighbour in visited):
|
||||
visited.add(neighbour)
|
||||
queue.append(neighbour)
|
||||
return visited
|
||||
|
||||
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
|
||||
code = code_extract(code)
|
||||
code_bytes = bytes(code, "utf8")
|
||||
parser = Parser(Language(tree_sitter_python.language()))
|
||||
tree = parser.parse(code_bytes)
|
||||
class_names = set()
|
||||
function_names = set()
|
||||
variable_names = set()
|
||||
|
||||
root_node = tree.root_node
|
||||
import_nodes = []
|
||||
definition_nodes = []
|
||||
|
||||
for child in root_node.children:
|
||||
if child.type in IMPORT_TYPE:
|
||||
import_nodes.append(child)
|
||||
elif child.type == CLASS_TYPE:
|
||||
name = get_definition_name(child)
|
||||
if not (
|
||||
name in class_names or name in variable_names or name in function_names
|
||||
):
|
||||
definition_nodes.append((name, child))
|
||||
class_names.add(name)
|
||||
elif child.type == FUNCTION_TYPE:
|
||||
name = get_definition_name(child)
|
||||
if not (
|
||||
name in function_names or name in variable_names or name in class_names
|
||||
) and has_return_statement(child):
|
||||
definition_nodes.append((name, child))
|
||||
function_names.add(get_definition_name(child))
|
||||
elif (
|
||||
child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE
|
||||
):
|
||||
subchild = child.children[0]
|
||||
name = get_definition_name(subchild)
|
||||
if not (
|
||||
name in variable_names or name in function_names or name in class_names
|
||||
):
|
||||
definition_nodes.append((name, subchild))
|
||||
variable_names.add(name)
|
||||
|
||||
if entrypoint:
|
||||
name2deps = get_deps(definition_nodes)
|
||||
reacheable = get_function_dependency(entrypoint, name2deps)
|
||||
|
||||
sanitized_output = b""
|
||||
|
||||
for node in import_nodes:
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
|
||||
for pair in definition_nodes:
|
||||
name, node = pair
|
||||
if entrypoint and not (name in reacheable):
|
||||
continue
|
||||
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
|
||||
return sanitized_output[:-1].decode("utf8")
|
||||
Loading…
Add table
Add a link
Reference in a new issue