This commit is contained in:
didi 2024-07-25 10:47:17 +08:00
parent ca1c8f8c5c
commit 772d2aea56
9 changed files with 583 additions and 65 deletions

View file

@ -17,7 +17,6 @@ from examples.ags.w_action_node.operator import GenerateCode, GenerateCodeBlock
generate_code = GenerateCode(llm=LLM())
generate_code_block = GenerateCodeBlock(llm=LLM())
solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability', vote_count=5)
async def sample_generate(id, result_path:str="samples.jsonl",mode:str="ags"):
@ -25,7 +24,10 @@ async def sample_generate(id, result_path:str="samples.jsonl",mode:str="ags"):
if mode == "ags":
solution_result = await solver(case['prompt'],ensemble_count=5)
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
else:
elif mode == "alpha":
solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5)
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
elif mode == "llm":
solution_result = await generate_code_block(case['prompt'])
sample_dict = dict(task_id=case['task_id'], solution=solution_result['code_solution'])
with open(result_path, mode='a') as f:
@ -39,7 +41,7 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
async def solve_and_write(case, mode):
try:
if mode == 'llm':
solution_result = await generate_code_block(case['prompt'])
solution_result = await generate_code_block(problem_description=case['prompt'], function_name=case['entry_point'])
# solution_result = await generate_code(case['prompt'])
sample_dict = {
'task_id': case['task_id'],
@ -51,7 +53,13 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
'task_id': case['task_id'],
'solution': solution_result['final_solution']
}
elif mode == "alpha":
solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5)
sample_dict = {
'task_id': case['task_id'],
'solution': solution_result['final_solution']
}
# TODO 解决 final_solution 问题之后就可以开始正式测评了
async with file_lock:
async with aiofiles.open(result_path, mode='a') as f:
await f.write(json.dumps(sample_dict) + '\n')
@ -65,7 +73,6 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
results = await asyncio.gather(*tasks)
failed_tasks = [task_id for task_id in results if task_id is not None]
# TODO 这个地方还是不够自动化
if failed_tasks:
print(failed_tasks)
if mode == 'llm':
@ -73,7 +80,7 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
case = get_human_eval_plus()[task_id]
for _ in range(3):
try:
solution_result = await generate_code_block(case['prompt'])
solution_result = await generate_code_block(case['prompt'],function_name=case['entry_point'])
task_dict = {
'task_id': case['task_id'],
'solution': solution_result['code_solution']
@ -84,17 +91,18 @@ async def samples_generate(mode:str, result_path:str="samples.jsonl"):
break
except Exception as e:
print(f"{e} \n failure {task_id}")
elif mode == "ags":
elif mode == "ags" or mode == "alpha":
for task_id in failed_tasks:
try:
await sample_generate(task_id,result_path)
await sample_generate(task_id,result_path,mode)
except Exception as e:
print(f"failure {task_id}")
jsonl_ranker(result_path, result_path)
if not failed_tasks:
# 自动 sanitize
result_path = automatic_sanitize(result_path)
# result_path = automatic_sanitize(result_path)
if automatic_evalplus(result_path):
eval_path = result_path[:-6]+"_eval_results.json"
unpassed_exapmle = extract_failure_tests(eval_path)
@ -107,7 +115,7 @@ async def samples_generate_ags():
cases = list(get_human_eval_plus().values())
async def solve_with_id(case):
solution_result = await solver(case['prompt'], ensemble_count=3)
solution_result = await solver(case['prompt'], ensemble_count=5)
return case['task_id'], solution_result['final_solution']
tasks = [solve_with_id(case) for case in cases]

View file

@ -5,8 +5,8 @@
from metagpt.llm import LLM
from typing import List
from examples.ags.w_action_node.operator import Generate, GenerateCode, GenerateCodeBlock, Review, Revise, FuEnsemble, MdEnsemble, DbEnsemble
from examples.ags.w_action_node.operator import Generate, GenerateCode, GenerateCodeBlock, Review, Revise, FuEnsemble, MdEnsemble, DbEnsemble, Rephrase, Test
from examples.ags.w_action_node.utils import extract_test_cases_from_jsonl
class Graph:
def __init__(self, name:str, llm:LLM) -> None:
self.name = name
@ -26,6 +26,8 @@ class HumanEvalGraph(Graph):
self.generate_code_block = GenerateCodeBlock(llm=llm)
self.review = Review(llm=llm, criteria=criteria)
self.revise = Revise(llm=llm)
self.rephrase = Rephrase(llm=llm)
self.tester = Test(llm=llm)
self.fuensemble = FuEnsemble(llm=llm)
self.mdensemble = MdEnsemble(llm=llm, vote_count=vote_count)
@ -41,10 +43,28 @@ class HumanEvalGraph(Graph):
break
except Exception as e:
print(e)
# solution list 有5个
solution = await self.mdensemble("code", solution_list, problem)
return solution
async def alpha_codium(self, problem_id:str, problem:str, ensemble_count:int = 3):
# async def __call__(self,problem_id, problem:str, ensemble_count:int = 3):
test_cases = extract_test_cases_from_jsonl(problem_id)
rephrase_problem = await self.rephrase(problem) # 在rephrase 中拼接原始的问题描述
solution_list = []
for _ in range(ensemble_count):
for retry_count in range(5):
try:
solution = await self.generate_code_block(problem, rephrase_problem)
solution = solution.get('code_solution')
solution_list.append(solution)
break
except Exception as e:
print(e)
solution = await self.mdensemble("code", solution_list, problem)
solution = await self.tester(problem, rephrase_problem, solution, test_cases)
return solution
async def review_revise_ensemble(self, problem:str, ensemble_count:int = 2):
solution_list = []
for _ in range(ensemble_count):
@ -53,16 +73,16 @@ class HumanEvalGraph(Graph):
solution = await self.ensemble(solution_list, problem)
return solution
# async def simple_ensemble(self, problem:str, ensemble_count:int = 3):
async def simple_ensemble(self, problem:str, ensemble_count:int = 3):
# async def __call__(self, problem:str, ensemble_count:int = 3):
# solution_list = []
# for _ in range(ensemble_count):
# solution = await self.generate_code(problem)
# # solution = await self.generate_code_block(problem)
# solution = solution.get('code_solution')
# solution_list.append(solution)
# solution = await self.fuensemble(solution_list, problem)
# return solution
solution_list = []
for _ in range(ensemble_count):
solution = await self.generate_code(problem)
# solution = await self.generate_code_block(problem)
solution = solution.get('code_solution')
solution_list.append(solution)
solution = await self.fuensemble(solution_list, problem)
return solution
async def single_solve(self, problem:str, max_loop:int):
solution = await self.generate_code(problem)

View file

@ -10,9 +10,11 @@ from collections import Counter
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, GenerateCodeBlockOp ,ReviewOp, ReviseOp, FuEnsembleOp, MdEnsembleOp
from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, GENERATE_CODEBLOCK_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, FU_ENSEMBLE_PROMPT, MD_ENSEMBLE_PROMPT, DE_ENSEMBLE_ANGEL_PROMPT, DE_ENSEMBLE_DEVIL_PROMPT, DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT, DE_ENSEMBLE_JUDGE_FINAL_PROMPT
from examples.ags.w_action_node.prompt import DE_ENSEMBLE_CODE_FORMAT_PROMPT, DE_ENSEMBLE_TXT_FORMAT_PROMPT
from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, GenerateCodeBlockOp ,ReviewOp, ReviseOp, FuEnsembleOp, MdEnsembleOp, ReflectionTestOp, RephraseOp
from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, GENERATE_CODEBLOCK_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, FU_ENSEMBLE_PROMPT, MD_ENSEMBLE_PROMPT, REFLECTION_ON_PUBILIC_TEST_PROMPT, REPHRASE_ON_PROBLEM_PROMPT, GENERATE_CODEBLOCK_REPHRASE_PROMPT
from examples.ags.w_action_node.prompt import DE_ENSEMBLE_CODE_FORMAT_PROMPT, DE_ENSEMBLE_TXT_FORMAT_PROMPT, DE_ENSEMBLE_ANGEL_PROMPT, DE_ENSEMBLE_DEVIL_PROMPT, DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT, DE_ENSEMBLE_JUDGE_FINAL_PROMPT
from examples.ags.w_action_node.utils import test_cases_2_test_functions
class Operator:
def __init__(self, name, llm:LLM):
self.name = name
@ -47,9 +49,15 @@ class GenerateCodeBlock(Operator):
def __init__(self, name:str ="Coder", llm: LLM = LLM()):
super().__init__(name, llm)
async def __call__(self, problem_description):
async def __call__(self, problem_description, function_name):
prompt = GENERATE_CODEBLOCK_PROMPT.format(problem_description=problem_description)
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm,mode='code_fill')
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm, mode='code_fill',function_name=function_name)
response = node.instruct_content.model_dump()
return response
async def rephrase_generate(self, problem_description, rephrase_problem, function_name):
prompt = GENERATE_CODEBLOCK_REPHRASE_PROMPT.format(problem_description=problem_description,rephrase_problem=rephrase_problem)
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm, mode='code_fill', function_name=function_name)
response = node.instruct_content.model_dump()
return response
@ -153,7 +161,7 @@ class MdEnsemble(Operator):
most_frequent_index = Counter(all_responses).most_common(1)[0][0]
print(f"most frequent_index: {most_frequent_index}")
final_answer = solutions[most_frequent_index]
print(f"final answer: {final_answer}")
print(f"final answer: \n{final_answer}")
# final_answer, frequency = self.most_frequent(all_responses)
return {"final_solution": final_answer}
@ -293,23 +301,68 @@ class DbEnsemble(Operator):
class Rephrase(Operator):
"""
https://arxiv.org/abs/2404.14963
1. AlphaCodium
2. https://arxiv.org/abs/2404.14963
"""
pass
def __init__(self, name:str ="Rephraser", llm: LLM = LLM()):
super().__init__(name, llm)
async def __call__(self, problem_description:str)->str:
prompt = REPHRASE_ON_PROBLEM_PROMPT.format(problem_description=problem_description)
node = await ActionNode.from_pydantic(RephraseOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return response["rephrased_problem"]
class Test(Operator):
def __init__(self, name:str ="Tester", llm: LLM = LLM()):
super().__init__(name, llm)
def test_cases_2_assert(self, test_cases):
return f"assert {test_cases[0]}({test_cases[1]}) == {test_cases[2]} \n"
def exec_code(self, solution, test_cases):
solution = solution["final_solution"]
pass_case = []
fail_case = []
for test_case in test_cases:
test_code = test_cases_2_test_functions(solution,test_case)
try:
exec(test_code)
pass_case.append(self.test_cases_2_assert(test_case))
except AssertionError as e:
fail_case.append(self.test_cases_2_assert(test_case))
except Exception as e:
print(e)
return {"error":e}
if fail_case != []:
return fail_case
return []
async def __call__(self, problem, rephrase_problem, solution, test_cases):
result = self.exec_code(solution, test_cases)
# 处理通过Public Tests的代码
# TODO 这里的问题是如果Test直接通过了就没有办法Check Multi Tests了
if result == []:
return solution
# 处理代码执行失败的代码
elif type(result) == dict:
result = result["error"]
prompt = REFLECTION_ON_PUBILIC_TEST_PROMPT.format(problem_description=problem, rephrase_problem=rephrase_problem, code_solution=solution, exec_pass=f"executed unsuccessfully, error: \n {result}", test_fail="executed unsucessfully")
node = await ActionNode.from_pydantic(ReflectionTestOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return {"final_solution":response["refined_solution"]}
else:
prompt = REFLECTION_ON_PUBILIC_TEST_PROMPT.format(problem_description=problem, rephrase_problem=rephrase_problem, code_solution=solution, exec_pass="executed successfully", test_fail=result)
node = await ActionNode.from_pydantic(ReflectionTestOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return {"final_solution":response["refined_solution"]}
class FindFact(Operator):
pass
class SelfAsk(Operator):
pass
class CodeReflection(Operator):
"""
Interpreter Part
We run code here to get error information.
"""
class Verify(Operator):
"""
? 还没有想好

View file

@ -42,3 +42,15 @@ class MdEnsembleOp(BaseModel):
description="The letter of the chosen best solution (only one letter)."
)
class TestCaseExtractOp(BaseModel):
test_cases: list = Field(default=[('<function name>', [5, 8, 7, 1], 12), ('<function name>', [3, 3, 3, 3, 3], 9)],
description="Extracted test cases from the problem description")
class RephraseOp(BaseModel):
rephrased_problem: str = Field(default="", description="Rephrased problem description for this problem")
class ReflectionTestOp(BaseModel):
reflection: str = Field(default="", description="对关于代码执行错误或者测试用例失败step by step的思考")
refined_solution: str = Field(default="", description="对于代码执行错误或者测试用例失败的修正方案")

View file

@ -3,11 +3,6 @@
# @Author : didi
# @Desc : prompts of operators
# TODO PromptBreeder 评分是怎么做的?
# TODO 评估案例 GSM-8K 直接拿的DataSet
#
#
GENERATE_PROMPT = """
Generate Solution for the following problem: {problem_description}
"""
@ -35,21 +30,35 @@ You are an expert programmer tasked with solving a coding problem.
The above is an incomplete Python code fragment. Return the complete and correct code with no additional text.
Please maintain the JSON format in your response.
### Your Response:
"""
GENERATE_CODEBLOCK_PROMPT = """
You are an expert programmer tasked with solving a coding problem.
# GENERATE_CODEBLOCK_PROMPT = """
# You are an expert programmer tasked with solving a coding problem.
# ### Problem Description:
# {problem_description}
# ### Instructions:
# The above is an incomplete Python code fragment. Return the complete and correct code with no additional text.
# """
GENERATE_CODEBLOCK_REPHRASE_PROMPT = """
You are given a code contest problem, and a self-reflection on the problem:
### Problem Description:
{problem_description}
### Instructions:
The above is an incomplete Python code fragment. Return the complete and correct code with no additional text.
### self reflection on the problem
{rephrase_problem}
=======================
The above is an incomplete Python code fragment and reflection on it. Return the complete and correct code with no additional text.
"""
# GENERATE_CODE_PROMPT = """
# Generate Code Solution for the following problem: {problem_description}
# """
GENERATE_CODEBLOCK_PROMPT = """
Please provide a self-contained Python script that solves the following problem in a markdown code block:
{problem_description}
"""
REVIEW_PROMPT = """
For the question described as {problem_description},
@ -135,3 +144,55 @@ You, as the moderator, will evaluate both sides' answers and determine if there
Please strictly output in JSON format, do not output irrelevant content
"""
EXTRACT_CASE_PROMPT = """
You are given a coding problem, and you need to extract the test cases from the problem description.
{problem_description}
一个problem中会有多个测试用例每个测试用例包含三个部分
1. 函数名
2. 输入
3. 期望输出
每个测试用例包裹在一个三元组之中三元组之间用逗号分隔整体用列表包裹
由于结果需要被解析到JSON中True与False请表示为true, false;
"""
REPHRASE_ON_PROBLEM_PROMPT = """
You are given a code contest problem:
### problem
{problem_description}
### instrcutions
Given the code contest problem, Your Goal is:
Reflect on the problem, and describe it in your own words, in bullet points. Pay attention to small details, nuances, notes and examples in the problem description.
"""
REFLECTION_ON_PUBILIC_TEST_PROMPT = """
You are given a code contest problem, and a self-reflection on the problem:
### problem
{problem_description}
### self reflection on the problem
{rephrase_problem}
=======================
A Python code solution was generated for the problem:
### Code Solution
{code_solution}
=======================
This section of the code execution result is
### Execution Result
{exec_pass}
=======================
However, when running the following input example, the code solution above failed to produce the expected output:
#### Failed Test Case
{test_fail}
Your goal is to analyze the code solution and the error, and propose a fixed code which will produce the expected output for the provided test input.
The fixed code should keep the solution robust, and work for all other input examples as well.
Make sure the fixed code has a reasonable runtime - less than three seconds on a modern computer, given the problem constraints for large input.
"""

View file

@ -5,7 +5,11 @@
import json
import re
from typing import List, Dict, Any
from typing import List, Dict, Any, Tuple
from metagpt.llm import LLM
from metagpt.actions.action_node import ActionNode
from examples.ags.w_action_node.operator_an import TestCaseExtractOp
from examples.ags.w_action_node.prompt import EXTRACT_CASE_PROMPT
def extract_task_id(task_id: str) -> int:
"""Extract the numeric part of the task_id."""
@ -29,4 +33,168 @@ def jsonl_ranker(input_file: str, output_file: str):
# Write the sorted data to a new JSONL file
with open(output_file, 'w') as f:
for item in sorted_data:
f.write(json.dumps(item) + '\n')
f.write(json.dumps(item) + '\n')
# def extract_test_cases_from_jsonl(problem_id:str, file_path:str="public_test.jsonl"):
# # TODO 这个JSONL效率有点神经病
# if problem_id == "Humaneval/87":
# return [ ["get_row", [[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 1, 6], [1, 2, 3, 4, 5, 1]], 1], [(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)]], ["get_row", [[], 1], []], ["get_row", [[[], [1], [1, 2, 3]], 3], [(2, 2)]] ]
# elif problem_id == "Humaneval/95":
# return [ ["check_dict_case", [{"a": "apple", "b": "banana"}], True], ["check_dict_case", [{"a": "apple", "A": "banana", "B": "banana"}], False], ["check_dict_case", [{"a": "apple", "8": "banana", "a": "apple"}], False], ["check_dict_case", [{"Name": "John", "Age": "36", "City": "Houston"}], False], ["check_dict_case", [{"STATE": "NC", "ZIP": "12345"}], True] ]
# elif problem_id == "Humaneval/107":
# return [ ["even_odd_palindrome", [3], (1, 2)], ["even_odd_palindrome", [12], (4, 6)] ]
# elif problem_id == "Humaneval/112":
# return [ ["reverse_delete", ["abcde", "ae"], ("bcd", False)], ["reverse_delete", ["abcdef", "b"], ("acdef", False)], ["reverse_delete", ["abcdedcba", "ab"], ("cdedc", True)] ]
# elif problem_id == "Humaneval/127":
# return [ ["intersection", [(1, 2), (2, 3)], "NO"], ["intersection", [(-1, 1), (0, 4)], "NO"], ["intersection", [(-3, -1), (-5, 5)], "YES"] ]
# elif problem_id == "Humaneval/136":
# return [ ["largest_smallest_integers", [2, 4, 1, 3, 5, 7], (None, 1)], ["largest_smallest_integers", [], (None, None)], ["largest_smallest_integers", [0], (None, None)] ]
# elif problem_id == "Humaneval/148":
# return [ ["bf", ["Jupiter", "Neptune"], ("Saturn", "Uranus")], ["bf", ["Earth", "Mercury"], ("Venus",)], ["bf", ["Mercury", "Uranus"], ("Venus", "Earth", "Mars", "Jupiter", "Saturn")], ["bf", ["InvalidPlanet", "Neptune"], ()], ["bf", ["Jupiter", "InvalidPlanet"], ()], ["bf", ["Mercury", "Mercury"], ()] ]
# elif problem_id == "Humaneval/155":
# return [ ["even_odd_count", [-12], (1, 1)], ["even_odd_count", [123], (1, 2)] ]
# with open(file_path, 'r') as file:
# for line in file:
# data = json.loads(line)
# if problem_id in data:
# return data[problem_id]
# return None
import json
import ast
def parse_python_literal(s):
try:
return ast.literal_eval(s)
except (ValueError, SyntaxError):
return s
def extract_test_cases_from_jsonl(problem_id:str, file_path:str="public_test.jsonl"):
# 保留原有的硬编码测试用例
hardcoded_cases = {
"HumanEval/87": [ ["get_row", [[[1, 2, 3, 4, 5, 6], [1, 2, 3, 4, 1, 6], [1, 2, 3, 4, 5, 1]], 1], [(0, 0), (1, 4), (1, 0), (2, 5), (2, 0)]], ["get_row", [[], 1], []], ["get_row", [[[], [1], [1, 2, 3]], 3], [(2, 2)]] ],
"HumanEval/95": [ ["check_dict_case", [{"a": "apple", "b": "banana"}], True], ["check_dict_case", [{"a": "apple", "A": "banana", "B": "banana"}], False], ["check_dict_case", [{"a": "apple", "8": "banana", "a": "apple"}], False], ["check_dict_case", [{"Name": "John", "Age": "36", "City": "Houston"}], False], ["check_dict_case", [{"STATE": "NC", "ZIP": "12345"}], True] ],
"HumanEval/107": [ ["even_odd_palindrome", [3], (1, 2)], ["even_odd_palindrome", [12], (4, 6)] ],
"HumanEval/112": [ ["reverse_delete", ["abcde", "ae"], ("bcd", False)], ["reverse_delete", ["abcdef", "b"], ("acdef", False)], ["reverse_delete", ["abcdedcba", "ab"], ("cdedc", True)] ],
"HumanEval/127": [ ["intersection", [(1, 2), (2, 3)], "NO"], ["intersection", [(-1, 1), (0, 4)], "NO"], ["intersection", [(-3, -1), (-5, 5)], "YES"] ],
"HumanEval/136": [ ["largest_smallest_integers", [2, 4, 1, 3, 5, 7], (None, 1)], ["largest_smallest_integers", [], (None, None)], ["largest_smallest_integers", [0], (None, None)] ],
"HumanEval/148": [ ["bf", ["Jupiter", "Neptune"], ("Saturn", "Uranus")], ["bf", ["Earth", "Mercury"], ("Venus",)], ["bf", ["Mercury", "Uranus"], ("Venus", "Earth", "Mars", "Jupiter", "Saturn")], ["bf", ["InvalidPlanet", "Neptune"], ()], ["bf", ["Jupiter", "InvalidPlanet"], ()], ["bf", ["Mercury", "Mercury"], ()] ],
"HumanEval/155": [ ["even_odd_count", [-12], (1, 1)], ["even_odd_count", [123], (1, 2)] ]
}
# 检查是否有硬编码的测试用例
if problem_id in hardcoded_cases:
return hardcoded_cases[problem_id]
# 如果没有硬编码的测试用例,从文件中读取
with open(file_path, 'r') as file:
for line in file:
data = json.loads(line)
if problem_id in data:
problem_data = data[problem_id]
# 处理测试用例
for i, test_case in enumerate(problem_data):
# 函数名保持不变
# 参数列表需要解析
test_case[1] = [parse_python_literal(arg) for arg in test_case[1]]
# 预期输出需要解析
test_case[2] = parse_python_literal(test_case[2])
return problem_data
return None # 如果没有找到问题,返回 None
def extract_test_cases(docstring: str) -> List[Tuple[str, List[Any], Any]]:
# 使用正则表达式匹配测试用例,现在捕获函数名和任意输出
pattern = r'>>> (\w+)\((.*?)\)\n\s*(.*?)(?=\n|$)'
matches = re.findall(pattern, docstring, re.DOTALL)
test_cases = []
for match in matches:
func_name, input_str, expected_output = match
# 处理输入
input_list = []
for item in input_str.split(','):
item = item.strip()
try:
# 尝试将输入转换为数值类型
if '.' in item:
input_list.append(float(item))
else:
input_list.append(int(item))
except ValueError:
# 如果无法转换为数值,则保留为字符串
input_list.append(item.strip("'\""))
# 处理输出
try:
# 尝试将输出转换为数值或布尔值
if expected_output.lower() == 'true':
expected_output = True
elif expected_output.lower() == 'false':
expected_output = False
elif '.' in expected_output:
expected_output = float(expected_output)
else:
expected_output = int(expected_output)
except ValueError:
# 如果无法转换,则保留为字符串
expected_output = expected_output.strip("'\"")
test_cases.append([func_name, input_list, expected_output])
return test_cases
async def llm_extract_test_case(id, problem_description: str, file_path:str="public_test.jsonl"):
prompt = EXTRACT_CASE_PROMPT.format(problem_description=problem_description)
node = await ActionNode.from_pydantic(TestCaseExtractOp).fill(context=prompt, llm=LLM())
result = node.instruct_content.model_dump()
with open(file_path,"a") as f:
f.write(json.dumps({id:result["test_cases"]}) + '\n')
return {id:result["test_cases"]}
import json
def test_cases_2_test_functions(solution: str, test_case: List):
function_name = test_case[0]
def format_param(param):
if isinstance(param, str):
return repr(param)
elif isinstance(param, (int, float, bool)):
return str(param)
elif isinstance(param, list):
return '[' + ', '.join(format_param(item) for item in param) + ']'
elif isinstance(param, tuple):
return '(' + ', '.join(format_param(item) for item in param) + ')'
elif isinstance(param, dict):
return '{' + ', '.join(f'{format_param(k)}: {format_param(v)}' for k, v in param.items()) + '}'
elif isinstance(param, type(None)):
return 'None'
else:
raise ValueError(f"Unsupported parameter type: {type(param)}")
parameters = ', '.join(format_param(item) for item in test_case[1])
print(type(test_case[2]), test_case[2])
expected_output = format_param(test_case[2])
print(expected_output)
tester_function = f"""
{solution}
def check(candidate):
assert candidate({parameters}) == {expected_output}
check({function_name})
"""
print(f"""
Generated test function:
{tester_function}
""")
return tester_function