Update Action Graph Solver Version 0.1

This commit is contained in:
didi 2024-07-31 18:19:11 +08:00
parent 5446c7e490
commit 686b1cd130
14 changed files with 10286 additions and 0 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -0,0 +1,70 @@
# -*- coding: utf-8 -*-
# @Date :
# @Author : issac
# @Desc : test on gsm8k
import json
import re
import os
# 读取原始数据集
def read_jsonl(path: str):
with open(path) as fh:
return [json.loads(line) for line in fh.readlines() if line]
# 和图/和基础模型直接交互得到答案
def LLM(question):
answer = ""
# 这里就是输入问题question返回答案answer
# answer = 根据question生成的回答
return answer
def gsm_extract_answer(completion):
ANS_RE = re.compile(r"#### (\-?[0-9\.\,]+)")
INVALID_ANS = "[invalid]"
match = ANS_RE.search(completion)
if match:
match_str = match.group(1).strip()
match_str = match_str.replace(",", "")
return match_str
else:
return INVALID_ANS
def gsm_is_correct(data):
INVALID_ANS = "[invalid]"
gt_answer = gsm_extract_answer(data["answer"])
assert gt_answer != INVALID_ANS
return gsm_extract_answer(data["answer_llm"]) == gt_answer
# 提取数据集并得到测试答案
def get_examples(split):
path = os.path.join("", f"{split}.jsonl")
output_path = "gsm8k_generate.jsonl"
examples = read_jsonl(path)
processed_examples = [] # 用于存储处理后的样本
for ex in examples:
answer_llm = LLM(ex['question'])
ex['answer_llm'] = answer_llm
ex['is_correct'] = gsm_is_correct(ex)
# 将处理后的样本添加到列表中
processed_examples.append(ex)
# 将处理后的样本写入到新的 JSONL 文件
with open(output_path, 'w', encoding='utf-8') as f:
for example in processed_examples:
# 将字典转换为 JSON 格式的字符串,并写入新行
json_line = json.dumps(example) + '\n'
f.write(json_line)
print(f"{len(examples)} {split} examples")
return examples
if __name__ == "__main__":
example = get_examples("gsm")
print(example[:5])

View file

@ -0,0 +1,155 @@
# -*- coding: utf-8 -*-
# @Date :
# @Author : issac
# @Desc : test on hotpotqa
import sys
import json
import re
import string
from collections import Counter
import pickle
def normalize_answer(s):
def remove_articles(text):
return re.sub(r'\b(a|an|the)\b', ' ', text)
def white_space_fix(text):
return ' '.join(text.split())
def remove_punc(text):
exclude = set(string.punctuation)
return ''.join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def f1_score(prediction, ground_truth):
normalized_prediction = normalize_answer(prediction)
normalized_ground_truth = normalize_answer(ground_truth)
ZERO_METRIC = (0, 0, 0)
if normalized_prediction in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
if normalized_ground_truth in ['yes', 'no', 'noanswer'] and normalized_prediction != normalized_ground_truth:
return ZERO_METRIC
prediction_tokens = normalized_prediction.split()
ground_truth_tokens = normalized_ground_truth.split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return ZERO_METRIC
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1, precision, recall
def exact_match_score(prediction, ground_truth):
return (normalize_answer(prediction) == normalize_answer(ground_truth))
def update_answer(metrics, prediction, gold):
em = exact_match_score(prediction, gold)
f1, prec, recall = f1_score(prediction, gold)
metrics['em'] += float(em)
metrics['f1'] += f1
metrics['prec'] += prec
metrics['recall'] += recall
return em, prec, recall
def update_sp(metrics, prediction, gold):
cur_sp_pred = set(map(tuple, prediction))
gold_sp_pred = set(map(tuple, gold))
tp, fp, fn = 0, 0, 0
for e in cur_sp_pred:
if e in gold_sp_pred:
tp += 1
else:
fp += 1
for e in gold_sp_pred:
if e not in cur_sp_pred:
fn += 1
prec = 1.0 * tp / (tp + fp) if tp + fp > 0 else 0.0
recall = 1.0 * tp / (tp + fn) if tp + fn > 0 else 0.0
f1 = 2 * prec * recall / (prec + recall) if prec + recall > 0 else 0.0
em = 1.0 if fp + fn == 0 else 0.0
metrics['sp_em'] += em
metrics['sp_f1'] += f1
metrics['sp_prec'] += prec
metrics['sp_recall'] += recall
return em, prec, recall
def LLM(question):
answer = ""
# 这里就是输入问题question返回答案answer
# answer = 根据question生成的回答
return answer
def eval(prediction_file, gold_file):
with open(prediction_file) as f:
prediction = json.load(f)
with open(gold_file) as f:
gold = json.load(f)
metrics = {'em': 0, 'f1': 0, 'prec': 0, 'recall': 0,
'sp_em': 0, 'sp_f1': 0, 'sp_prec': 0, 'sp_recall': 0,
'joint_em': 0, 'joint_f1': 0, 'joint_prec': 0, 'joint_recall': 0}
for dp in gold:
cur_id = dp['_id']
can_eval_joint = True
if cur_id not in prediction['answer']:
print('missing answer {}'.format(cur_id))
can_eval_joint = False
else:
em, prec, recall = update_answer(
metrics, prediction['answer'][cur_id], dp['answer'])
N = len(gold)
for k in metrics.keys():
metrics[k] /= N
print(metrics)
def LLM(question):
answer = question
# 这里就是输入问题question返回答案answer
# answer = 根据question生成的回答
return answer
def answer(prediction_file, gold_file):
with open(gold_file) as f:
gold = json.load(f)
# 初始化预测字典,包含 answer 和 sp 两个键,初始为空字典
prediction = {'answer': {}}
for dp in gold:
cur_id = dp['_id']
paragraphs = [item[1] for item in dp['context'] if isinstance(item[1], list)] # 确保 item[1] 是列表
# 将所有文本段落连接成一个字符串
context_str = "\n".join(" ".join(paragraph) for paragraph in paragraphs)
question = dp['question']
# 构建输入字符串
input_llm = f"question{question}\n\ncontext{context_str}"
# 假设 LLM 是一个函数,返回模型的预测答案
response = LLM(input_llm)
# 将预测答案存储在字典中,键为 cur_id
prediction['answer'][cur_id] = response
# 将预测结果写入文件
with open(prediction_file, 'w') as f:
json.dump(prediction, f)
if __name__ == '__main__':
answer('hotpot_pre.json', 'your path here')
eval('hotpot_pre.json', 'your path here')

View file

@ -0,0 +1,171 @@
# -*- coding: utf-8 -*-
# @Date : 7/7/2024 17:07 PM
# @Author : didi
# @Desc : test on human eval graph
import os
import json
import subprocess
import sys
import asyncio
import aiofiles
from metagpt.llm import LLM
from evalplus.data import get_human_eval_plus
from examples.ags.w_action_node.utils import jsonl_ranker
from examples.ags.w_action_node.graph import HumanEvalGraph
from examples.ags.w_action_node.operator import GenerateCode, GenerateCodeBlock
generate_code = GenerateCode(llm=LLM())
generate_code_block = GenerateCodeBlock(llm=LLM())
solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability', vote_count=5)
async def sample_generate(id, result_path:str="samples.jsonl",mode:str="ags"):
case = get_human_eval_plus()[f"{id}"]
if mode == "ags":
solution_result = await solver(case['prompt'],ensemble_count=5)
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
elif mode == "alpha":
solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5)
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
elif mode == "llm":
solution_result = await generate_code_block(case['prompt'],case['entry_point'])
sample_dict = dict(task_id=case['task_id'], solution=solution_result['code_solution'])
print(sample_dict)
with open(result_path, mode='a') as f:
f.write(json.dumps(sample_dict) + '\n')
jsonl_ranker(result_path, result_path)
async def samples_generate(mode:str, result_path:str="samples.jsonl"):
cases = list(get_human_eval_plus().values())
file_lock = asyncio.Lock()
async def solve_and_write(case, mode):
try:
if mode == 'llm':
solution_result = await generate_code_block(problem_description=case['prompt'], function_name=case['entry_point'])
# solution_result = await generate_code(case['prompt'])
sample_dict = {
'task_id': case['task_id'],
'solution': solution_result['code_solution']
}
elif mode == "ags":
solution_result = await solver(case['prompt'], ensemble_count=5)
sample_dict = {
'task_id': case['task_id'],
'solution': solution_result['final_solution']
}
elif mode == "alpha":
solution_result = await solver.alpha_codium(case['task_id'], case['prompt'], ensemble_count=5)
sample_dict = {
'task_id': case['task_id'],
'solution': solution_result['final_solution']
}
# TODO 解决 final_solution 问题之后就可以开始正式测评了
async with file_lock:
async with aiofiles.open(result_path, mode='a') as f:
await f.write(json.dumps(sample_dict) + '\n')
return None
except Exception as e:
print(e)
return case['task_id']
tasks = [solve_and_write(case, mode) for case in cases]
results = await asyncio.gather(*tasks)
failed_tasks = [task_id for task_id in results if task_id is not None]
if failed_tasks:
print(failed_tasks)
if mode == 'llm':
for task_id in failed_tasks:
case = get_human_eval_plus()[task_id]
for _ in range(3):
try:
solution_result = await generate_code_block(case['prompt'],function_name=case['entry_point'])
task_dict = {
'task_id': case['task_id'],
'solution': solution_result['code_solution']
}
with open(result_path, mode='a') as f:
f.write(json.dumps(task_dict) + '\n')
failed_tasks.remove(task_id)
break
except Exception as e:
print(f"{e} \n failure {task_id}")
elif mode == "ags" or mode == "alpha":
for task_id in failed_tasks:
try:
await sample_generate(task_id,result_path,mode)
except Exception as e:
print(f"failure {task_id}")
jsonl_ranker(result_path, result_path)
if not failed_tasks:
# 自动 sanitize
# result_path = automatic_sanitize(result_path)
if automatic_evalplus(result_path):
eval_path = result_path[:-6]+"_eval_results.json"
unpassed_exapmle = extract_failure_tests(eval_path)
print(unpassed_exapmle)
else:
print(failed_tasks)
def automatic_sanitize(result_path: str = "samples.jsonl"):
"""
在命令行中自动执行 evalplus.sanitize --samples result_path
返回result_path前缀加上"-sanitized.jsonl"
"""
command = ["evalplus.sanitize", "--samples", result_path]
try:
subprocess.run(command, check=True)
except subprocess.CalledProcessError as e:
print(f"执行命令时出错: {e}")
return None
# 构建sanitized文件路径
base_name = os.path.splitext(result_path)[0]
sanitized_path = f"{base_name}-sanitized.jsonl"
return sanitized_path
def automatic_evalplus(result_path:str ="samples.jsonl"):
"""
在命令行中自动执行 evalplus.evaluate --dataset humaneval --samples samples.jsonl --parallel 2 --base-only
"""
command = [
sys.executable, # 使用当前 Python 解释器
"-m",
"evalplus.evaluate",
"--dataset", "humaneval",
"--samples", result_path,
"--parallel", "2",
"--base-only"
]
try:
result = subprocess.run(command, check=True, capture_output=True, text=True)
print("输出:", result.stdout)
return True
except subprocess.CalledProcessError as e:
print("错误输出:", e.stderr)
return False
def extract_failure_tests(file_path:str = "samples_eval_results.json"):
with open(file_path, 'r') as f:
task_results = json.load(f)
failed_tests = []
for task in task_results['eval'].values():
if task[0]["base_status"] == "fail":
failed_test = {
"task_id": task[0]["task_id"],
# "solution": task["solution"],
# "fail_tests": task["base_fail_tests"]
}
failed_tests.append(failed_test)
print(len(failed_tests))
return failed_tests

View file

@ -0,0 +1,115 @@
# -*- coding: utf-8 -*-
# @Date : 6/27/2024 22:07 PM
# @Author : didi
# @Desc : graph & an instance - humanevalgraph
from metagpt.llm import LLM
from typing import List
from examples.ags.w_action_node.operator import Generate, GenerateCode, GenerateCodeBlock, Review, Revise, FuEnsemble, MdEnsemble, DbEnsemble, Rephrase, Test
from examples.ags.w_action_node.utils import extract_test_cases_from_jsonl
from evalplus.data import get_human_eval_plus
class Graph:
def __init__(self, name:str, llm:LLM) -> None:
self.name = name
self.model = llm
def __call__():
NotImplementedError("Subclasses must implement __call__ method")
def optimize(dataset:List):
pass
class HumanEvalGraph(Graph):
def __init__(self, name:str, llm: LLM, criteria:str, vote_count:int =5) -> None:
super().__init__(name, llm)
self.criteria = criteria # TODO 自动构建图时,图的初始参数与图所使用的算子要求的外部参数相匹配
self.generate_code = GenerateCode(llm=llm)
self.generate_code_block = GenerateCodeBlock(llm=llm)
self.review = Review(llm=llm, criteria=criteria)
self.revise = Revise(llm=llm)
self.rephrase = Rephrase(llm=llm)
self.tester = Test(llm=llm)
self.fuensemble = FuEnsemble(llm=llm)
self.mdensemble = MdEnsemble(llm=llm, vote_count=vote_count)
async def __call__(self, problem:str, ensemble_count:int = 3):
solution_list = []
for _ in range(ensemble_count):
for retry_count in range(5):
try:
# solution = await self.generate_code(problem)
solution = await self.generate_code_block(problem)
solution = solution.get('code_solution')
solution_list.append(solution)
break
except Exception as e:
print(e)
solution = await self.mdensemble("code", solution_list, problem)
return solution
async def alpha_codium(self, problem_id:str, problem:str, ensemble_count:int = 3):
test_cases = extract_test_cases_from_jsonl(problem_id)
entry_point = get_human_eval_plus()[problem_id]['entry_point']
rephrase_problem = await self.rephrase(problem) # 在rephrase 中拼接原始的问题描述
solution_list = []
for _ in range(ensemble_count):
for retry_count in range(5):
try:
solution = await self.generate_code_block.rephrase_generate(problem, rephrase_problem, function_name=entry_point)
solution = solution.get('code_solution')
solution_list.append(solution)
break
except Exception as e:
print(e)
solution = await self.mdensemble("code", solution_list, problem)
solution = await self.tester(problem_id, problem, rephrase_problem, solution, test_cases)
return solution
async def review_revise_ensemble(self, problem:str, ensemble_count:int = 2):
solution_list = []
for _ in range(ensemble_count):
solution = await self.single_solve(problem, 3)
solution_list.append(solution)
solution = await self.ensemble(solution_list, problem)
return solution
async def simple_ensemble(self, problem:str, ensemble_count:int = 3):
solution_list = []
for _ in range(ensemble_count):
solution = await self.generate_code(problem)
# solution = await self.generate_code_block(problem)
solution = solution.get('code_solution')
solution_list.append(solution)
solution = await self.fuensemble(solution_list, problem)
return solution
async def single_solve(self, problem:str, max_loop:int):
solution = await self.generate_code(problem)
solution = solution.get('code_solution')
for _ in range(max_loop):
review_feedback = await self.review(problem, solution)
if review_feedback['review_result']:
break
solution = await self.revise(problem, solution, review_feedback['feedback'])
solution = solution.get('revised_solution')
return solution
class Gsm8kGraph(Graph):
def __init__(self, name:str, llm: LLM) -> None:
super().__init__(name, llm)
self.generate = Generate(llm=llm)
self.rephrase = Rephrase(llm=llm)
async def __call__(self, problem:str):
solution = self.generate(problem)
return solution
class HotpotQAGraph(Graph):
def __init__(self, name:str, llm: LLM) -> None:
super().__init__(name, llm)
self.generate = Generate(llm=llm)
self.rephrase = Rephrase(llm=llm)
async def __call__(self, problem:str):
solution = self.generate(problem)
return solution

View file

@ -0,0 +1,370 @@
# -*- coding: utf-8 -*-
# @Date : 6/27/2024 17:36 PM
# @Author : didi
# @Desc : operator demo of ags
import ast
import sys
import traceback
import random
from typing import List, Tuple, Any, Dict
from collections import Counter
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, GenerateCodeBlockOp ,ReviewOp, ReviseOp, FuEnsembleOp, MdEnsembleOp, ReflectionTestOp, RephraseOp
from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, GENERATE_CODEBLOCK_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, FU_ENSEMBLE_PROMPT, MD_ENSEMBLE_PROMPT, REFLECTION_ON_PUBILIC_TEST_PROMPT, REPHRASE_ON_PROBLEM_PROMPT, GENERATE_CODEBLOCK_REPHRASE_PROMPT
from examples.ags.w_action_node.prompt import DE_ENSEMBLE_CODE_FORMAT_PROMPT, DE_ENSEMBLE_TXT_FORMAT_PROMPT, DE_ENSEMBLE_ANGEL_PROMPT, DE_ENSEMBLE_DEVIL_PROMPT, DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT, DE_ENSEMBLE_JUDGE_FINAL_PROMPT
from examples.ags.w_action_node.utils import test_cases_2_test_functions
class Operator:
def __init__(self, name, llm:LLM):
self.name = name
self.llm = llm
def __call__(self, *args, **kwargs):
raise NotImplementedError
class Generate(Operator):
def __init__(self, name:str ="Generator", llm: LLM = LLM()):
super().__init__(name, llm)
async def __call__(self, problem_description):
prompt = GENERATE_PROMPT.format(problem_description=problem_description)
node = await ActionNode.from_pydantic(GenerateOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return response
class GenerateCode(Operator):
def __init__(self, name:str ="Coder", llm: LLM = LLM()):
super().__init__(name, llm)
async def __call__(self, problem_description):
prompt = GENERATE_CODE_PROMPT.format(problem_description=problem_description)
node = await ActionNode.from_pydantic(GenerateCodeOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return response
class GenerateCodeBlock(Operator):
def __init__(self, name:str ="Coder", llm: LLM = LLM()):
super().__init__(name, llm)
async def __call__(self, problem_description, function_name):
prompt = GENERATE_CODEBLOCK_PROMPT.format(problem_description=problem_description)
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm, mode='code_fill',function_name=function_name)
response = node.instruct_content.model_dump()
return response
async def rephrase_generate(self, problem_description, rephrase_problem, function_name):
prompt = GENERATE_CODEBLOCK_REPHRASE_PROMPT.format(problem_description=problem_description,rephrase_problem=rephrase_problem)
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm, mode='code_fill', function_name=function_name)
response = node.instruct_content.model_dump()
return response
class Review(Operator):
def __init__(self, criteria, name:str ="Reviewer", llm: LLM = LLM()):
self.criteria = criteria
super().__init__(name, llm)
async def __call__(self, problem_description, solution):
prompt = REVIEW_PROMPT.format(problem_description=problem_description, solution=solution, criteria=self.criteria)
node = await ActionNode.from_pydantic(ReviewOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return response
class Revise(Operator):
def __init__(self, name:str ="Reviser", llm: LLM = LLM()):
super().__init__(name, llm)
async def __call__(self, problem_description, solution, feedback):
prompt = REVISE_PROMPT.format(problem_description=problem_description, solution=solution, feedback=feedback)
node = await ActionNode.from_pydantic(ReviseOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return response
class FuEnsemble(Operator):
def __init__(self, name:str ="FuseEnsembler", llm: LLM = LLM()):
super().__init__(name, llm)
async def __call__(self, solutions:List, problem_description):
solution_text = ""
for solution in solutions:
solution_text += str(solution) + "\n"
prompt = FU_ENSEMBLE_PROMPT.format(solutions=solution_text, problem_description=problem_description)
node = await ActionNode.from_pydantic(FuEnsembleOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return response
class MdEnsemble(Operator):
def __init__(self, name:str ="MedEnsembler", llm: LLM = LLM(), vote_count:int=3):
super().__init__(name, llm)
self.vote_count = vote_count
@staticmethod
def shuffle_answers(solutions: List[str]) -> Tuple[List[str], Dict[str, str]]:
shuffled_solutions = solutions.copy()
random.shuffle(shuffled_solutions)
answer_mapping = {chr(65 + i): solutions.index(solution) for i, solution in enumerate(shuffled_solutions)}
return shuffled_solutions, answer_mapping
async def __call__(self, solution_type:str ,solutions:List[str], problem_description:str):
print(solutions)
all_responses = []
# 当Ensmeble方案是Code类型时我们使用AST进行去重
if solution_type == "code":
unique_structures = {}
updated_solutions = []
for solution in solutions:
try:
tree = ast.parse(solution)
structure_key = ast.dump(tree, annotate_fields=False, include_attributes=False)
if structure_key not in unique_structures:
unique_structures[structure_key] = solution
updated_solutions.append(solution)
except SyntaxError:
# If the solution has a syntax error, we'll skip it
continue
solutions = updated_solutions
updated_length = len(solutions)
if updated_length == 1:
return {"final_solution": solutions[0]}
for _ in range(self.vote_count):
shuffled_solutions, answer_mapping = self.shuffle_answers(solutions)
solution_text = ""
for index, solution in enumerate(shuffled_solutions):
solution_text += f"{chr(65 + index)}: \n{str(solution)}\n\n\n"
prompt = MD_ENSEMBLE_PROMPT.format(solutions=solution_text, problem_description=problem_description)
node = await ActionNode.from_pydantic(MdEnsembleOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
answer = response.get('solution_letter', '')
answer = answer.strip().upper()
if answer in answer_mapping:
original_index = answer_mapping[answer]
# print(f"original index: {original_index}")
all_responses.append(original_index)
most_frequent_index = Counter(all_responses).most_common(1)[0][0]
final_answer = solutions[most_frequent_index]
return {"final_solution": final_answer}
class ScEnsemble(Operator):
"""
self consistency ensemble
"""
pass
class DbEnsemble(Operator):
"""
(Should we be going MAD? A Look at Multi-Agent Debate Strategies for LLMs)
The system is a multi-round debate system where each agent is given the
question and responses generated by all agents. For each round, a judge
analyzes the responses provided determines whether to terminate the
debate or keep going. At the end of the debate the judge is also responsible
for determining the final answer.
"""
def __init__(self, name:str ="DebateEnsemble", llm: LLM = LLM()):
super().__init__(name, llm)
self.agents = ["angel","devil","judge"]
self.format_requirements = {
"txt":DE_ENSEMBLE_TXT_FORMAT_PROMPT,
"code":DE_ENSEMBLE_CODE_FORMAT_PROMPT
}
def get_system_prompt(self, name:str, mode:str='txt'):
if name == "angel":
if mode == "code":
return DE_ENSEMBLE_ANGEL_PROMPT + "\n" + DE_ENSEMBLE_CODE_FORMAT_PROMPT
return DE_ENSEMBLE_ANGEL_PROMPT + "\n" + DE_ENSEMBLE_TXT_FORMAT_PROMPT
elif name == "devil":
if mode == "code":
return DE_ENSEMBLE_DEVIL_PROMPT + "\n" + DE_ENSEMBLE_CODE_FORMAT_PROMPT
return DE_ENSEMBLE_DEVIL_PROMPT + "\n" + DE_ENSEMBLE_TXT_FORMAT_PROMPT
elif name == "judge":
if mode == "final":
return DE_ENSEMBLE_JUDGE_FINAL_PROMPT
return DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT
def construct_messages(self, message_history_with_name, name, mode:str="txt", phase:str="universal"):
"""
基于name与mode来构建system message.
基于name来构建messages
"""
messages = []
messages.append({"role": "system", "content": self.get_system_prompt(name, mode)})
if name in ["angel", "devil"]:
messages = self._construct_debate(message_history_with_name, name, messages)
elif name == "judge":
messages = self._construct_judge(message_history_with_name, mode, messages)
return messages
def _construct_debate(self, message_history_with_name, name, messages):
user_message = ""
for message in message_history_with_name:
if message["name"] == "Judge":
continue
elif message["name"] == name:
if user_message:
messages.append({
"role": "user",
"name": "user",
"content": user_message.strip("\n"),
})
messages.append({
"role": "assistant",
"name": name,
"content": message["content"],
})
user_message = ""
else:
user_message += message["content"]
if user_message:
messages.append({
"role": "user",
"name": "user",
"content": user_message.strip("\n"),
})
return messages
def _construct_judge(self, message_history_with_name, mode, messages):
pass
async def debate_answer(self, message_history:List, role:str="angel"):
messages = self.construct_messages(message_history, role)
response = await self.llm.acompletion_text(messages=messages)
message_history.append({
"role":"user",
"name":role,
"content":response}
)
return message_history, response
async def judge_answer(self, message_history:List, phase:str="universal"):
messages = self.construct_messages(message_history, "judge", phase=phase)
response = await self.llm.acompletion_text(messages=messages)
message_history.append({
"role": "user",
"name": "judge",
"content": response}
)
return message_history, response
async def __call__(self, origin_solution:str, problem_description:str, max_round:int = 3, mode:str='txt'):
# 思路输入一个原始答案构建一个agent代表这个答案进行辩论另一个agentdevil使用debate llm的内容进行辩论法官在每一轮次做出决定是否终止到了maxround还没终止就由法官进行总结。
message_history_with_name = [
{"role":"user", "name":"angel", "content":origin_solution}
]
for index in range(max_round):
for agent in self.agents:
if agent == "angel":
if index == 0:
pass
message_history_with_name, rsp = self.debate_answer(message_history_with_name, role="angel")
elif agent == "devil":
message_history_with_name, rsp = self.debate_answer(message_history_with_name, role="devil")
elif agent == "judge":
message_history_with_name, judge_result = self.judge_answer(message_history_with_name, phase="universal")
if not judge_result["is_debating"]:
"""
这里需要在 self.judge_answer 中设置一个自动给出solution的地方
"""
return {"final_solution":judge_result["final_solution"]}
message_history_with_name.pop(-1)
message_history_with_name, judge_answer = self.judge_answer(message_history_with_name, phase="final")
return {"final_solution":judge_answer["debate_answer"]}
class Rephrase(Operator):
"""
1. AlphaCodium
2. https://arxiv.org/abs/2404.14963
"""
def __init__(self, name:str ="Rephraser", llm: LLM = LLM()):
super().__init__(name, llm)
async def __call__(self, problem_description:str)->str:
prompt = REPHRASE_ON_PROBLEM_PROMPT.format(problem_description=problem_description)
node = await ActionNode.from_pydantic(RephraseOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return response["rephrased_problem"]
class Test(Operator):
def __init__(self, name:str ="Tester", llm: LLM = LLM()):
super().__init__(name, llm)
def exec_code(self, solution, test_cases, problem_id):
# TODO
# 1. 获取更加详细的Test error信息
# 2. 更换Public Test数据集当前使用的数据存在Label Leak(使用的Reflexion的数据集)
# 3. 实现单独测试每一个test case -> 1
solution = solution["final_solution"]
test_code = test_cases_2_test_functions(solution, test_cases)
print("test_code", test_code)
try:
exec(test_code, globals())
except AssertionError as e:
exc_type, exc_value, exc_traceback = sys.exc_info()
tb_str = traceback.format_exception(exc_type, exc_value, exc_traceback)
with open("tester.txt", "a") as f:
f.write("test_error" +problem_id + "\n")
error_infomation = {"test_fail_case": {
"error_type": "AssertionError",
"error_message": str(e),
"traceback": tb_str
}}
print("error here", error_infomation)
return error_infomation
except Exception as e:
with open("tester.txt", "a") as f:
f.write(problem_id + "\n")
return {"exec_fail_case":str(e)}
return []
async def __call__(self, problem_id, problem, rephrase_problem, solution, test_cases):
result = self.exec_code(solution, test_cases, problem_id)
print("result here", result)
if result == []:
return solution
elif "exec_fail_case" in result:
result = result["exec_fail_case"]
prompt = REFLECTION_ON_PUBILIC_TEST_PROMPT.format(problem_description=problem, rephrase_problem=rephrase_problem, code_solution=solution, exec_pass=f"executed unsuccessfully, error: \n {result}", test_fail="executed unsucessfully")
node = await ActionNode.from_pydantic(ReflectionTestOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return {"final_solution":response["refined_solution"]}
else:
result = result["test_fail_case"]
prompt = REFLECTION_ON_PUBILIC_TEST_PROMPT.format(problem_description=problem, rephrase_problem=rephrase_problem, code_solution=solution, exec_pass="executed successfully", test_fail=result)
node = await ActionNode.from_pydantic(ReflectionTestOp).fill(context=prompt, llm=self.llm)
response = node.instruct_content.model_dump()
return {"final_solution":response["refined_solution"]}
class FindFact(Operator):
pass
class SelfAsk(Operator):
pass
class Verify(Operator):
"""
? 还没有想好
"""
pass

View file

@ -0,0 +1,56 @@
# -*- coding: utf-8 -*-
# @Date : 6/27/2024 19:46 PM
# @Author : didi
# @Desc : action nodes for operator
from pydantic import BaseModel, Field
class GenerateOp(BaseModel):
solution: str = Field(default="", description="Your Solution for this problem")
class GenerateCodeOp(BaseModel):
code_solution: str = Field(default="", description="Complete and correct code here.")
class GenerateCodeBlockOp(BaseModel):
code_solution: str = Field(default="", description="Your complete code solution for this problem")
class ReviewOp(BaseModel):
review_result: bool = Field(default=False, description="The Review Result (Bool). If you think this solution looks good for you, return 'true'; If not, return 'false'")
feedback: str = Field(default="", description="Your FeedBack for this problem based on the criteria. If the review result is true, you can put it 'nothing here'.")
class ReviseOp(BaseModel):
revised_solution: str = Field(default="", description="Based on the feedback, revised solution for this problem")
class FuEnsembleOp(BaseModel):
thought: str = Field(default="", description="Analyze the solutions and think how to combine the advantages of various solutions to form the best possible solution.")
final_solution: str = Field(default="", description="Output the final solution after analysis and integration")
class MdEnsembleOp(BaseModel):
thought: str = Field(
default="""Example thought process:
1. Examined the 'compare_one' function.
2. The function correctly handles both numeric and string inputs by converting strings to floats.
3. It properly compares two values and returns the larger one.
4. The function returns None if the values are equal, which might be useful in some contexts but could be improved by returning either value.
5. The use of 'isinstance' for type checking is a good practice.
6. The function handles decimal separators well by replacing ',' with '.'.
Overall, this solution effectively solves the problem of comparing two values, with good error handling and flexibility. It could be improved by specifying behavior for equal values, but it's a strong solution as is.""",
description="Step-by-step analysis of the solutions to determine the best one."
)
solution_letter: str = Field(
default="",
description="The letter of the chosen best solution (only one letter)."
)
class TestCaseExtractOp(BaseModel):
test_cases: list = Field(default=[('<function name>', [5, 8, 7, 1], 12), ('<function name>', [3, 3, 3, 3, 3], 9)],
description="Extracted test cases from the problem description")
class RephraseOp(BaseModel):
rephrased_problem: str = Field(default="", description="Rephrased problem description for this problem")
class ReflectionTestOp(BaseModel):
reflection: str = Field(default="", description="对关于代码执行错误或者测试用例失败step by step的思考")
refined_solution: str = Field(default="", description="对于代码执行错误或者测试用例失败的修正方案")

View file

@ -0,0 +1,185 @@
# -*- coding: utf-8 -*-
# @Date : 6/26/2024 17:07 PM
# @Author : didi
# @Desc : prompts of operators
GENERATE_PROMPT = """
Generate Solution for the following problem: {problem_description}
"""
GENERATE_CODE_PROMPT = """
You are an expert programmer tasked with solving a coding problem.
### Problem Description:
{problem_description}
### Instructions:
The above is an incomplete Python code fragment. Return the complete and correct code with no additional text.
Please maintain the JSON format in your response.
### Your Response:
"""
GENERATE_CODEBLOCK_REPHRASE_PROMPT = """
Please provide a self-contained Python script that solves the following problem in a markdown code block:
### Problem Description:
{problem_description}
### self reflection on the problem
{rephrase_problem}
When creating your solution:
1. Consider all edge cases and boundary conditions.
2. Avoid oversimplification - address all aspects of the problem.
3. Ensure your logic covers all stated requirements.
4. Avoid adding additional test cases beyond those provided in the problem description.
"""
GENERATE_CODEBLOCK_PROMPT ="""
Please provide a self-contained Python script that solves the following problem in a markdown code block:
{problem_description}
When creating your solution:
1. Consider all edge cases and boundary conditions.
2. Avoid oversimplification - address all aspects of the problem.
3. Ensure your logic covers all stated requirements.
4. Avoid adding additional test cases beyond those provided in the problem description.
"""
REVIEW_PROMPT = """
For the question described as {problem_description},
please review the following solution: {solution}, and provide a review result in boolean format.
If you believe the solution is capable of resolving the issue, return True; otherwise, return False, and include your comments
"""
REVISE_PROMPT = """
For the question described as {problem_description},
please evaluate and revise the solution provided: {solution}, taking into account the review feedbacks: {feedback}."
Then output the revised solution.
"""
FU_ENSEMBLE_PROMPT = """
### Given problem
{problem_description}
### We've got a list of solutions
<solutions>
{solutions}
</solutions>
### Instructions
Based on the given problem and solution candidates:
1. Analyze the pros and cons of each candidate solution
2. Consider how to integrate reasonable parts from different solutions
3. Formulate a more comprehensive and effective solution
"""
MD_ENSEMBLE_PROMPT = """
You are given a coding problem:
{problem_description}
Here is a list of possible solutions to the problem:
{solutions}
Using the inputs above, your goal is to choose the best solution to the code contest problem.
Don't just pick the most efficient solution. The main consideration is that the solution can fully solve the problem in a correct and robust manner.
Provide your final decision by writing the chosen solution letter (e.g., B).
Please maintain the JSON format in your response.
"""
DE_ENSEMBLE_TXT_FORMAT_PROMPT = """
Now please output your answer in json format, with the format as follows:
{\"Reason\": \"\", \"debate_answer\": \"the capital letter corresponding to the answer\"}.
Please strictly output in JSON format, do not output irrelevant content. """
DE_ENSEMBLE_CODE_FORMAT_PROMPT = """
Now please output your answer in json format, with the format as follows:
{{
"reason":"<为什么要这样做>",
"code_solution":"<你觉得合适的solution用代码表示出来>"
}}
Please strictly output in JSON format, do not output irrelevant content. """
DE_ENSEMBLE_ANGEL_PROMPT = """
Do you agree with my perspective? Please provide your reasons and answer.
"""
DE_ENSEMBLE_DEVIL_PROMPT = """
You agree with my answer 90% of the time and have almost no reservations. Affirm your agreement, share any additional thoughts if you have them, and conclude with the capital letter corresponding to your answer at the end of your response.
"""
DE_ENSEMBLE_JUDGE_FINAL_PROMPT = """
You, as the moderator, will evaluate both sides' answers and determine your
preference for an answer candidate. Please summarize your reasons for supporting affirmative/negative side and
give the final answer that you think is correct to conclude the debate. Now please output your answer in json format, with the format as follows:
{\"Reason\": \"\", \"debate_answer\": \"the capital letter corresponding to the answer\"}.
Please strictly output in JSON format, do not output irrelevant content.
"""
DE_ENSEMBLE_JUDGE_UNIVERSAL_PROMPT = """
You, as the moderator, will evaluate both sides' answers and determine if there is a clear
preference for an answer candidate. If so, please summarize your reasons for supporting affirmative/negative side and
give the final answer that you think is correct, and the debate will conclude. If not, the debate will continue to
the next round. Now please output your answer in json format, with the format as follows:
{\"Whether there is a preference\": \"Yes or No\", \"Supported Side\": \"Affirmative or Negative\",
\"Reason\": \"\", \"debate_answer\": \"the capital letter corresponding to the answer\"}.
Please strictly output in JSON format, do not output irrelevant content
"""
EXTRACT_CASE_PROMPT = """
You are given a coding problem, and you need to extract the test cases from the problem description.
{problem_description}
一个problem中会有多个测试用例每个测试用例包含三个部分
1. 函数名
2. 输入
3. 期望输出
每个测试用例包裹在一个三元组之中三元组之间用逗号分隔整体用列表包裹
由于结果需要被解析到JSON中True与False请表示为true, false;
"""
REPHRASE_ON_PROBLEM_PROMPT = """
You are given a code contest problem:
### problem
{problem_description}
### instrcutions
Given the code contest problem, Your Goal is:
Reflect on the problem, and describe it in your own words, in bullet points. Pay attention to small details, nuances, notes and examples in the problem description.
"""
REFLECTION_ON_PUBILIC_TEST_PROMPT = """
You are given a code contest problem, and a self-reflection on the problem:
### problem
{problem_description}
### self reflection on the problem
{rephrase_problem}
=======================
A Python code solution was generated for the problem:
### Code Solution
{code_solution}
=======================
This section of the code execution result is
### Execution Result
{exec_pass}
=======================
However, when running the following input example, the code solution above failed to produce the expected output:
#### Failed Test Case
{test_fail}
Your goal is to analyze the code solution and the error, and propose a fixed code which will produce the expected output for the provided test input.
The fixed code should keep the solution robust, and work for all other input examples as well.
Make sure the fixed code has a reasonable runtime - less than three seconds on a modern computer, given the problem constraints for large input.
"""

View file

@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
# @Date : 7/2/2024 17:36 PM
# @Author : didi
# @Desc : utils for experiment
import json
import re
import ast
from typing import List, Dict, Any, Tuple
from metagpt.llm import LLM
from metagpt.actions.action_node import ActionNode
from examples.ags.w_action_node.operator_an import TestCaseExtractOp
from examples.ags.w_action_node.prompt import EXTRACT_CASE_PROMPT
def extract_task_id(task_id: str) -> int:
"""Extract the numeric part of the task_id."""
match = re.search(r'/(\d+)', task_id)
return int(match.group(1)) if match else 0
def jsonl_ranker(input_file: str, output_file: str):
"""
Read a JSONL file, sort the entries based on task_id, and write to a new JSONL file.
:param input_file: Path to the input JSONL file
:param output_file: Path to the output JSONL file
"""
# Read and parse the JSONL file
with open(input_file, 'r') as f:
data = [json.loads(line) for line in f]
# Sort the data based on the numeric part of task_id
sorted_data = sorted(data, key=lambda x: extract_task_id(x['task_id']))
# Write the sorted data to a new JSONL file
with open(output_file, 'w') as f:
for item in sorted_data:
f.write(json.dumps(item) + '\n')
def parse_python_literal(s):
try:
return ast.literal_eval(s)
except (ValueError, SyntaxError):
return s
def extract_test_cases_from_jsonl(problem_id:str, file_path:str="public_test_reflexion.jsonl"):
# 保留原有的硬编码测试用例
hardcoded_cases = {
"HumanEval/32": "",
"HumanEval/38": "",
"HumanEval/50": "",
}
# 检查是否有硬编码的测试用例
if problem_id in hardcoded_cases:
return hardcoded_cases[problem_id]
# 如果没有硬编码的测试用例,从文件中读取
with open(file_path, 'r') as file:
for line in file:
data = json.loads(line)
if data.get("id") == problem_id:
return data.get("test")
return None # 如果没有找到问题,返回 None
def extract_test_cases(docstring: str) -> List[Tuple[str, List[Any], Any]]:
# 使用正则表达式匹配测试用例,现在捕获函数名和任意输出
pattern = r'>>> (\w+)\((.*?)\)\n\s*(.*?)(?=\n|$)'
matches = re.findall(pattern, docstring, re.DOTALL)
test_cases = []
for match in matches:
func_name, input_str, expected_output = match
# 处理输入
input_list = []
for item in input_str.split(','):
item = item.strip()
try:
# 尝试将输入转换为数值类型
if '.' in item:
input_list.append(float(item))
else:
input_list.append(int(item))
except ValueError:
# 如果无法转换为数值,则保留为字符串
input_list.append(item.strip("'\""))
# 处理输出
try:
# 尝试将输出转换为数值或布尔值
if expected_output.lower() == 'true':
expected_output = True
elif expected_output.lower() == 'false':
expected_output = False
elif '.' in expected_output:
expected_output = float(expected_output)
else:
expected_output = int(expected_output)
except ValueError:
# 如果无法转换,则保留为字符串
expected_output = expected_output.strip("'\"")
test_cases.append([func_name, input_list, expected_output])
return test_cases
async def llm_extract_test_case(id, problem_description: str, file_path:str="public_test.jsonl"):
prompt = EXTRACT_CASE_PROMPT.format(problem_description=problem_description)
node = await ActionNode.from_pydantic(TestCaseExtractOp).fill(context=prompt, llm=LLM())
result = node.instruct_content.model_dump()
with open(file_path,"a") as f:
f.write(json.dumps({id:result["test_cases"]}) + '\n')
return {id:result["test_cases"]}
def test_cases_2_test_functions(solution: str, test_cases: str):
tester_function = f"""
{solution}
{test_cases}
"""
return tester_function

View file

@ -17,6 +17,7 @@ from pydantic import BaseModel, Field, create_model, model_validator
from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.actions.code_sanitize import sanitize
from metagpt.const import USE_CONFIG_TIMEOUT
from metagpt.llm import BaseLLM
from metagpt.logs import logger
@ -41,6 +42,7 @@ TAG = "CONTENT"
LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT."
FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else."
SIMPLE_TEMPLATE = """
## context
{context}
@ -147,6 +149,8 @@ class ActionNode:
prevs: List["ActionNode"] # previous nodes
nexts: List["ActionNode"] # next nodes
MODE_CODE_FILL = "code_fill"
def __init__(
self,
key: str,
@ -464,6 +468,68 @@ class ActionNode:
return self
def get_field_name(self):
"""
Get the field name from the Pydantic model associated with this ActionNode.
"""
model_class = self.create_class()
fields = model_class.model_fields
# Assuming there's only one field in the model
if len(fields) == 1:
return next(iter(fields))
# If there are multiple fields, we might want to use self.key to find the right one
return self.key
async def code_fill(
self,
context,
function_name=None,
timeout=USE_CONFIG_TIMEOUT
):
"""
fill CodeBlock Node
"""
def extract_code_from_response(response):
"""
Extracts code wrapped in triple backticks from the response,
removing any language specifier.
:param response: The full response from the LLM
:return: The extracted code, or None if no code is found
"""
code_pattern = r"```(?:\w+\n)?([\s\S]*?)```"
matches = re.findall(code_pattern, response)
if matches:
# The first group in the regex contains the code without the language specifier
code = matches[0].strip()
return code
return None
import re
field_name = self.get_field_name()
prompt = context
# print("generate prompt", "\n", prompt)
content = await self.llm.aask(prompt, timeout=timeout)
# print("generate content", "\n", content)
extracted_code = sanitize(code=content, entrypoint=function_name)
# extracted_code = extract_code_from_response(content)
result = {field_name: extracted_code}
# print("final_result", "\n", result)
return result
async def messages_fill(
self,
):
"""
参考这个代码只不过LLM调用方式改成使用
参考
"""
pass
async def fill(
self,
context,
@ -474,6 +540,7 @@ class ActionNode:
images: Optional[Union[str, list[str]]] = None,
timeout=USE_CONFIG_TIMEOUT,
exclude=[],
function_name: str = None
):
"""Fill the node(s) with mode.
@ -500,6 +567,11 @@ class ActionNode:
if self.schema:
schema = self.schema
if mode == self.MODE_CODE_FILL:
result = await self.code_fill(context, function_name, timeout)
self.instruct_content = self.create_class()(**result)
return self
if strgy == "simple":
return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude)
elif strgy == "complex":

View file

@ -0,0 +1,167 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/7/24 16:37
@Author : didi
@File : code_node.py
"""
import os
import ast
import pathlib
import traceback
from typing import Dict, Generator, List, Optional, Set, Tuple
import tree_sitter_python
from tqdm import tqdm
from tree_sitter import Language, Node, Parser
CLASS_TYPE = "class_definition"
FUNCTION_TYPE = "function_definition"
IMPORT_TYPE = ["import_statement", "import_from_statement"]
IDENTIFIER_TYPE = "identifier"
ATTRIBUTE_TYPE = "attribute"
RETURN_TYPE = "return_statement"
EXPRESSION_TYPE = "expression_statement"
ASSIGNMENT_TYPE = "assignment"
def traverse_tree(node: Node) -> Generator[Node, None, None]:
cursor = node.walk()
depth = 0
visited_children = False
while True:
if not visited_children:
yield cursor.node
if not cursor.goto_first_child():
depth += 1
visited_children = True
elif cursor.goto_next_sibling():
visited_children = False
elif not cursor.goto_parent() or depth == 0:
break
else:
depth -= 1
def syntax_check(code, verbose=False):
try:
ast.parse(code)
return True
except (SyntaxError, MemoryError):
if verbose:
traceback.print_exc()
return False
def code_extract(text: str) -> str:
lines = text.split("\n")
longest_line_pair = (0, 0)
longest_so_far = 0
for i in range(len(lines)):
for j in range(i + 1, len(lines)):
current_lines = "\n".join(lines[i : j + 1])
if syntax_check(current_lines):
current_length = sum(1 for line in lines[i : j + 1] if line.strip())
if current_length > longest_so_far:
longest_so_far = current_length
longest_line_pair = (i, j)
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1])
def get_definition_name(node: Node) -> str:
for child in node.children:
if child.type == IDENTIFIER_TYPE:
return child.text.decode("utf8")
def has_return_statement(node: Node) -> bool:
traverse_nodes = traverse_tree(node)
for node in traverse_nodes:
if node.type == RETURN_TYPE:
return True
return False
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]:
def dfs_get_deps(node: Node, deps: Set[str]) -> None:
for child in node.children:
if child.type == IDENTIFIER_TYPE:
deps.add(child.text.decode("utf8"))
else:
dfs_get_deps(child, deps)
name2deps = {}
for name, node in nodes:
deps = set()
dfs_get_deps(node, deps)
name2deps[name] = deps
return name2deps
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]:
queue = [entrypoint]
visited = {entrypoint}
while queue:
current = queue.pop(0)
if current not in call_graph:
continue
for neighbour in call_graph[current]:
if not (neighbour in visited):
visited.add(neighbour)
queue.append(neighbour)
return visited
def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
code = code_extract(code)
code_bytes = bytes(code, "utf8")
parser = Parser(Language(tree_sitter_python.language()))
tree = parser.parse(code_bytes)
class_names = set()
function_names = set()
variable_names = set()
root_node = tree.root_node
import_nodes = []
definition_nodes = []
for child in root_node.children:
if child.type in IMPORT_TYPE:
import_nodes.append(child)
elif child.type == CLASS_TYPE:
name = get_definition_name(child)
if not (
name in class_names or name in variable_names or name in function_names
):
definition_nodes.append((name, child))
class_names.add(name)
elif child.type == FUNCTION_TYPE:
name = get_definition_name(child)
if not (
name in function_names or name in variable_names or name in class_names
) and has_return_statement(child):
definition_nodes.append((name, child))
function_names.add(get_definition_name(child))
elif (
child.type == EXPRESSION_TYPE and child.children[0].type == ASSIGNMENT_TYPE
):
subchild = child.children[0]
name = get_definition_name(subchild)
if not (
name in variable_names or name in function_names or name in class_names
):
definition_nodes.append((name, subchild))
variable_names.add(name)
if entrypoint:
name2deps = get_deps(definition_nodes)
reacheable = get_function_dependency(entrypoint, name2deps)
sanitized_output = b""
for node in import_nodes:
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
for pair in definition_nodes:
name, node = pair
if entrypoint and not (name in reacheable):
continue
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n"
return sanitized_output[:-1].decode("utf8")

10
test.py Normal file
View file

@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
# @Date : 6/27/2024 18:00 PM
# @Author : didi
# @Desc : test on humaneval graph
import asyncio
from examples.ags.benchmark.humaneval import sample_generate, samples_generate
asyncio.run(sample_generate('HumanEval/id',result_path="result_path",mode="alpha"))
asyncio.run(samples_generate(mode='alpha',result_path="result_path"))