This commit is contained in:
didi 2024-08-26 08:40:10 +08:00
parent 7c2501e08b
commit d97f90f9c7
3 changed files with 52 additions and 10 deletions

View file

@ -11,7 +11,6 @@ import pandas as pd
from deepeval.benchmarks import GSM8K
from examples.ags.benchmark.gsm8k import GraphModel
from examples.ags.w_action_node.graph import SolveGraph
# TODO 完成实验数据集的手动划分
@ -126,8 +125,8 @@ class Evaluator:
dataset = params["dataset"]
llm_config = params["llm_config"]
# TODO 给到的是load出来的Graph怎么让他做实例化
graph = SolveGraph(name="Gsm8K", llm_config=llm_config, dataset=dataset)
# TODO 给到的是load出来的Graph怎么让他做实例化graph_class 可以跟我这样用吗?
graph = graph_class(name="Gsm8K", llm_config=llm_config, dataset=dataset)
model = GraphModel(graph)
benchmark = GSM8K(n_problems=samples, n_shots=0, enable_cot=False)

View file

@ -24,8 +24,8 @@ from examples.ags.w_action_node.prompts.optimize_prompt import (
OPERATOR_TEMPLATE,
)
from metagpt.actions.action_node import ActionNode
from metagpt.llm import LLM
from metagpt.logs import logger
from metagpt.provider.llm_provider_registry import create_llm_instance
config_iterate_path = "iterate"
@ -49,16 +49,17 @@ class Optimizer:
def __init__(
self,
dataset: DatasetType,
opt_llm: LLM,
exec_llm: LLM,
opt_llm_config,
exec_llm_config,
operators: List,
optimized_path: str = None,
sample: int = 6,
q_type: str = "math", # math,code,quiz
op: str = "Generator", # 需要优化的Operator
) -> None:
self.optimize_llm = opt_llm
self.execute_llm = exec_llm
self.optimize_llm_config = opt_llm_config
self.execute_llm_config = exec_llm_config
self.optimize_llm = create_llm_instance(self.optimize_llm_config)
self.dataset = dataset
self.graph = None # 初始化为 None稍后加载
self.operators = operators
@ -140,6 +141,7 @@ class Optimizer:
graph_module_name = f"{graphs_path}.round_{round_number}.graph"
try:
graph_module = __import__(graph_module_name, fromlist=[""])
# TODO 这里似乎有BUG
graph_class = getattr(graph_module, f"{self.dataset}Graph")
self.graph = graph_class
except ImportError as e:
@ -407,7 +409,9 @@ class Optimizer:
with open(os.path.join(directory, "experience.json"), "w", encoding="utf-8") as file:
json.dump(experience, file, ensure_ascii=False, indent=4)
score = evaluator.validation_evaluate(self.dataset, self.graph)
score = evaluator.validation_evaluate(
self.dataset, self.graph, {"dataset": self.dataset, "llm_config": self.execute_llm_config}
)
experience["after"] = score
experience["succeed"] = bool(score > experience["before"])
return score