mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
Update
This commit is contained in:
parent
7c2501e08b
commit
d97f90f9c7
3 changed files with 52 additions and 10 deletions
|
|
@ -11,7 +11,6 @@ import pandas as pd
|
|||
from deepeval.benchmarks import GSM8K
|
||||
|
||||
from examples.ags.benchmark.gsm8k import GraphModel
|
||||
from examples.ags.w_action_node.graph import SolveGraph
|
||||
|
||||
# TODO 完成实验数据集的手动划分
|
||||
|
||||
|
|
@ -126,8 +125,8 @@ class Evaluator:
|
|||
dataset = params["dataset"]
|
||||
llm_config = params["llm_config"]
|
||||
|
||||
# TODO 给到的是load出来的Graph,怎么让他做实例化?
|
||||
graph = SolveGraph(name="Gsm8K", llm_config=llm_config, dataset=dataset)
|
||||
# TODO 给到的是load出来的Graph,怎么让他做实例化?graph_class 可以跟我这样用吗?
|
||||
graph = graph_class(name="Gsm8K", llm_config=llm_config, dataset=dataset)
|
||||
model = GraphModel(graph)
|
||||
benchmark = GSM8K(n_problems=samples, n_shots=0, enable_cot=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -24,8 +24,8 @@ from examples.ags.w_action_node.prompts.optimize_prompt import (
|
|||
OPERATOR_TEMPLATE,
|
||||
)
|
||||
from metagpt.actions.action_node import ActionNode
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.llm_provider_registry import create_llm_instance
|
||||
|
||||
config_iterate_path = "iterate"
|
||||
|
||||
|
|
@ -49,16 +49,17 @@ class Optimizer:
|
|||
def __init__(
|
||||
self,
|
||||
dataset: DatasetType,
|
||||
opt_llm: LLM,
|
||||
exec_llm: LLM,
|
||||
opt_llm_config,
|
||||
exec_llm_config,
|
||||
operators: List,
|
||||
optimized_path: str = None,
|
||||
sample: int = 6,
|
||||
q_type: str = "math", # math,code,quiz
|
||||
op: str = "Generator", # 需要优化的Operator
|
||||
) -> None:
|
||||
self.optimize_llm = opt_llm
|
||||
self.execute_llm = exec_llm
|
||||
self.optimize_llm_config = opt_llm_config
|
||||
self.execute_llm_config = exec_llm_config
|
||||
self.optimize_llm = create_llm_instance(self.optimize_llm_config)
|
||||
self.dataset = dataset
|
||||
self.graph = None # 初始化为 None,稍后加载
|
||||
self.operators = operators
|
||||
|
|
@ -140,6 +141,7 @@ class Optimizer:
|
|||
graph_module_name = f"{graphs_path}.round_{round_number}.graph"
|
||||
try:
|
||||
graph_module = __import__(graph_module_name, fromlist=[""])
|
||||
# TODO 这里似乎有BUG
|
||||
graph_class = getattr(graph_module, f"{self.dataset}Graph")
|
||||
self.graph = graph_class
|
||||
except ImportError as e:
|
||||
|
|
@ -407,7 +409,9 @@ class Optimizer:
|
|||
with open(os.path.join(directory, "experience.json"), "w", encoding="utf-8") as file:
|
||||
json.dump(experience, file, ensure_ascii=False, indent=4)
|
||||
|
||||
score = evaluator.validation_evaluate(self.dataset, self.graph)
|
||||
score = evaluator.validation_evaluate(
|
||||
self.dataset, self.graph, {"dataset": self.dataset, "llm_config": self.execute_llm_config}
|
||||
)
|
||||
experience["after"] = score
|
||||
experience["succeed"] = bool(score > experience["before"])
|
||||
return score
|
||||
|
|
|
|||
41
optimize.py
41
optimize.py
|
|
@ -3,4 +3,43 @@
|
|||
# @Author : didi
|
||||
# @Desc : Experiment of graph optimization
|
||||
|
||||
# TODO 实现两个LLM Config
|
||||
|
||||
from examples.ags.w_action_node.optimizer import Optimizer
|
||||
from metagpt.configs.models_config import ModelsConfig
|
||||
|
||||
# 配置实验参数
|
||||
dataset = "Gsm8K" # 数据集选择为GSM8K
|
||||
sample = 6 # 采样数量
|
||||
q_type = "math" # 问题类型为数学
|
||||
optimized_path = "examples/ags/w_action_node/optimized" # 优化结果保存路径
|
||||
|
||||
# 初始化LLM模型
|
||||
deepseek_llm_config = ModelsConfig.default().get("deepseek-coder")
|
||||
claude_llm_config = ModelsConfig.default().get("claude-3.5-sonnet")
|
||||
|
||||
# 初始化操作符列表
|
||||
gsm8k_operators = [
|
||||
"Generate",
|
||||
"ContextualGenerate",
|
||||
"Format",
|
||||
"Review",
|
||||
"Revise",
|
||||
"FuEnsemble",
|
||||
"MdEnsemble",
|
||||
"ScEnsemble",
|
||||
"Rephrase",
|
||||
]
|
||||
|
||||
# 创建优化器实例
|
||||
optimizer = Optimizer(
|
||||
dataset=dataset,
|
||||
opt_llm_config=claude_llm_config,
|
||||
exec_llm_config=deepseek_llm_config,
|
||||
operators=gsm8k_operators,
|
||||
optimized_path=optimized_path,
|
||||
sample=sample,
|
||||
q_type=q_type,
|
||||
)
|
||||
|
||||
# 运行优化器
|
||||
optimizer.optimize("Graph")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue