This commit is contained in:
didi 2024-09-25 16:46:20 +08:00
parent 6a84a9d49b
commit 8dfe2de34c
3 changed files with 73 additions and 51 deletions

View file

@ -47,13 +47,13 @@ class Evaluator:
elif dataset == "MATH":
return self._math_eval(graph, params, path, is_test)
elif dataset == "HumanEval":
return self._humaneval_eval(graph, params, is_test)
return self._humaneval_eval(graph, params, path, is_test)
elif dataset == "HotpotQA":
return self._hotpotqa_eval(graph, params, is_test)
return self._hotpotqa_eval(graph, params, path, is_test)
elif dataset == "MBPP":
return self._mbpp_eval(graph, params, is_test)
return self._mbpp_eval(graph, params, path, is_test)
elif dataset == "DROP":
return self._drop_eval(graph, params, is_test)
return self._drop_eval(graph, params, path, is_test)
# def graph_evaluate(self, dataset: DatasetType, graph, params: dict, path):
# """
@ -154,7 +154,7 @@ class Evaluator:
va_list = [0]
else:
data_path = "examples/ags/data/human-eval_validate.jsonl" # 替换为您的JSONL文件路径
va_list = [0]
va_list = None
graph = await load_graph()

View file

@ -12,7 +12,9 @@ from subprocess import PIPE, Popen, TimeoutExpired
from typing import Dict, List, Tuple
import concurrent.futures
import threading
from tenacity import retry, stop_after_attempt, wait_fixed
from examples.ags.scripts.utils import extract_test_cases_from_jsonl
from examples.ags.scripts.operator_an import (
CodeGenerateOp,
@ -371,10 +373,13 @@ class Rephrase(Operator):
class Test(Operator):
def __init__(self, name: str = "Test", llm: LLM = LLM()):
def __init__(self, llm, name: str = "Test"):
super().__init__(name, llm)
def exec_code(self, solution, test_cases, problem_id, entry_point):
def exec_code(self, solution, entry_point):
test_cases = extract_test_cases_from_jsonl(entry_point)
fail_cases = []
for test_case in test_cases:
test_code = test_case_2_test_function(solution, test_case, entry_point)
@ -384,7 +389,7 @@ class Test(Operator):
exc_type, exc_value, exc_traceback = sys.exc_info()
tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
with open("tester.txt", "a") as f:
f.write("test_error" + problem_id + "\n")
f.write("test_error of " + entry_point + "\n")
error_infomation = {
"test_fail_case": {
"test_case": test_case,
@ -397,7 +402,7 @@ class Test(Operator):
logger.info(f"test error: {error_infomation}")
except Exception as e:
with open("tester.txt", "a") as f:
f.write(problem_id + "\n")
f.write(entry_point + "\n")
return {"exec_fail_case": str(e)}
if fail_cases != []:
return fail_cases
@ -405,19 +410,23 @@ class Test(Operator):
return "no error"
async def __call__(
self, problem_id, problem, rephrase_problem, solution, test_cases, entry_point, test_loop: int = 3
self, problem, solution, entry_point, test_loop: int = 3
):
solution = solution["final_solution"]
"""
"Test": {
"description": "Test the solution with test cases, if the solution is correct, return 'no error', if the solution is incorrect, return reflect on the soluion and the error information",
"interface": "test(problem: str, solution: str, entry_point: str) -> str"
}
"""
for _ in range(test_loop):
result = self.exec_code(solution, test_cases, problem_id, entry_point)
result = self.exec_code(solution, problem, entry_point)
if result == "no error":
return {"final_solution": solution}
return {"result": True, "solution": solution}
elif "exec_fail_case" in result:
result = result["exec_fail_case"]
prompt = REFLECTION_ON_PUBLIC_TEST_PROMPT.format(
problem_description=problem,
rephrase_problem=rephrase_problem,
code_solution=solution,
problem=problem,
solution=solution,
exec_pass=f"executed unsuccessfully, error: \n {result}",
test_fail="executed unsucessfully",
)
@ -426,9 +435,8 @@ class Test(Operator):
solution = response["refined_solution"]
else:
prompt = REFLECTION_ON_PUBLIC_TEST_PROMPT.format(
problem_description=problem,
rephrase_problem=rephrase_problem,
code_solution=solution,
problem=problem,
solution=solution,
exec_pass="executed successfully",
test_fail=result,
)
@ -442,7 +450,7 @@ class Programmer(Operator):
def __init__(self, llm: LLM, name: str = "Programmer"):
super().__init__(name, llm)
async def exec_code(self, code, timeout=180):
async def exec_code(code, timeout=180):
def run_code():
try:
# 创建一个新的全局命名空间
@ -461,13 +469,29 @@ class Programmer(Operator):
exc_type, exc_value, exc_traceback = sys.exc_info()
tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
return "Error", f"执行错误: {str(e)}\n{''.join(tb_str)}"
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_code)
# 创建一个事件来标记任务完成
done_event = threading.Event()
result = ["Error", "执行无结果,子进程异常"]
def wrapper():
nonlocal result
result = run_code()
done_event.set()
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(wrapper)
try:
return future.result(timeout=timeout)
except concurrent.futures.TimeoutError:
return "Error", "代码执行超时"
# 等待任务完成或超时
if done_event.wait(timeout=timeout):
return result
else:
# 超时,尝试取消任务
future.cancel()
return "Error", "代码执行超时"
finally:
# 确保线程池被正确关闭
executor.shutdown(wait=False)
async def code_generate(self, problem, analysis, feedback, mode):
prompt = PYTHON_CODE_VERIFIER_PROMPT.format(problem=problem, analysis=analysis, feedback=feedback)