format SPO code and add figure from paper

This commit is contained in:
isaacJinyu 2025-02-07 21:32:43 +08:00
parent 322003aad7
commit 852dc20a84
13 changed files with 156 additions and 148 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 93 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 200 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 294 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 268 KiB

View file

@ -4,6 +4,10 @@ # SPO 🤖 | Self-Supervised Prompt PromptOptimizer
A next-generation prompt engineering system implementing **Self-Supervised Prompt Optimization (SPO)**. Achieves state-of-the-art performance with 17.8-90.9× higher cost efficiency than conventional methods. 🚀
<p align="center">
<a href=""><img src="../../docs/resources/spo/SPO-method.png" alt="Framework of AFlow" title="Framework of AFlow <sub>1</sub>" width="80%"></a>
</p>
## ✨ Core Advantages
- 💸 **Ultra-Low Cost** - _$0.15 per task optimization_
@ -11,6 +15,25 @@ ## ✨ Core Advantages
- ⚡ **Universal Adaptation** - _Closed & open-ended tasks supported_
- 🔄 **Self-Evolving** - _Auto-optimization via LLM-as-judge mechanism_
[Read our paper on arXiv](coming soon)
## 📊 Experiment
### Closed Tasks
<p align="center">
<a href=""><img src="../../docs/resources/spo/SPO-closed_task_table.png" alt="Framework of AFlow" title="Framework of AFlow <sub>1</sub>" width="80%"></a>
<a href=""><img src="../../docs/resources/spo/SPO-closed_task_figure.png" alt="Framework of AFlow" title="Framework of AFlow <sub>1</sub>" width="80%"></a>
</p>
*SPO demonstrates superior cost efficiency, requiring only 1.1% to 5.6% of the cost of state-of-the-art methods while maintaining competitive performance.*
### Open-ended Tasks
<p align="center">
<a href=""><img src="../../docs/resources/spo/SPO-open-ended _task_figure.png" alt="Framework of AFlow" title="Framework of AFlow <sub>1</sub>" width="80%"></a>
</p>
*SPO significantly improves model performance across all model configurations in open-ended tasks.*
## 🚀 Quick Start
### 1. Configure Your API Key ⚙️

View file

@ -1,38 +1,27 @@
import argparse
from metagpt.ext.spo.components.optimizer import PromptOptimizer
from metagpt.ext.spo.utils.llm_client import SPO_LLM
def parse_args():
parser = argparse.ArgumentParser(description='SPO PromptOptimizer CLI')
parser = argparse.ArgumentParser(description="SPO PromptOptimizer CLI")
# LLM parameter
parser.add_argument('--opt-model', type=str, default='claude-3-5-sonnet-20240620',
help='Model for optimization')
parser.add_argument('--opt-temp', type=float, default=0.7,
help='Temperature for optimization')
parser.add_argument('--eval-model', type=str, default='gpt-4o-mini',
help='Model for evaluation')
parser.add_argument('--eval-temp', type=float, default=0.3,
help='Temperature for evaluation')
parser.add_argument('--exec-model', type=str, default='gpt-4o-mini',
help='Model for execution')
parser.add_argument('--exec-temp', type=float, default=0,
help='Temperature for execution')
parser.add_argument("--opt-model", type=str, default="claude-3-5-sonnet-20240620", help="Model for optimization")
parser.add_argument("--opt-temp", type=float, default=0.7, help="Temperature for optimization")
parser.add_argument("--eval-model", type=str, default="gpt-4o-mini", help="Model for evaluation")
parser.add_argument("--eval-temp", type=float, default=0.3, help="Temperature for evaluation")
parser.add_argument("--exec-model", type=str, default="gpt-4o-mini", help="Model for execution")
parser.add_argument("--exec-temp", type=float, default=0, help="Temperature for execution")
# PromptOptimizer parameter
parser.add_argument('--workspace', type=str, default='workspace',
help='Path for optimized output')
parser.add_argument('--initial-round', type=int, default=1,
help='Initial round number')
parser.add_argument('--max-rounds', type=int, default=10,
help='Maximum number of rounds')
parser.add_argument('--template', type=str, default='Poem.yaml',
help='Template file name')
parser.add_argument('--name', type=str, default='Poem',
help='Project name')
parser.add_argument('--no-iteration', action='store_false', dest='iteration',
help='Disable iteration mode')
parser.add_argument("--workspace", type=str, default="workspace", help="Path for optimized output")
parser.add_argument("--initial-round", type=int, default=1, help="Initial round number")
parser.add_argument("--max-rounds", type=int, default=10, help="Maximum number of rounds")
parser.add_argument("--template", type=str, default="Poem.yaml", help="Template file name")
parser.add_argument("--name", type=str, default="Poem", help="Project name")
parser.add_argument("--no-iteration", action="store_false", dest="iteration", help="Disable iteration mode")
return parser.parse_args()
@ -41,18 +30,9 @@ def main():
args = parse_args()
SPO_LLM.initialize(
optimize_kwargs={
"model": args.opt_model,
"temperature": args.opt_temp
},
evaluate_kwargs={
"model": args.eval_model,
"temperature": args.eval_temp
},
execute_kwargs={
"model": args.exec_model,
"temperature": args.exec_temp
}
optimize_kwargs={"model": args.opt_model, "temperature": args.opt_temp},
evaluate_kwargs={"model": args.eval_model, "temperature": args.eval_temp},
execute_kwargs={"model": args.exec_model, "temperature": args.exec_temp},
)
optimizer = PromptOptimizer(

View file

@ -3,11 +3,12 @@
# @Author : all
# @Desc : Evaluation for different datasets
import asyncio
from typing import Dict, Any
from metagpt.ext.spo.utils import load
from metagpt.ext.spo.prompts.evaluate_prompt import EVALUATE_PROMPT
import random
from metagpt.ext.spo.utils.llm_client import SPO_LLM, extract_content
from typing import Any, Dict
from metagpt.ext.spo.prompts.evaluate_prompt import EVALUATE_PROMPT
from metagpt.ext.spo.utils import load
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType, extract_content
from metagpt.logs import logger
@ -17,7 +18,6 @@ class QuickExecute:
"""
def __init__(self, prompt: str):
self.prompt = prompt
self.llm = SPO_LLM.get_instance()
@ -28,12 +28,12 @@ class QuickExecute:
async def fetch_answer(q: str) -> Dict[str, Any]:
messages = [{"role": "user", "content": f"{self.prompt}\n\n{q}"}]
try:
answer = await self.llm.responser(type="execute", messages=messages)
return {'question': q, 'answer': answer}
answer = await self.llm.responser(request_type=RequestType.EXECUTE, messages=messages)
return {"question": q, "answer": answer}
except Exception as e:
return {'question': q, 'answer': str(e)}
return {"question": q, "answer": str(e)}
tasks = [fetch_answer(item['question']) for item in qa]
tasks = [fetch_answer(item["question"]) for item in qa]
answers = await asyncio.gather(*tasks)
return answers
@ -56,15 +56,18 @@ class QuickEvaluate:
else:
is_swapped = False
messages = [{"role": "user", "content": EVALUATE_PROMPT.format(
requirement=requirement,
sample=samples,
new_sample=new_samples,
answers=str(qa))}]
messages = [
{
"role": "user",
"content": EVALUATE_PROMPT.format(
requirement=requirement, sample=samples, new_sample=new_samples, answers=str(qa)
),
}
]
try:
response = await self.llm.responser(type="evaluate", messages=messages)
choose = extract_content(response, 'choose')
response = await self.llm.responser(request_type=RequestType.EVALUATE, messages=messages)
choose = extract_content(response, "choose")
return choose == "A" if is_swapped else choose == "B"
except Exception as e:
@ -72,9 +75,8 @@ class QuickEvaluate:
return False
if __name__ == "__main__":
execute = QuickExecute(prompt="Answer the Question")
executor = QuickExecute(prompt="Answer the Question")
answers = asyncio.run(execute.prompt_evaluate())
answers = asyncio.run(executor.prompt_execute())
print(answers)

View file

@ -4,27 +4,26 @@
# @Desc : optimizer for prompt
import asyncio
from metagpt.ext.spo.utils.data_utils import DataUtils
from metagpt.ext.spo.utils.evaluation_utils import EvaluationUtils
from metagpt.ext.spo.utils.prompt_utils import PromptUtils
from metagpt.ext.spo.prompts.optimize_prompt import PROMPT_OPTIMIZE_PROMPT
from metagpt.ext.spo.utils import load
from metagpt.ext.spo.utils.data_utils import DataUtils
from metagpt.ext.spo.utils.evaluation_utils import EvaluationUtils
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType, extract_content
from metagpt.ext.spo.utils.prompt_utils import PromptUtils
from metagpt.logs import logger
from metagpt.ext.spo.utils.llm_client import extract_content, SPO_LLM
class PromptOptimizer:
def __init__(
self,
optimized_path: str = None,
initial_round: int = 1,
max_rounds: int = 10,
name: str = "",
template: str = "",
iteration: bool = True,
self,
optimized_path: str = None,
initial_round: int = 1,
max_rounds: int = 10,
name: str = "",
template: str = "",
iteration: bool = True,
) -> None:
self.dataset = name
self.root_path = f"{optimized_path}/{self.dataset}"
self.top_scores = []
@ -40,7 +39,6 @@ class PromptOptimizer:
def optimize(self):
if self.iteration:
for opt_round in range(self.max_rounds):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
@ -55,7 +53,6 @@ class PromptOptimizer:
logger.info(f"Prompt generated in round {self.round}: {prompt}")
async def _optimize_prompt(self):
prompt_path = f"{self.root_path}/prompts"
load.set_file_name(self.template)
@ -69,18 +66,16 @@ class PromptOptimizer:
self.prompt = prompt
self.prompt_utils.write_prompt(directory, prompt=self.prompt)
new_samples = await self.evaluation_utils.execute_prompt(self, directory, initial=True)
_, answers = await self.evaluation_utils.evaluate_prompt(self, None, new_samples, path=prompt_path,
data=data, initial=True)
_, answers = await self.evaluation_utils.evaluate_prompt(
self, None, new_samples, path=prompt_path, data=data, initial=True
)
self.prompt_utils.write_answers(directory, answers=answers)
_, requirements, qa, count = load.load_meta_data()
directory = self.prompt_utils.create_round_directory(prompt_path, self.round + 1)
top_round = self.data_utils.get_best_round()
samples = top_round
samples = self.data_utils.get_best_round()
logger.info(f"choose {samples['round']}")
@ -88,12 +83,16 @@ class PromptOptimizer:
best_answer = self.data_utils.list_to_markdown(samples["answers"])
optimize_prompt = PROMPT_OPTIMIZE_PROMPT.format(
prompt=samples["prompt"], answers=best_answer,
prompt=samples["prompt"],
answers=best_answer,
requirements=requirements,
golden_answers=golden_answer,
count=count)
count=count,
)
response = await self.llm.responser(type="optimize", messages=[{"role": "user", "content": optimize_prompt}])
response = await self.llm.responser(
request_type=RequestType.OPTIMIZE, messages=[{"role": "user", "content": optimize_prompt}]
)
modification = extract_content(response, "modification")
@ -110,8 +109,9 @@ class PromptOptimizer:
new_samples = await self.evaluation_utils.execute_prompt(self, directory, data)
success, answers = await self.evaluation_utils.evaluate_prompt(self, samples, new_samples, path=prompt_path,
data=data, initial=False)
success, answers = await self.evaluation_utils.evaluate_prompt(
self, samples, new_samples, path=prompt_path, data=data, initial=False
)
self.prompt_utils.write_answers(directory, answers=answers)
@ -122,7 +122,6 @@ class PromptOptimizer:
return prompt
async def _test_prompt(self):
load.set_file_name(self.template)
prompt_path = f"{self.root_path}/prompts"

View file

@ -1,11 +1,11 @@
import datetime
import json
import os
import random
from typing import Union, List, Dict
import pandas as pd
from metagpt.logs import logger
from typing import Dict, List, Union
import pandas as pd
from metagpt.logs import logger
class DataUtils:
@ -14,7 +14,7 @@ class DataUtils:
self.top_scores = []
def load_results(self, path: str) -> list:
result_path = os.path.join(path, "results.json")
result_path = self.get_results_file_path(path)
if os.path.exists(result_path):
with open(result_path, "r") as json_file:
try:
@ -24,7 +24,6 @@ class DataUtils:
return []
def get_best_round(self):
self._load_scores()
for entry in self.top_scores:
@ -44,11 +43,6 @@ class DataUtils:
with open(json_file_path, "w") as json_file:
json.dump(data, json_file, default=str, indent=4)
def save_cost(self, directory: str, data: Union[List, Dict]):
json_file = os.path.join(directory, 'cost.json')
with open(json_file, "w", encoding="utf-8") as file:
json.dump(data, file, default=str, indent=4)
def _load_scores(self):
rounds_dir = os.path.join(self.root_path, "prompts")
result_file = os.path.join(rounds_dir, "results.json")
@ -65,12 +59,14 @@ class DataUtils:
df = pd.DataFrame(data)
for index, row in df.iterrows():
self.top_scores.append({
"round": row["round"],
"succeed": row["succeed"],
"prompt": row["prompt"],
"answers": row['answers']
})
self.top_scores.append(
{
"round": row["round"],
"succeed": row["succeed"],
"prompt": row["prompt"],
"answers": row["answers"],
}
)
self.top_scores.sort(key=lambda x: x["round"], reverse=True)

View file

@ -1,22 +1,24 @@
import tiktoken
from metagpt.ext.spo.components.evaluator import QuickEvaluate, QuickExecute
from metagpt.logs import logger
import tiktoken
EVALUATION_REPETITION = 4
def count_tokens(sample):
if sample is None:
if not sample:
return 0
else:
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(str(sample['answers'])))
return len(encoding.encode(str(sample["answers"])))
class EvaluationUtils:
def __init__(self, root_path: str):
self.root_path = root_path
async def execute_prompt(self, optimizer, prompt_path, initial=False):
optimizer.prompt = optimizer.prompt_utils.load_prompt(optimizer.round, prompt_path)
executor = QuickExecute(prompt=optimizer.prompt)
@ -29,7 +31,6 @@ class EvaluationUtils:
return new_data
async def evaluate_prompt(self, optimizer, samples, new_samples, path, data, initial=False):
evaluator = QuickEvaluate()
new_token = count_tokens(new_samples)
@ -47,8 +48,9 @@ class EvaluationUtils:
false_count = evaluation_results.count(False)
succeed = true_count > false_count
new_data = optimizer.data_utils.create_result_data(new_samples['round'], new_samples['answers'],
new_samples['prompt'], succeed, new_token)
new_data = optimizer.data_utils.create_result_data(
new_samples["round"], new_samples["answers"], new_samples["prompt"], succeed, new_token
)
data.append(new_data)
@ -56,6 +58,6 @@ class EvaluationUtils:
optimizer.data_utils.save_results(result_path, data)
answers = new_samples['answers']
answers = new_samples["answers"]
return succeed, answers

View file

@ -1,12 +1,20 @@
import asyncio
import re
from enum import Enum
from typing import Optional
from metagpt.configs.models_config import ModelsConfig
from metagpt.llm import LLM
import asyncio
class RequestType(Enum):
OPTIMIZE = "optimize"
EVALUATE = "evaluate"
EXECUTE = "execute"
class SPO_LLM:
_instance: Optional['SPO_LLM'] = None
_instance: Optional["SPO_LLM"] = None
def __init__(self, optimize_kwargs=None, evaluate_kwargs=None, execute_kwargs=None):
self.evaluate_llm = LLM(llm_config=self._load_llm_config(evaluate_kwargs))
@ -14,7 +22,7 @@ class SPO_LLM:
self.execute_llm = LLM(llm_config=self._load_llm_config(execute_kwargs))
def _load_llm_config(self, kwargs: dict):
model = kwargs.get('model')
model = kwargs.get("model")
if not model:
raise ValueError("'model' parameter is required")
@ -31,23 +39,24 @@ class SPO_LLM:
return config
except AttributeError as e:
except AttributeError:
raise ValueError(f"Model '{model}' not found in configuration")
except Exception as e:
raise ValueError(f"Error loading configuration for model '{model}': {str(e)}")
async def responser(self, type: str, messages):
if type == "optimize":
response = await self.optimize_llm.acompletion(messages)
elif type == "evaluate":
response = await self.evaluate_llm.acompletion(messages)
elif type == "execute":
response = await self.execute_llm.acompletion(messages)
else:
raise ValueError("Please set the correct name: optimize, evaluate or execute")
async def responser(self, request_type: RequestType, messages: list):
llm_mapping = {
RequestType.OPTIMIZE: self.optimize_llm,
RequestType.EVALUATE: self.evaluate_llm,
RequestType.EXECUTE: self.execute_llm,
}
rsp = response.choices[0].message.content
return rsp
llm = llm_mapping.get(request_type)
if not llm:
raise ValueError(f"Invalid request type. Valid types: {', '.join([t.value for t in RequestType])}")
response = await llm.acompletion(messages)
return response.choices[0].message.content
@classmethod
def initialize(cls, optimize_kwargs, evaluate_kwargs, execute_kwargs):
@ -61,8 +70,9 @@ class SPO_LLM:
raise RuntimeError("SPO_LLM not initialized. Call initialize() first.")
return cls._instance
def extract_content(xml_string, tag):
pattern = rf'<{tag}>(.*?)</{tag}>'
pattern = rf"<{tag}>(.*?)</{tag}>"
match = re.search(pattern, xml_string, re.DOTALL)
return match.group(1).strip() if match else None
@ -72,23 +82,20 @@ async def spo():
SPO_LLM.initialize(
optimize_kwargs={"model": "gpt-4o", "temperature": 0.7},
evaluate_kwargs={"model": "gpt-4o-mini", "temperature": 0.3},
execute_kwargs={"model": "gpt-4o-mini", "temperature": 0.3}
execute_kwargs={"model": "gpt-4o-mini", "temperature": 0.3},
)
llm = SPO_LLM.get_instance()
# test messages
hello_msg = [{"role": "user", "content": "hello"}]
response = await llm.responser(type='execute', messages=hello_msg)
response = await llm.responser(request_type=RequestType.EXECUTE, messages=hello_msg)
print(f"AI: {response}")
response = await llm.responser(type='optimize', messages=hello_msg)
response = await llm.responser(request_type=RequestType.OPTIMIZE, messages=hello_msg)
print(f"AI: {response}")
response = await llm.responser(type='evaluate', messages=hello_msg)
response = await llm.responser(request_type=RequestType.EVALUATE, messages=hello_msg)
print(f"AI: {response}")
if __name__ == "__main__":
asyncio.run(spo())

View file

@ -1,10 +1,12 @@
import yaml
import random
import os
import random
FILE_NAME = ''
import yaml
FILE_NAME = ""
SAMPLE_K = 3
def set_file_name(name):
global FILE_NAME
FILE_NAME = name
@ -12,13 +14,13 @@ def set_file_name(name):
def load_meta_data(k=SAMPLE_K):
# load yaml file
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'settings', FILE_NAME)
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "settings", FILE_NAME)
if not os.path.exists(config_path):
raise FileNotFoundError(f"Configuration file '{FILE_NAME}' not found in settings directory")
try:
with open(config_path, 'r', encoding='utf-8') as file:
with open(config_path, "r", encoding="utf-8") as file:
data = yaml.safe_load(file)
except yaml.YAMLError as e:
raise ValueError(f"Error parsing YAML file '{FILE_NAME}': {str(e)}")
@ -27,14 +29,14 @@ def load_meta_data(k=SAMPLE_K):
qa = []
for item in data['faq']:
question = item['question']
answer = item['answer']
qa.append({'question': question, 'answer': answer})
for item in data["faq"]:
question = item["question"]
answer = item["answer"]
qa.append({"question": question, "answer": answer})
prompt = data['prompt']
requirements = data['requirements']
count = data['count']
prompt = data["prompt"]
requirements = data["requirements"]
count = data["count"]
if isinstance(count, int):
count = f", within {count} words"
@ -44,4 +46,3 @@ def load_meta_data(k=SAMPLE_K):
random_qa = random.sample(qa, min(k, len(qa)))
return prompt, requirements, random_qa, count

View file

@ -1,4 +1,5 @@
import os
from metagpt.logs import logger
@ -15,14 +16,13 @@ class PromptUtils:
prompt_file_name = f"{prompts_path}/prompt.txt"
try:
with open(prompt_file_name, 'r', encoding='utf-8') as file:
with open(prompt_file_name, "r", encoding="utf-8") as file:
return file.read()
except FileNotFoundError as e:
logger.info(f"Error loading prompt for round {round_number}: {e}")
raise
def write_answers(self, directory: str, answers: dict, name: str = "answers.txt"):
with open(os.path.join(directory, name), "w", encoding="utf-8") as file:
for item in answers:
file.write(f"Question:\n{item['question']}\n")
@ -30,7 +30,5 @@ class PromptUtils:
file.write("\n")
def write_prompt(self, directory: str, prompt: str):
with open(os.path.join(directory, "prompt.txt"), "w", encoding="utf-8") as file:
file.write(prompt)