mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Update AGS
This commit is contained in:
parent
4d376649cc
commit
aeac3fe3f9
8 changed files with 457 additions and 57 deletions
168
examples/ags/demo/medprompt.py
Normal file
168
examples/ags/demo/medprompt.py
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
# 第一段代码是MedPrompt,一种利用利用LLM产生多种答案,然后进行洗牌投票来选出最优决策的方法
|
||||
# 我需要你首先理解这个方法,然后将这个方法与我的代码结合起来
|
||||
# 我的代码如下,我们会接收到多个答案,我需要你将这个答案利用MedPrompt的方法进行处理。
|
||||
# 在我的代码中,产生llm answer是用 await ActionNode.from_pydantic(ScEnsembleOp).fill(context=prompt, llm=self.llm) 实现的。
|
||||
|
||||
class ScEnsemble(Ensemble):
|
||||
|
||||
def __init__(self, name:str ="Ensembler", llm: LLM = LLM()):
|
||||
super().__init__(name, llm)
|
||||
|
||||
async def __call__(self, solutions:List, problem_description):
|
||||
solution_text = ""
|
||||
for index, solution in enumerate(solutions):
|
||||
solution_text += f"Solution{index}: {str(solution)}" + "\n"
|
||||
|
||||
prompt = ENSEMBLE_PROMPT.format(solutions=solution_text, problem_description=problem_description)
|
||||
node = await ActionNode.from_pydantic(ScEnsembleOp).fill(context=prompt, llm=self.llm)
|
||||
response = node.instruct_content.model_dump()
|
||||
return response
|
||||
|
||||
class Medprompt(QASystem):
|
||||
def __init__(
|
||||
self,
|
||||
agents: list,
|
||||
num_reasoning_steps: int,
|
||||
debate_prompts: dict,
|
||||
verbose: bool = False,
|
||||
name: Optional[str] = None,
|
||||
mock: bool = False, # Unused
|
||||
agent_prompts: Optional[dict] = None, # Unused
|
||||
):
|
||||
super().__init__(verbose=verbose)
|
||||
|
||||
assert len(agents) == 1
|
||||
self._num_reasoning_steps = num_reasoning_steps
|
||||
self._agent = agents[0]
|
||||
self._agent_names = [type(agent).__name__ for agent in agents]
|
||||
self.prompts = debate_prompts
|
||||
|
||||
"""
|
||||
This is an implementation of the Medprompt system take
|
||||
from https://arxiv.org/abs/2311.16452
|
||||
|
||||
The system is comprised of a single agent prompted to provide multiple
|
||||
answers and explainations via temperature sampling and question shuffling.
|
||||
The final answer is determined by taking the most frequent answer provided
|
||||
by the agent during the aggregation.
|
||||
|
||||
IMPORTANT: The current implementation only contains the first three steps
|
||||
of the Medprompt setup. Therefore additional improvements can be made
|
||||
by including the kNN and Ensemble with choice shuffling as well.
|
||||
"""
|
||||
|
||||
# Setup debate metrics
|
||||
def metrics(
|
||||
self, info: Dict[str, Any], format_solution_fn: Callable, solution: str
|
||||
) -> Dict[str, Any]:
|
||||
return construct_agent_metrics(
|
||||
info=info,
|
||||
format_solution_fn=format_solution_fn,
|
||||
solution=solution,
|
||||
verbose=self._verbose,
|
||||
agents=["Agent_0"],
|
||||
agent_names=self._agent_names,
|
||||
num_rounds=self._num_reasoning_steps,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def shuffle_answers(question: str) -> Tuple[str, Any]:
|
||||
"""
|
||||
Takes in a multiple choice question string and shuffles only the answer texts,
|
||||
keeping the answer labels (A, B, C, etc.) intact.
|
||||
Also returns a mapping of shuffled choices to original choices.
|
||||
"""
|
||||
# Find the start of the answer section (e.g., '\nA:')
|
||||
answer_section_start = re.search(r"\n[A-Z]:", question).start() # type: ignore
|
||||
|
||||
# Split the question from the answers
|
||||
main_question = question[:answer_section_start]
|
||||
answers = question[answer_section_start + 1 :].split("\n")
|
||||
|
||||
# Filter out answers that are not in the correct format
|
||||
# answers = [answer for answer in answers if ": " == answer[1:3]]
|
||||
|
||||
# Extract answer texts
|
||||
answer_texts = [answer.split(": ", 1)[1] for answer in answers]
|
||||
|
||||
# assert len(answer_texts) > 0
|
||||
|
||||
# Shuffle the answer texts and create a mapping to original answers
|
||||
shuffled_texts = answer_texts.copy()
|
||||
random.shuffle(shuffled_texts)
|
||||
answer_mapping = {
|
||||
chr(65 + i): answers[answer_texts.index(text)][0]
|
||||
for i, text in enumerate(shuffled_texts)
|
||||
}
|
||||
|
||||
# Reassemble the shuffled answers with original labels
|
||||
shuffled_answers = [
|
||||
f"{chr(65 + i)}: {text}" for i, text in enumerate(shuffled_texts)
|
||||
]
|
||||
|
||||
# Reassemble the question
|
||||
shuffled_question = main_question + "\n" + "\n".join(shuffled_answers)
|
||||
return shuffled_question, answer_mapping
|
||||
|
||||
def answer(
|
||||
self,
|
||||
question: str,
|
||||
) -> Tuple[str, Any]:
|
||||
|
||||
agent_answers: Any = {"Agent_0": {}}
|
||||
agent_info: Any = {"Agent_0": {}}
|
||||
agent_responses: Any = {"Agent_0": {}}
|
||||
if self._verbose:
|
||||
print("#######################")
|
||||
print("REASONING STEP")
|
||||
print("#######################")
|
||||
|
||||
message_history: List[Dict[str, str]] = []
|
||||
|
||||
for i in range(self._num_reasoning_steps):
|
||||
|
||||
try:
|
||||
# TODO: Provide the options to the system as well. This would
|
||||
# make it much easier to shuffle the answers. Furthermore, remove
|
||||
# all questions without options in load_datasets.py.
|
||||
shuffled_question, answer_mapping = self.shuffle_answers(question)
|
||||
except Exception as e:
|
||||
shuffled_question = question
|
||||
answer_mapping = {"A": "A", "B": "B", "C": "C", "D": "D", "E": "E"}
|
||||
print("question: ", question)
|
||||
print("Shuffling failed, using original question: ", e)
|
||||
|
||||
answer, info = self._agent.answer(
|
||||
question=shuffled_question,
|
||||
system_message=self.prompts["system"],
|
||||
)
|
||||
|
||||
# Dummy data to check the suffler.
|
||||
# answer = "A"
|
||||
# info = {"prompt_tokens": 1234, "response_tokens": 1234,
|
||||
# "response": "I don't know, A.",
|
||||
# "cost": 0.0, "num_messages_removed": 0.0,
|
||||
# "answer_duration": 1.0, "engine": "Diesel"}
|
||||
|
||||
# Map the answer back to the original answer
|
||||
if answer in answer_mapping:
|
||||
answer = answer_mapping[answer]
|
||||
|
||||
message_history.append(
|
||||
{"agent_name": f"Reasoning_{i}", "content": info["response"]}
|
||||
)
|
||||
agent_answers["Agent_0"][f"Reasoning_{i}"] = answer
|
||||
agent_responses["Agent_0"][f"Reasoning_{i}"] = info["response"]
|
||||
agent_info["Agent_0"][f"Reasoning_{i}"] = info
|
||||
|
||||
final_answers = [
|
||||
agent_answers["Agent_0"][f"Reasoning_{i}"]
|
||||
for i in range(self._num_reasoning_steps)
|
||||
]
|
||||
answer, _ = most_frequent(final_answers)
|
||||
|
||||
return answer, {
|
||||
"response": agent_responses,
|
||||
"agent_answers": agent_answers,
|
||||
"agent_info": agent_info,
|
||||
}
|
||||
|
|
@ -5,37 +5,46 @@
|
|||
|
||||
from metagpt.llm import LLM
|
||||
|
||||
from examples.ags.w_action_node.operator import Generate, GenerateCode, Review, Revise, Ensemble, ScEnsemble
|
||||
from examples.ags.w_action_node.operator import Generate, GenerateCode, GenerateCodeBlock, Review, Revise, Ensemble, MdEnsemble
|
||||
|
||||
class Graph:
|
||||
def __init__(self, name:str, llm:LLM) -> None:
|
||||
self.name = name
|
||||
# TODO 是否需要对每一个算子使用不同的Graph?
|
||||
self.model = llm
|
||||
|
||||
def __call__():
|
||||
NotImplementedError("Subclasses must implement __call__ method")
|
||||
|
||||
|
||||
class HumanEvalGraph(Graph):
|
||||
def __init__(self, name:str, llm: LLM, criteria:str) -> None:
|
||||
def __init__(self, name:str, llm: LLM, criteria:str, vote_count:int =3) -> None:
|
||||
super().__init__(name, llm)
|
||||
self.criteria = criteria # TODO 自动构建图时,图的初始参数与图所使用的算子要求的外部参数相关
|
||||
self.criteria = criteria # TODO 自动构建图时,图的初始参数与图所使用的算子要求的外部参数相匹配
|
||||
self.generate_code = GenerateCode(llm=llm)
|
||||
self.generate_code_block = GenerateCodeBlock(llm=llm)
|
||||
self.review = Review(llm=llm, criteria=criteria)
|
||||
self.revise = Revise(llm=llm)
|
||||
self.ensemble = Ensemble(llm=llm)
|
||||
self.scensemble = ScEnsemble(llm=llm)
|
||||
|
||||
# async def __call__(self, problem:str, ensemble_count:int = 2):
|
||||
# solution_list = []
|
||||
# for _ in range(ensemble_count):
|
||||
# solution = await self.single_solve(problem, 3)
|
||||
# solution_list.append(solution)
|
||||
# solution = await self.ensemble(solution_list, problem)
|
||||
# return solution
|
||||
self.mdensemble = MdEnsemble(llm=llm, vote_count=vote_count)
|
||||
|
||||
async def __call__(self, problem:str):
|
||||
async def __call__(self, problem:str, ensemble_count:int = 3):
|
||||
solution_list = []
|
||||
for _ in range(ensemble_count):
|
||||
# solution = await self.generate_code(problem)
|
||||
solution = await self.generate_code_block(problem)
|
||||
solution = solution.get('code_solution')
|
||||
solution_list.append(solution)
|
||||
solution = await self.mdensemble(solution_list, problem)
|
||||
return solution
|
||||
|
||||
async def review_revise_ensemble(self, problem:str, ensemble_count:int = 2):
|
||||
solution_list = []
|
||||
for _ in range(ensemble_count):
|
||||
solution = await self.single_solve(problem, 3)
|
||||
solution_list.append(solution)
|
||||
solution = await self.ensemble(solution_list, problem)
|
||||
return solution
|
||||
|
||||
async def simple_ensemble(self, problem:str):
|
||||
solution_list = []
|
||||
for _ in range(3):
|
||||
solution = await self.generate_code(problem)
|
||||
|
|
@ -44,15 +53,6 @@ class HumanEvalGraph(Graph):
|
|||
solution = await self.ensemble(solution_list, problem)
|
||||
return solution
|
||||
|
||||
# async def __call__(self, problem:str):
|
||||
# solution_list = []
|
||||
# for _ in range(3):
|
||||
# solution = await self.generate_code(problem)
|
||||
# solution = solution.get('code_solution')
|
||||
# solution_list.append(solution)
|
||||
# solution = await self.scensemble(solution_list, problem)
|
||||
# return solution
|
||||
|
||||
async def single_solve(self, problem:str, max_loop:int):
|
||||
solution = await self.generate_code(problem)
|
||||
solution = solution.get('code_solution')
|
||||
|
|
|
|||
|
|
@ -3,13 +3,15 @@
|
|||
# @Author : didi
|
||||
# @Desc : operator demo of ags
|
||||
|
||||
from typing import List
|
||||
import random
|
||||
from typing import List, Tuple, Any, Dict
|
||||
from collections import Counter
|
||||
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
|
||||
from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, ReviewOp, ReviseOp, EnsembleOp
|
||||
from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, ENSEMBLE_PROMPT
|
||||
from examples.ags.w_action_node.operator_an import GenerateOp, GenerateCodeOp, GenerateCodeBlockOp ,ReviewOp, ReviseOp, EnsembleOp, MdEnsembleOp
|
||||
from examples.ags.w_action_node.prompt import GENERATE_PROMPT, GENERATE_CODE_PROMPT, REVIEW_PROMPT, REVISE_PROMPT, ENSEMBLE_PROMPT, MD_ENSEMBLE_PROMPT
|
||||
|
||||
class Operator:
|
||||
def __init__(self, name, llm:LLM=None):
|
||||
|
|
@ -40,6 +42,17 @@ class GenerateCode(Operator):
|
|||
response = node.instruct_content.model_dump()
|
||||
return response
|
||||
|
||||
class GenerateCodeBlock(Operator):
|
||||
|
||||
def __init__(self, name:str ="Coder", llm: LLM = LLM()):
|
||||
super().__init__(name, llm)
|
||||
|
||||
async def __call__(self, problem_description):
|
||||
prompt = GENERATE_CODE_PROMPT.format(problem_description=problem_description)
|
||||
node = await ActionNode.from_pydantic(GenerateCodeBlockOp).fill(context=prompt, llm=self.llm,mode='code_fill')
|
||||
response = node.instruct_content.model_dump()
|
||||
return response
|
||||
|
||||
class Review(Operator):
|
||||
|
||||
def __init__(self, criteria, name:str ="Reviewer", llm: LLM = LLM()):
|
||||
|
|
@ -77,17 +90,79 @@ class Ensemble(Operator):
|
|||
response = node.instruct_content.model_dump()
|
||||
return response
|
||||
|
||||
class MdEnsemble(Ensemble):
|
||||
|
||||
class ScEnsemble(Operator):
|
||||
|
||||
def __init__(self, name:str ="Ensembler", llm: LLM = LLM()):
|
||||
def __init__(self, name:str ="MdEnsembler", llm: LLM = LLM(), vote_count:int=3):
|
||||
super().__init__(name, llm)
|
||||
self.vote_count = vote_count
|
||||
|
||||
@staticmethod
|
||||
def shuffle_answers(solutions: List[str]) -> Tuple[List[str], Dict[str, str]]:
|
||||
shuffled_solutions = solutions.copy()
|
||||
random.shuffle(shuffled_solutions)
|
||||
answer_mapping = {
|
||||
chr(65 + i): solutions.index(sol)
|
||||
for i, sol in enumerate(shuffled_solutions)
|
||||
}
|
||||
return shuffled_solutions, answer_mapping
|
||||
|
||||
@staticmethod
|
||||
def most_frequent(lst: List[Any]) -> Tuple[Any, int]:
|
||||
counter = Counter(lst)
|
||||
most_common = counter.most_common(1)
|
||||
return most_common[0] if most_common else (None, 0)
|
||||
|
||||
async def __call__(self, solutions:List, problem_description):
|
||||
solution_text = ""
|
||||
for solution in solutions:
|
||||
solution_text += str(solution) + "\n"
|
||||
prompt = ENSEMBLE_PROMPT.format(solutions=solution_text, problem_description=problem_description)
|
||||
node = await ActionNode.from_pydantic(EnsembleOp).fill(context=prompt, llm=self.llm)
|
||||
response = node.instruct_content.model_dump()
|
||||
return response
|
||||
async def __call__(self, solutions:List[str], problem_description:str,):
|
||||
all_responses = []
|
||||
|
||||
for _ in range(self.vote_count):
|
||||
shuffled_solutions, answer_mapping = self.shuffle_answers(solutions)
|
||||
|
||||
solution_text = ""
|
||||
for index, solution in enumerate(shuffled_solutions):
|
||||
solution_text += f"{chr(65 + index)}: {str(solution)}\n"
|
||||
|
||||
prompt = MD_ENSEMBLE_PROMPT.format(solutions=solution_text, problem_description=problem_description)
|
||||
node = await ActionNode.from_pydantic(MdEnsembleOp).fill(context=prompt, llm=self.llm)
|
||||
response = node.instruct_content.model_dump()
|
||||
|
||||
answer = response.get('solution_letter', '')
|
||||
answer = answer.strip().upper()
|
||||
|
||||
if answer in answer_mapping:
|
||||
original_index = answer_mapping[answer]
|
||||
all_responses.append(solutions[original_index])
|
||||
|
||||
final_answer, frequency = self.most_frequent(all_responses)
|
||||
|
||||
return {"final_solution": final_answer}
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
# def load_llm_configs(*config_names):
|
||||
# """
|
||||
# Load multiple LLM configurations and return a list of initialized LLMs.
|
||||
|
||||
# :param config_names: Variable number of configuration file names (without .yaml extension)
|
||||
# :return: List of initialized LLM objects
|
||||
# """
|
||||
# llms = []
|
||||
# for config_name in config_names:
|
||||
# config_path = Path(f"~/.metagpt/{config_name}.yaml").expanduser()
|
||||
# if config_path.exists():
|
||||
# config = Config.from_yaml_file(config_path)
|
||||
# llms.append(LLM(config.llm))
|
||||
# else:
|
||||
# print(f"Warning: Configuration file {config_path} not found. Skipping.")
|
||||
# return llms
|
||||
|
||||
|
||||
# 使用函数加载多个 LLM 配置
|
||||
# llms = load_llm_configs("gpt-4o", "sonnet-35") # 你可以根据需要添加或删除配置
|
||||
|
|
@ -11,6 +11,9 @@ class GenerateOp(BaseModel):
|
|||
class GenerateCodeOp(BaseModel):
|
||||
code_solution: str = Field(default="", description="Your Code Solution for this problem")
|
||||
|
||||
class GenerateCodeBlockOp(BaseModel):
|
||||
code_solution: str = Field(default="", description="Your Code Solution for this problem")
|
||||
|
||||
class ReviewOp(BaseModel):
|
||||
review_result: bool = Field(default=False, description="The Review Result (Bool). If you think this solution looks good for you, return 'true'; If not, return 'false'")
|
||||
feedback: str = Field(default="", description="Your FeedBack for this problem based on the criteria. If the review result is true, you can put it 'nothing here'.")
|
||||
|
|
@ -21,5 +24,6 @@ class ReviseOp(BaseModel):
|
|||
class EnsembleOp(BaseModel):
|
||||
final_solution: str = Field(default="", description="Final ensemble solution for this problem")
|
||||
|
||||
class ScEnsembleOp(BaseModel):
|
||||
solution_number: int = Field(default="", description="Choose The Best Solution Between These, and outp[ut the solution number")
|
||||
class MdEnsembleOp(BaseModel):
|
||||
thought: str = Field(default="", description="Analyze the solutions and think what's the best step by step.")
|
||||
solution_letter: str = Field(default="", description="Choose The Best Solution, and output the solution letter")
|
||||
|
|
@ -9,8 +9,20 @@ Generate Solution for the following problem: {problem_description}
|
|||
"""
|
||||
|
||||
GENERATE_CODE_PROMPT = """
|
||||
Generate Code Solution for the following problem: {problem_description}
|
||||
Below is an instruction that describes a task, paired with an input that provides further context.
|
||||
Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
Write a program to perform the given task.
|
||||
|
||||
Input:
|
||||
{problem_description}
|
||||
|
||||
### Response:
|
||||
"""
|
||||
# GENERATE_CODE_PROMPT = """
|
||||
# Generate Code Solution for the following problem: {problem_description}
|
||||
# """
|
||||
|
||||
REVIEW_PROMPT = """
|
||||
For the question described as {problem_description},
|
||||
|
|
@ -28,3 +40,14 @@ ENSEMBLE_PROMPT = """
|
|||
For the question described as {problem_description}, Solutions: {solutions}
|
||||
Please select the solution that appears most frequently from these options and ensemble this to provide best solution.
|
||||
"""
|
||||
|
||||
MD_ENSEMBLE_PROMPT = """
|
||||
# Context
|
||||
For the question described as {problem_description},
|
||||
Solutions can be seen below:
|
||||
{solutions}
|
||||
|
||||
# Instruction
|
||||
Based on the problem and solution candidates, carefully analyze which is the best answer. Focus solely on the correctness of the solution in addressing the problem.
|
||||
Provide your final decision by writing the chosen solution number (e.g., A).
|
||||
"""
|
||||
32
examples/ags/w_action_node/utils.py
Normal file
32
examples/ags/w_action_node/utils.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
# @Date : 7/2/2024 17:36 PM
|
||||
# @Author : didi
|
||||
# @Desc : utils for experiment
|
||||
|
||||
import json
|
||||
import re
|
||||
from typing import List, Dict, Any
|
||||
|
||||
def extract_task_id(task_id: str) -> int:
|
||||
"""Extract the numeric part of the task_id."""
|
||||
match = re.search(r'/(\d+)', task_id)
|
||||
return int(match.group(1)) if match else 0
|
||||
|
||||
def jsonl_ranker(input_file: str, output_file: str):
|
||||
"""
|
||||
Read a JSONL file, sort the entries based on task_id, and write to a new JSONL file.
|
||||
|
||||
:param input_file: Path to the input JSONL file
|
||||
:param output_file: Path to the output JSONL file
|
||||
"""
|
||||
# Read and parse the JSONL file
|
||||
with open(input_file, 'r') as f:
|
||||
data = [json.loads(line) for line in f]
|
||||
|
||||
# Sort the data based on the numeric part of task_id
|
||||
sorted_data = sorted(data, key=lambda x: extract_task_id(x['task_id']))
|
||||
|
||||
# Write the sorted data to a new JSONL file
|
||||
with open(output_file, 'w') as f:
|
||||
for item in sorted_data:
|
||||
f.write(json.dumps(item) + '\n')
|
||||
72
he_test.py
72
he_test.py
|
|
@ -1,33 +1,71 @@
|
|||
import json
|
||||
import asyncio
|
||||
|
||||
import aiofiles
|
||||
from metagpt.llm import LLM
|
||||
from evalplus.data import get_human_eval_plus, write_jsonl
|
||||
from examples.ags.w_action_node.utils import jsonl_ranker
|
||||
from examples.ags.w_action_node.graph import HumanEvalGraph
|
||||
from examples.ags.w_action_node.operator import GenerateCode
|
||||
|
||||
generate_code = GenerateCode(llm=LLM())
|
||||
case = get_human_eval_plus()['HumanEval/10']
|
||||
solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability')
|
||||
|
||||
async def sample_generate(case):
|
||||
solution_result = await solver(case['prompt'])
|
||||
solver = HumanEvalGraph(name="solver", llm=LLM(), criteria='correctness, efficiency, readability', vote_count=5)
|
||||
|
||||
async def sample_generate(id):
|
||||
case = get_human_eval_plus()[f"{id}"]
|
||||
solution_result = await solver(case['prompt'],ensemble_count=3)
|
||||
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
|
||||
print(sample_dict)
|
||||
with open("samples.jsonl", mode='a') as f:
|
||||
f.write(json.dumps(sample_dict) + '\n')
|
||||
jsonl_ranker("samples.jsonl", "samples.jsonl")
|
||||
|
||||
async def samples_generate_sequence():
|
||||
sample_list = []
|
||||
for case in get_human_eval_plus().values():
|
||||
solution_result = await solver(case['prompt'])
|
||||
sample_dict = dict(task_id=case['task_id'], solution=solution_result['final_solution'])
|
||||
sample_list.append(sample_dict)
|
||||
write_jsonl("samples.jsonl", sample_list)
|
||||
async def samples_generate(mode:str):
|
||||
cases = list(get_human_eval_plus().values())
|
||||
file_lock = asyncio.Lock()
|
||||
|
||||
async def solve_and_write(case, mode):
|
||||
try:
|
||||
if mode == 'llm':
|
||||
solution_result = await generate_code(case['prompt'])
|
||||
sample_dict = {
|
||||
'task_id': case['task_id'],
|
||||
'solution': solution_result['code_solution']
|
||||
}
|
||||
elif mode == "ags":
|
||||
solution_result = await solver(case['prompt'], ensemble_count=3)
|
||||
sample_dict = {
|
||||
'task_id': case['task_id'],
|
||||
'solution': solution_result['final_solution']
|
||||
}
|
||||
|
||||
async with file_lock:
|
||||
async with aiofiles.open("samples.jsonl", mode='a') as f:
|
||||
await f.write(json.dumps(sample_dict) + '\n')
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return case['task_id']
|
||||
|
||||
tasks = [solve_and_write(case, mode) for case in cases]
|
||||
results = await asyncio.gather(*tasks)
|
||||
failed_tasks = [task_id for task_id in results if task_id is not None]
|
||||
|
||||
# TODO 这个地方还是不够自动化
|
||||
if failed_tasks:
|
||||
for task_id in failed_tasks:
|
||||
try:
|
||||
await sample_generate(task_id)
|
||||
except Exception as e:
|
||||
print(f"failure {task_id}")
|
||||
jsonl_ranker("samples.jsonl", "samples.jsonl")
|
||||
|
||||
async def samples_generate_ags():
|
||||
sample_list = []
|
||||
cases = list(get_human_eval_plus().values())
|
||||
|
||||
async def solve_with_id(case):
|
||||
solution_result = await solver(case['prompt'])
|
||||
solution_result = await solver(case['prompt'], ensemble_count=3)
|
||||
return case['task_id'], solution_result['final_solution']
|
||||
|
||||
tasks = [solve_with_id(case) for case in cases]
|
||||
|
|
@ -56,8 +94,10 @@ async def samples_generate_llm():
|
|||
|
||||
write_jsonl("samples.jsonl", sample_list)
|
||||
|
||||
# asyncio.run(sample_generate(case))
|
||||
# asyncio.run(sample_generate('HumanEval/101'))
|
||||
# asyncio.run(samples_generate_llm())
|
||||
asyncio.run(samples_generate_ags())
|
||||
asyncio.run(samples_generate(mode='ags'))
|
||||
# jsonl_ranker("samples.jsonl", "samples.jsonl")
|
||||
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,7 @@ TAG = "CONTENT"
|
|||
LANGUAGE_CONSTRAINT = "Language: Please use the same language as Human INPUT."
|
||||
FORMAT_CONSTRAINT = f"Format: output wrapped inside [{TAG}][/{TAG}] like format example, nothing else."
|
||||
|
||||
|
||||
SIMPLE_TEMPLATE = """
|
||||
## context
|
||||
{context}
|
||||
|
|
@ -147,6 +148,8 @@ class ActionNode:
|
|||
prevs: List["ActionNode"] # previous nodes
|
||||
nexts: List["ActionNode"] # next nodes
|
||||
|
||||
MODE_CODE_FILL = "code_fill"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
key: str,
|
||||
|
|
@ -464,6 +467,56 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
def get_field_name(self):
|
||||
"""
|
||||
Get the field name from the Pydantic model associated with this ActionNode.
|
||||
"""
|
||||
model_class = self.create_class()
|
||||
fields = model_class.model_fields
|
||||
|
||||
# Assuming there's only one field in the model
|
||||
if len(fields) == 1:
|
||||
return next(iter(fields))
|
||||
|
||||
# If there are multiple fields, we might want to use self.key to find the right one
|
||||
return self.key
|
||||
|
||||
async def code_fill(
|
||||
self,
|
||||
context,
|
||||
timeout=USE_CONFIG_TIMEOUT
|
||||
):
|
||||
"""
|
||||
fill CodeBlock Node
|
||||
"""
|
||||
|
||||
def extract_code_from_response(response):
|
||||
"""
|
||||
Extracts code wrapped in triple backticks from the response,
|
||||
removing any language specifier.
|
||||
|
||||
:param response: The full response from the LLM
|
||||
:return: The extracted code, or None if no code is found
|
||||
"""
|
||||
code_pattern = r"```(?:\w+\n)?([\s\S]*?)```"
|
||||
matches = re.findall(code_pattern, response)
|
||||
|
||||
if matches:
|
||||
# The first group in the regex contains the code without the language specifier
|
||||
code = matches[0].strip()
|
||||
return code
|
||||
return None
|
||||
|
||||
import re
|
||||
field_name = self.get_field_name()
|
||||
prompt = context
|
||||
prompt += "\nPlease wrap the generated code within triple backticks, like this: ```<code>```"
|
||||
content = await self.llm.aask(prompt, timeout=timeout)
|
||||
|
||||
extracted_code = extract_code_from_response(content)
|
||||
result = {field_name: extracted_code}
|
||||
return result
|
||||
|
||||
async def fill(
|
||||
self,
|
||||
context,
|
||||
|
|
@ -500,6 +553,11 @@ class ActionNode:
|
|||
if self.schema:
|
||||
schema = self.schema
|
||||
|
||||
if mode == self.MODE_CODE_FILL:
|
||||
result = await self.code_fill(context, timeout)
|
||||
self.instruct_content = self.create_class()(**result)
|
||||
return self
|
||||
|
||||
if strgy == "simple":
|
||||
return await self.simple_fill(schema=schema, mode=mode, images=images, timeout=timeout, exclude=exclude)
|
||||
elif strgy == "complex":
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue