format SPO code and add figure from paper

This commit is contained in:
isaacJinyu 2025-02-07 21:32:43 +08:00
parent 322003aad7
commit 852dc20a84
13 changed files with 156 additions and 148 deletions

View file

@ -3,11 +3,12 @@
# @Author : all
# @Desc : Evaluation for different datasets
import asyncio
from typing import Dict, Any
from metagpt.ext.spo.utils import load
from metagpt.ext.spo.prompts.evaluate_prompt import EVALUATE_PROMPT
import random
from metagpt.ext.spo.utils.llm_client import SPO_LLM, extract_content
from typing import Any, Dict
from metagpt.ext.spo.prompts.evaluate_prompt import EVALUATE_PROMPT
from metagpt.ext.spo.utils import load
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType, extract_content
from metagpt.logs import logger
@ -17,7 +18,6 @@ class QuickExecute:
"""
def __init__(self, prompt: str):
self.prompt = prompt
self.llm = SPO_LLM.get_instance()
@ -28,12 +28,12 @@ class QuickExecute:
async def fetch_answer(q: str) -> Dict[str, Any]:
messages = [{"role": "user", "content": f"{self.prompt}\n\n{q}"}]
try:
answer = await self.llm.responser(type="execute", messages=messages)
return {'question': q, 'answer': answer}
answer = await self.llm.responser(request_type=RequestType.EXECUTE, messages=messages)
return {"question": q, "answer": answer}
except Exception as e:
return {'question': q, 'answer': str(e)}
return {"question": q, "answer": str(e)}
tasks = [fetch_answer(item['question']) for item in qa]
tasks = [fetch_answer(item["question"]) for item in qa]
answers = await asyncio.gather(*tasks)
return answers
@ -56,15 +56,18 @@ class QuickEvaluate:
else:
is_swapped = False
messages = [{"role": "user", "content": EVALUATE_PROMPT.format(
requirement=requirement,
sample=samples,
new_sample=new_samples,
answers=str(qa))}]
messages = [
{
"role": "user",
"content": EVALUATE_PROMPT.format(
requirement=requirement, sample=samples, new_sample=new_samples, answers=str(qa)
),
}
]
try:
response = await self.llm.responser(type="evaluate", messages=messages)
choose = extract_content(response, 'choose')
response = await self.llm.responser(request_type=RequestType.EVALUATE, messages=messages)
choose = extract_content(response, "choose")
return choose == "A" if is_swapped else choose == "B"
except Exception as e:
@ -72,9 +75,8 @@ class QuickEvaluate:
return False
if __name__ == "__main__":
execute = QuickExecute(prompt="Answer the Question")
executor = QuickExecute(prompt="Answer the Question")
answers = asyncio.run(execute.prompt_evaluate())
answers = asyncio.run(executor.prompt_execute())
print(answers)

View file

@ -4,27 +4,26 @@
# @Desc : optimizer for prompt
import asyncio
from metagpt.ext.spo.utils.data_utils import DataUtils
from metagpt.ext.spo.utils.evaluation_utils import EvaluationUtils
from metagpt.ext.spo.utils.prompt_utils import PromptUtils
from metagpt.ext.spo.prompts.optimize_prompt import PROMPT_OPTIMIZE_PROMPT
from metagpt.ext.spo.utils import load
from metagpt.ext.spo.utils.data_utils import DataUtils
from metagpt.ext.spo.utils.evaluation_utils import EvaluationUtils
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType, extract_content
from metagpt.ext.spo.utils.prompt_utils import PromptUtils
from metagpt.logs import logger
from metagpt.ext.spo.utils.llm_client import extract_content, SPO_LLM
class PromptOptimizer:
def __init__(
self,
optimized_path: str = None,
initial_round: int = 1,
max_rounds: int = 10,
name: str = "",
template: str = "",
iteration: bool = True,
self,
optimized_path: str = None,
initial_round: int = 1,
max_rounds: int = 10,
name: str = "",
template: str = "",
iteration: bool = True,
) -> None:
self.dataset = name
self.root_path = f"{optimized_path}/{self.dataset}"
self.top_scores = []
@ -40,7 +39,6 @@ class PromptOptimizer:
def optimize(self):
if self.iteration:
for opt_round in range(self.max_rounds):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@ -55,7 +53,6 @@ class PromptOptimizer:
logger.info(f"Prompt generated in round {self.round}: {prompt}")
async def _optimize_prompt(self):
prompt_path = f"{self.root_path}/prompts"
load.set_file_name(self.template)
@ -69,18 +66,16 @@ class PromptOptimizer:
self.prompt = prompt
self.prompt_utils.write_prompt(directory, prompt=self.prompt)
new_samples = await self.evaluation_utils.execute_prompt(self, directory, initial=True)
_, answers = await self.evaluation_utils.evaluate_prompt(self, None, new_samples, path=prompt_path,
data=data, initial=True)
_, answers = await self.evaluation_utils.evaluate_prompt(
self, None, new_samples, path=prompt_path, data=data, initial=True
)
self.prompt_utils.write_answers(directory, answers=answers)
_, requirements, qa, count = load.load_meta_data()
directory = self.prompt_utils.create_round_directory(prompt_path, self.round + 1)
top_round = self.data_utils.get_best_round()
samples = top_round
samples = self.data_utils.get_best_round()
logger.info(f"choose {samples['round']}")
@ -88,12 +83,16 @@ class PromptOptimizer:
best_answer = self.data_utils.list_to_markdown(samples["answers"])
optimize_prompt = PROMPT_OPTIMIZE_PROMPT.format(
prompt=samples["prompt"], answers=best_answer,
prompt=samples["prompt"],
answers=best_answer,
requirements=requirements,
golden_answers=golden_answer,
count=count)
count=count,
)
response = await self.llm.responser(type="optimize", messages=[{"role": "user", "content": optimize_prompt}])
response = await self.llm.responser(
request_type=RequestType.OPTIMIZE, messages=[{"role": "user", "content": optimize_prompt}]
)
modification = extract_content(response, "modification")
@ -110,8 +109,9 @@ class PromptOptimizer:
new_samples = await self.evaluation_utils.execute_prompt(self, directory, data)
success, answers = await self.evaluation_utils.evaluate_prompt(self, samples, new_samples, path=prompt_path,
data=data, initial=False)
success, answers = await self.evaluation_utils.evaluate_prompt(
self, samples, new_samples, path=prompt_path, data=data, initial=False
)
self.prompt_utils.write_answers(directory, answers=answers)
@ -122,7 +122,6 @@ class PromptOptimizer:
return prompt
async def _test_prompt(self):
load.set_file_name(self.template)
prompt_path = f"{self.root_path}/prompts"

View file

@ -1,11 +1,11 @@
import datetime
import json
import os
import random
from typing import Union, List, Dict
import pandas as pd
from metagpt.logs import logger
from typing import Dict, List, Union
import pandas as pd
from metagpt.logs import logger
class DataUtils:
@ -14,7 +14,7 @@ class DataUtils:
self.top_scores = []
def load_results(self, path: str) -> list:
result_path = os.path.join(path, "results.json")
result_path = self.get_results_file_path(path)
if os.path.exists(result_path):
with open(result_path, "r") as json_file:
try:
@ -24,7 +24,6 @@ class DataUtils:
return []
def get_best_round(self):
self._load_scores()
for entry in self.top_scores:
@ -44,11 +43,6 @@ class DataUtils:
with open(json_file_path, "w") as json_file:
json.dump(data, json_file, default=str, indent=4)
def save_cost(self, directory: str, data: Union[List, Dict]):
json_file = os.path.join(directory, 'cost.json')
with open(json_file, "w", encoding="utf-8") as file:
json.dump(data, file, default=str, indent=4)
def _load_scores(self):
rounds_dir = os.path.join(self.root_path, "prompts")
result_file = os.path.join(rounds_dir, "results.json")
@ -65,12 +59,14 @@ class DataUtils:
df = pd.DataFrame(data)
for index, row in df.iterrows():
self.top_scores.append({
"round": row["round"],
"succeed": row["succeed"],
"prompt": row["prompt"],
"answers": row['answers']
})
self.top_scores.append(
{
"round": row["round"],
"succeed": row["succeed"],
"prompt": row["prompt"],
"answers": row["answers"],
}
)
self.top_scores.sort(key=lambda x: x["round"], reverse=True)

View file

@ -1,22 +1,24 @@
import tiktoken
from metagpt.ext.spo.components.evaluator import QuickEvaluate, QuickExecute
from metagpt.logs import logger
import tiktoken
EVALUATION_REPETITION = 4
def count_tokens(sample):
if sample is None:
if not sample:
return 0
else:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(str(sample['answers'])))
return len(encoding.encode(str(sample["answers"])))
class EvaluationUtils:
def __init__(self, root_path: str):
self.root_path = root_path
async def execute_prompt(self, optimizer, prompt_path, initial=False):
optimizer.prompt = optimizer.prompt_utils.load_prompt(optimizer.round, prompt_path)
executor = QuickExecute(prompt=optimizer.prompt)
@ -29,7 +31,6 @@ class EvaluationUtils:
return new_data
async def evaluate_prompt(self, optimizer, samples, new_samples, path, data, initial=False):
evaluator = QuickEvaluate()
new_token = count_tokens(new_samples)
@ -47,8 +48,9 @@ class EvaluationUtils:
false_count = evaluation_results.count(False)
succeed = true_count > false_count
new_data = optimizer.data_utils.create_result_data(new_samples['round'], new_samples['answers'],
new_samples['prompt'], succeed, new_token)
new_data = optimizer.data_utils.create_result_data(
new_samples["round"], new_samples["answers"], new_samples["prompt"], succeed, new_token
)
data.append(new_data)
@ -56,6 +58,6 @@ class EvaluationUtils:
optimizer.data_utils.save_results(result_path, data)
answers = new_samples['answers']
answers = new_samples["answers"]
return succeed, answers

View file

@ -1,12 +1,20 @@
import asyncio
import re
from enum import Enum
from typing import Optional
from metagpt.configs.models_config import ModelsConfig
from metagpt.llm import LLM
import asyncio
class RequestType(Enum):
OPTIMIZE = "optimize"
EVALUATE = "evaluate"
EXECUTE = "execute"
class SPO_LLM:
_instance: Optional['SPO_LLM'] = None
_instance: Optional["SPO_LLM"] = None
def __init__(self, optimize_kwargs=None, evaluate_kwargs=None, execute_kwargs=None):
self.evaluate_llm = LLM(llm_config=self._load_llm_config(evaluate_kwargs))
@ -14,7 +22,7 @@ class SPO_LLM:
self.execute_llm = LLM(llm_config=self._load_llm_config(execute_kwargs))
def _load_llm_config(self, kwargs: dict):
model = kwargs.get('model')
model = kwargs.get("model")
if not model:
raise ValueError("'model' parameter is required")
@ -31,23 +39,24 @@ class SPO_LLM:
return config
except AttributeError as e:
except AttributeError:
raise ValueError(f"Model '{model}' not found in configuration")
except Exception as e:
raise ValueError(f"Error loading configuration for model '{model}': {str(e)}")
async def responser(self, type: str, messages):
if type == "optimize":
response = await self.optimize_llm.acompletion(messages)
elif type == "evaluate":
response = await self.evaluate_llm.acompletion(messages)
elif type == "execute":
response = await self.execute_llm.acompletion(messages)
else:
raise ValueError("Please set the correct name: optimize, evaluate or execute")
async def responser(self, request_type: RequestType, messages: list):
llm_mapping = {
RequestType.OPTIMIZE: self.optimize_llm,
RequestType.EVALUATE: self.evaluate_llm,
RequestType.EXECUTE: self.execute_llm,
}
rsp = response.choices[0].message.content
return rsp
llm = llm_mapping.get(request_type)
if not llm:
raise ValueError(f"Invalid request type. Valid types: {', '.join([t.value for t in RequestType])}")
response = await llm.acompletion(messages)
return response.choices[0].message.content
@classmethod
def initialize(cls, optimize_kwargs, evaluate_kwargs, execute_kwargs):
@ -61,8 +70,9 @@ class SPO_LLM:
raise RuntimeError("SPO_LLM not initialized. Call initialize() first.")
return cls._instance
def extract_content(xml_string, tag):
pattern = rf'<{tag}>(.*?)</{tag}>'
pattern = rf"<{tag}>(.*?)</{tag}>"
match = re.search(pattern, xml_string, re.DOTALL)
return match.group(1).strip() if match else None
@ -72,23 +82,20 @@ async def spo():
SPO_LLM.initialize(
optimize_kwargs={"model": "gpt-4o", "temperature": 0.7},
evaluate_kwargs={"model": "gpt-4o-mini", "temperature": 0.3},
execute_kwargs={"model": "gpt-4o-mini", "temperature": 0.3}
execute_kwargs={"model": "gpt-4o-mini", "temperature": 0.3},
)
llm = SPO_LLM.get_instance()
# test messages
hello_msg = [{"role": "user", "content": "hello"}]
response = await llm.responser(type='execute', messages=hello_msg)
response = await llm.responser(request_type=RequestType.EXECUTE, messages=hello_msg)
print(f"AI: {response}")
response = await llm.responser(type='optimize', messages=hello_msg)
response = await llm.responser(request_type=RequestType.OPTIMIZE, messages=hello_msg)
print(f"AI: {response}")
response = await llm.responser(type='evaluate', messages=hello_msg)
response = await llm.responser(request_type=RequestType.EVALUATE, messages=hello_msg)
print(f"AI: {response}")
if __name__ == "__main__":
asyncio.run(spo())

View file

@ -1,10 +1,12 @@
import yaml
import random
import os
import random
FILE_NAME = ''
import yaml
FILE_NAME = ""
SAMPLE_K = 3
def set_file_name(name):
global FILE_NAME
FILE_NAME = name
@ -12,13 +14,13 @@ def set_file_name(name):
def load_meta_data(k=SAMPLE_K):
# load yaml file
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings', FILE_NAME)
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "settings", FILE_NAME)
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file '{FILE_NAME}' not found in settings directory")
try:
with open(config_path, 'r', encoding='utf-8') as file:
with open(config_path, "r", encoding="utf-8") as file:
data = yaml.safe_load(file)
except yaml.YAMLError as e:
raise ValueError(f"Error parsing YAML file '{FILE_NAME}': {str(e)}")
@ -27,14 +29,14 @@ def load_meta_data(k=SAMPLE_K):
qa = []
for item in data['faq']:
question = item['question']
answer = item['answer']
qa.append({'question': question, 'answer': answer})
for item in data["faq"]:
question = item["question"]
answer = item["answer"]
qa.append({"question": question, "answer": answer})
prompt = data['prompt']
requirements = data['requirements']
count = data['count']
prompt = data["prompt"]
requirements = data["requirements"]
count = data["count"]
if isinstance(count, int):
count = f", within {count} words"
@ -44,4 +46,3 @@ def load_meta_data(k=SAMPLE_K):
random_qa = random.sample(qa, min(k, len(qa)))
return prompt, requirements, random_qa, count

View file

@ -1,4 +1,5 @@
import os
from metagpt.logs import logger
@ -15,14 +16,13 @@ class PromptUtils:
prompt_file_name = f"{prompts_path}/prompt.txt"
try:
with open(prompt_file_name, 'r', encoding='utf-8') as file:
with open(prompt_file_name, "r", encoding="utf-8") as file:
return file.read()
except FileNotFoundError as e:
logger.info(f"Error loading prompt for round {round_number}: {e}")
raise
def write_answers(self, directory: str, answers: dict, name: str = "answers.txt"):
with open(os.path.join(directory, name), "w", encoding="utf-8") as file:
for item in answers:
file.write(f"Question:\n{item['question']}\n")
@ -30,7 +30,5 @@ class PromptUtils:
file.write("\n")
def write_prompt(self, directory: str, prompt: str):
with open(os.path.join(directory, "prompt.txt"), "w", encoding="utf-8") as file:
file.write(prompt)