mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-04-25 00:36:55 +02:00
adapt SPO to MetaGPT
This commit is contained in:
parent
da1e103372
commit
a56b0e340a
8 changed files with 190 additions and 70 deletions
22
metagpt/ext/spo/optimize.py
Normal file
22
metagpt/ext/spo/optimize.py
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
from metagpt.ext.spo.scripts.optimizer import Optimizer
|
||||
from metagpt.ext.spo.scripts.utils.llm_client import SPO_LLM
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
SPO_LLM.initialize(
|
||||
optimize_kwargs={"model": "claude-3-5-sonnet-20240620", "temperature": 0.7},
|
||||
evaluate_kwargs={"model": "gpt-4o-mini", "temperature": 0.3},
|
||||
execute_kwargs={"model": "gpt-4o-mini", "temperature": 0.3}
|
||||
)
|
||||
|
||||
optimizer = Optimizer(
|
||||
optimized_path="workspace",
|
||||
initial_round=1,
|
||||
max_rounds=10,
|
||||
template="Poem.yaml",
|
||||
name="Poem",
|
||||
iteration=True,
|
||||
)
|
||||
|
||||
optimizer.optimize()
|
||||
|
|
@ -5,10 +5,10 @@
|
|||
import asyncio
|
||||
from typing import Dict, Literal, Tuple, List, Any
|
||||
|
||||
from utils import load
|
||||
from utils.llm_client import responser, extract_content
|
||||
from prompt.evaluate_prompt import EVALUATE_PROMPT
|
||||
from metagpt.ext.spo.scripts.utils import load
|
||||
from metagpt.ext.spo.prompts.evaluate_prompt import EVALUATE_PROMPT
|
||||
import random
|
||||
from metagpt.ext.spo.scripts.utils.llm_client import SPO_LLM, extract_content
|
||||
|
||||
|
||||
class QuickExecute:
|
||||
|
|
@ -16,21 +16,20 @@ class QuickExecute:
|
|||
完成不同数据集的评估。
|
||||
"""
|
||||
|
||||
def __init__(self, prompt: str, k: int = 3, model=None):
|
||||
def __init__(self, prompt: str):
|
||||
|
||||
self.prompt = prompt
|
||||
self.k = k
|
||||
self.model = model
|
||||
self.llm = SPO_LLM.get_instance()
|
||||
|
||||
async def prompt_execute(self) -> tuple[Any]:
|
||||
_, _, qa, _ = load.load_meta_data(k=self.k)
|
||||
_, _, qa, _ = load.load_meta_data()
|
||||
answers = []
|
||||
|
||||
async def fetch_answer(q: str) -> Dict[str, Any]:
|
||||
messages = [{"role": "user", "content": f"{self.prompt}\n\n{q}"}]
|
||||
try:
|
||||
answer = await responser(messages, model=self.model['name'], temperature=self.model['temperature'])
|
||||
return {'question': q, 'answer': answer.content}
|
||||
answer = await self.llm.responser(role="execute", messages=messages)
|
||||
return {'question': q, 'answer': answer}
|
||||
except Exception as e:
|
||||
return {'question': q, 'answer': str(e)}
|
||||
|
||||
|
|
@ -45,11 +44,11 @@ class QuickEvaluate:
|
|||
Complete the evaluation for different datasets here.
|
||||
"""
|
||||
|
||||
def __init__(self, k: int = 3):
|
||||
self.k = k
|
||||
def __init__(self):
|
||||
self.llm = SPO_LLM.get_instance()
|
||||
|
||||
async def prompt_evaluate(self, sample: list, new_sample: list, model: dict) -> bool:
|
||||
_, requirement, qa, _ = load.load_meta_data(k=self.k)
|
||||
async def prompt_evaluate(self, sample: list, new_sample: list) -> bool:
|
||||
_, requirement, qa, _ = load.load_meta_data()
|
||||
|
||||
if random.random() < 0.5:
|
||||
sample, new_sample = new_sample, sample
|
||||
|
|
@ -64,8 +63,8 @@ class QuickEvaluate:
|
|||
answers=str(qa))}]
|
||||
|
||||
try:
|
||||
response = await responser(messages, model=model['name'], temperature=model['temperature'])
|
||||
choose = extract_content(response.content, 'choose')
|
||||
response = await self.llm.responser(role="evaluate", messages=messages)
|
||||
choose = extract_content(response, 'choose')
|
||||
|
||||
if is_swapped:
|
||||
return choose == "A"
|
||||
|
|
@ -78,7 +77,7 @@ class QuickEvaluate:
|
|||
|
||||
|
||||
if __name__ == "__main__":
|
||||
execute = QuickExecute(prompt="Answer the Question,{question}", k=3)
|
||||
execute = QuickExecute(prompt="Answer the Question,{question}")
|
||||
|
||||
# 使用asyncio.run来运行异步方法
|
||||
answers = asyncio.run(execute.prompt_evaluate())
|
||||
|
|
|
|||
|
|
@ -5,14 +5,14 @@
|
|||
|
||||
import asyncio
|
||||
import time
|
||||
from optimizer_utils.data_utils import DataUtils
|
||||
from optimizer_utils.evaluation_utils import EvaluationUtils
|
||||
from optimizer_utils.prompt_utils import PromptUtils
|
||||
from prompt.optimize_prompt import PROMPT_OPTIMIZE_PROMPT
|
||||
from utils import load
|
||||
from utils.logs import logger
|
||||
from utils.llm_client import responser, extract_content
|
||||
from utils.token_manager import get_token_tracker
|
||||
from metagpt.ext.spo.scripts.utils.data_utils import DataUtils
|
||||
from metagpt.ext.spo.scripts.utils.evaluation_utils import EvaluationUtils
|
||||
from metagpt.ext.spo.scripts.utils.prompt_utils import PromptUtils
|
||||
from metagpt.ext.spo.prompts.optimize_prompt import PROMPT_OPTIMIZE_PROMPT
|
||||
from metagpt.ext.spo.scripts.utils import load
|
||||
from metagpt.logs import logger
|
||||
from metagpt.ext.spo.scripts.utils.llm_client import extract_content, SPO_LLM
|
||||
|
||||
|
||||
|
||||
class Optimizer:
|
||||
|
|
@ -21,11 +21,8 @@ class Optimizer:
|
|||
optimized_path: str = None,
|
||||
initial_round: int = 1,
|
||||
max_rounds: int = 10,
|
||||
name: str = "test",
|
||||
template: str = "meta.yaml",
|
||||
execute_model=None,
|
||||
optimize_model=None,
|
||||
evaluate_model=None,
|
||||
name: str = "",
|
||||
template: str = "",
|
||||
iteration: bool = True,
|
||||
) -> None:
|
||||
|
||||
|
|
@ -34,16 +31,13 @@ class Optimizer:
|
|||
self.top_scores = []
|
||||
self.round = initial_round
|
||||
self.max_rounds = max_rounds
|
||||
self.execute_model = execute_model
|
||||
self.optimize_model = optimize_model
|
||||
self.evaluate_model = evaluate_model
|
||||
self.iteration = iteration
|
||||
self.template = template
|
||||
|
||||
self.prompt_utils = PromptUtils(self.root_path)
|
||||
self.data_utils = DataUtils(self.root_path)
|
||||
self.evaluation_utils = EvaluationUtils(self.root_path)
|
||||
self.token_tracker = get_token_tracker()
|
||||
self.llm = SPO_LLM.get_instance()
|
||||
|
||||
def optimize(self):
|
||||
if self.iteration is True:
|
||||
|
|
@ -55,8 +49,6 @@ class Optimizer:
|
|||
self.round += 1
|
||||
logger.info(f"Score for round {self.round}: {score}")
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
else:
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
|
@ -77,14 +69,12 @@ class Optimizer:
|
|||
prompt, _, _, _ = load.load_meta_data()
|
||||
self.prompt = prompt
|
||||
self.prompt_utils.write_prompt(directory, prompt=self.prompt)
|
||||
new_sample = await self.evaluation_utils.execute_prompt(self, directory, data, model=self.execute_model,
|
||||
initial=True)
|
||||
_, answers = await self.evaluation_utils.evaluate_prompt(self, None, new_sample, model=self.evaluate_model,
|
||||
path=prompt_path, data=data, initial=True)
|
||||
new_sample = await self.evaluation_utils.execute_prompt(self, directory, initial=True)
|
||||
_, answers = await self.evaluation_utils.evaluate_prompt(self, None, new_sample, path=prompt_path, data=data, initial=True)
|
||||
self.prompt_utils.write_answers(directory, answers=answers)
|
||||
|
||||
|
||||
_, requirements, qa, count = load.load_meta_data(3)
|
||||
_, requirements, qa, count = load.load_meta_data()
|
||||
|
||||
directory = self.prompt_utils.create_round_directory(prompt_path, self.round + 1)
|
||||
|
||||
|
|
@ -105,11 +95,10 @@ class Optimizer:
|
|||
golden_answers=golden_answer,
|
||||
count=count)
|
||||
|
||||
response = await responser(messages=[{"role": "user", "content": optimize_prompt}],
|
||||
model=self.optimize_model['name'], temperature=self.optimize_model['temperature'])
|
||||
response = await self.llm.responser(role="optimize", messages=[{"role": "user", "content": optimize_prompt}])
|
||||
|
||||
modification = extract_content(response.content, "modification")
|
||||
prompt = extract_content(response.content, "prompt")
|
||||
modification = extract_content(response, "modification")
|
||||
prompt = extract_content(response, "prompt")
|
||||
if prompt:
|
||||
self.prompt = prompt
|
||||
else:
|
||||
|
|
@ -119,11 +108,10 @@ class Optimizer:
|
|||
|
||||
self.prompt_utils.write_prompt(directory, prompt=self.prompt)
|
||||
|
||||
new_sample = await self.evaluation_utils.execute_prompt(self, directory, data, model=self.execute_model,
|
||||
initial=False)
|
||||
new_sample = await self.evaluation_utils.execute_prompt(self, directory, data)
|
||||
|
||||
success, answers = await self.evaluation_utils.evaluate_prompt(self, sample, new_sample,
|
||||
model=self.evaluate_model, path=prompt_path,
|
||||
path=prompt_path,
|
||||
data=data, initial=False)
|
||||
|
||||
self.prompt_utils.write_answers(directory, answers=answers)
|
||||
|
|
@ -133,11 +121,6 @@ class Optimizer:
|
|||
|
||||
logger.info(f"now is {self.round + 1}")
|
||||
|
||||
self.token_tracker.print_usage_report()
|
||||
usage = self.token_tracker.get_total_usage()
|
||||
|
||||
self.data_utils.save_cost(directory, usage)
|
||||
|
||||
return prompt
|
||||
|
||||
async def _test_prompt(self):
|
||||
|
|
@ -150,8 +133,7 @@ class Optimizer:
|
|||
directory = self.prompt_utils.create_round_directory(prompt_path, self.round)
|
||||
# Load prompt using prompt_utils
|
||||
|
||||
new_sample = await self.evaluation_utils.execute_prompt(self, directory, data, model=self.execute_model,
|
||||
initial=False, k=100)
|
||||
new_sample = await self.evaluation_utils.execute_prompt(self, directory, data)
|
||||
self.prompt_utils.write_answers(directory, answers=new_sample["answers"], name="test_answers.txt")
|
||||
|
||||
logger.info(new_sample)
|
||||
|
|
|
|||
|
|
@ -1,9 +1,13 @@
|
|||
import datetime
|
||||
import json
|
||||
import os
|
||||
import random
|
||||
from typing import Union, List, Dict
|
||||
|
||||
import pandas as pd
|
||||
import yaml
|
||||
|
||||
FILE_NAME = ''
|
||||
SAMPLE_K = 3
|
||||
|
||||
|
||||
class DataUtils:
|
||||
|
|
@ -23,7 +27,7 @@ class DataUtils:
|
|||
|
||||
def get_best_round(self):
|
||||
|
||||
top_rounds = self._load_scores()
|
||||
self._load_scores()
|
||||
|
||||
for entry in self.top_scores:
|
||||
if entry["succeed"]:
|
||||
|
|
@ -66,6 +70,39 @@ class DataUtils:
|
|||
|
||||
return self.top_scores
|
||||
|
||||
def set_file_name(name):
|
||||
global FILE_NAME
|
||||
FILE_NAME = name
|
||||
|
||||
def load_meta_data(k=SAMPLE_K):
|
||||
|
||||
# 读取 YAML 文件
|
||||
config_path = os.path.join(os.path.dirname(__file__), '../settings', FILE_NAME)
|
||||
with open(config_path, 'r', encoding='utf-8') as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
qa = []
|
||||
|
||||
# 提取问题和答案
|
||||
for item in data['faq']:
|
||||
question = item['question']
|
||||
answer = item['answer']
|
||||
qa.append({'question': question, 'answer': answer})
|
||||
|
||||
prompt = data['prompt']
|
||||
requirements = data['requirements']
|
||||
count = data['count']
|
||||
|
||||
if isinstance(count, int):
|
||||
count = f", within {count} words"
|
||||
else:
|
||||
count = ""
|
||||
|
||||
# 随机选择三组问答
|
||||
random_qa = random.sample(qa, min(k, len(qa))) # 确保不超过列表长度
|
||||
|
||||
return prompt, requirements, random_qa, count
|
||||
|
||||
def list_to_markdown(self, questions_list):
|
||||
"""
|
||||
Convert a list of question-answer dictionaries to a formatted Markdown string.
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import asyncio
|
||||
|
||||
from script.evaluator import QuickEvaluate, QuickExecute
|
||||
from utils.logs import logger
|
||||
from metagpt.ext.spo.scripts.evaluator import QuickEvaluate, QuickExecute
|
||||
from metagpt.logs import logger
|
||||
import tiktoken
|
||||
|
||||
|
||||
|
|
@ -16,10 +16,10 @@ class EvaluationUtils:
|
|||
def __init__(self, root_path: str):
|
||||
self.root_path = root_path
|
||||
|
||||
async def execute_prompt(self, optimizer, prompt_path, data, model, initial=False, k=3):
|
||||
async def execute_prompt(self, optimizer, prompt_path, initial=False):
|
||||
|
||||
optimizer.prompt = optimizer.prompt_utils.load_prompt(optimizer.round, prompt_path)
|
||||
evaluator = QuickExecute(prompt=optimizer.prompt, k=k, model=model)
|
||||
evaluator = QuickExecute(prompt=optimizer.prompt)
|
||||
|
||||
answers = await evaluator.prompt_execute()
|
||||
|
||||
|
|
@ -29,10 +29,9 @@ class EvaluationUtils:
|
|||
|
||||
return new_data
|
||||
|
||||
async def evaluate_prompt(self, optimizer, sample, new_sample, path, data, model, initial=False):
|
||||
async def evaluate_prompt(self, optimizer, sample, new_sample, path, data, initial=False):
|
||||
|
||||
evaluator = QuickEvaluate(k=3)
|
||||
original_token = count_tokens(sample)
|
||||
evaluator = QuickEvaluate()
|
||||
new_token = count_tokens(new_sample)
|
||||
|
||||
if initial is True:
|
||||
|
|
@ -40,7 +39,7 @@ class EvaluationUtils:
|
|||
else:
|
||||
evaluation_results = []
|
||||
for _ in range(4):
|
||||
result = await evaluator.prompt_evaluate(sample=sample, new_sample=new_sample, model=model)
|
||||
result = await evaluator.prompt_evaluate(sample=sample, new_sample=new_sample)
|
||||
evaluation_results.append(result)
|
||||
|
||||
logger.info(evaluation_results)
|
||||
|
|
|
|||
81
metagpt/ext/spo/scripts/utils/llm_client.py
Normal file
81
metagpt/ext/spo/scripts/utils/llm_client.py
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
import re
|
||||
from typing import Optional
|
||||
from metagpt.configs.models_config import ModelsConfig
|
||||
from metagpt.llm import LLM
|
||||
import asyncio
|
||||
|
||||
|
||||
class SPO_LLM:
|
||||
_instance: Optional['SPO_LLM'] = None
|
||||
|
||||
def __init__(self, optimize_kwargs=None, evaluate_kwargs=None, execute_kwargs=None):
|
||||
self.evaluate_llm = LLM(llm_config=self._load_llm_config(evaluate_kwargs))
|
||||
self.optimize_llm = LLM(llm_config=self._load_llm_config(optimize_kwargs))
|
||||
self.execute_llm = LLM(llm_config=self._load_llm_config(execute_kwargs))
|
||||
|
||||
def _load_llm_config(self, kwargs: dict):
|
||||
model = kwargs.get('model')
|
||||
config = ModelsConfig.default().get(model).model_copy()
|
||||
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
return config
|
||||
|
||||
async def responser(self, role: str, messages):
|
||||
if role == "optimize":
|
||||
response = await self.optimize_llm.acompletion(messages)
|
||||
elif role == "evaluate":
|
||||
response = await self.evaluate_llm.acompletion(messages)
|
||||
elif role == "execute":
|
||||
response = await self.execute_llm.acompletion(messages)
|
||||
else:
|
||||
raise ValueError("Please set the correct name: optimize, evaluate or execute")
|
||||
|
||||
rsp = response.choices[0].message.content
|
||||
return rsp
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, optimize_kwargs, evaluate_kwargs, execute_kwargs):
|
||||
"""Initialize the global instance"""
|
||||
cls._instance = cls(optimize_kwargs, evaluate_kwargs, execute_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
"""Get the global instance"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("SPO_LLM not initialized. Call initialize() first.")
|
||||
return cls._instance
|
||||
|
||||
def extract_content(xml_string, tag):
|
||||
pattern = rf'<{tag}>(.*?)</{tag}>'
|
||||
match = re.search(pattern, xml_string, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
async def spo():
|
||||
# 在入口处初始化配置
|
||||
SPO_LLM.initialize(
|
||||
optimize_kwargs={"model": "gpt-4o-mini", "temperature": 0.7},
|
||||
evaluate_kwargs={"model": "gpt-4o-mini", "temperature": 0.3},
|
||||
execute_kwargs={"model": "gpt-4o-mini", "temperature": 0.3}
|
||||
)
|
||||
|
||||
llm = SPO_LLM.get_instance()
|
||||
|
||||
# 测试消息
|
||||
hello_msg = [{"role": "user", "content": "你是什么模型"}]
|
||||
response = await llm.responser(role='execute', messages=hello_msg)
|
||||
print(f"AI回复: {response}")
|
||||
response = await llm.responser(role='optimize', messages=hello_msg)
|
||||
print(f"AI回复: {response}")
|
||||
response = await llm.responser(role='evaluate', messages=hello_msg)
|
||||
print(f"AI回复: {response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(spo())
|
||||
|
||||
|
||||
|
||||
|
|
@ -2,7 +2,8 @@ import yaml
|
|||
import random
|
||||
import os
|
||||
|
||||
FILE_NAME = 'meta.yaml' # 默认值
|
||||
FILE_NAME = 'meta.yaml'
|
||||
SAMPLE_K = 3
|
||||
|
||||
|
||||
def load_llm():
|
||||
|
|
@ -19,11 +20,10 @@ def set_file_name(name):
|
|||
FILE_NAME = name
|
||||
|
||||
|
||||
def load_meta_data(k=5):
|
||||
def load_meta_data(k=SAMPLE_K):
|
||||
|
||||
k = 5
|
||||
# 读取 YAML 文件
|
||||
config_path = os.path.join(os.path.dirname(__file__), '../settings', FILE_NAME)
|
||||
config_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), 'settings', FILE_NAME)
|
||||
with open(config_path, 'r', encoding='utf-8') as file:
|
||||
data = yaml.safe_load(file)
|
||||
|
||||
|
|
@ -44,7 +44,7 @@ def load_meta_data(k=5):
|
|||
else:
|
||||
count = ""
|
||||
|
||||
# 随机选择三组问答
|
||||
# 随机选择k组问答
|
||||
random_qa = random.sample(qa, min(k, len(qa))) # 确保不超过列表长度
|
||||
|
||||
return prompt, requirements, random_qa, count
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import re
|
|||
import time
|
||||
import traceback
|
||||
from typing import List
|
||||
from utils.logs import logger
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class PromptUtils:
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue