diff --git a/examples/ags/benchmark/hotpotqa.py b/examples/ags/benchmark/hotpotqa.py index 2e9be830d..375882511 100644 --- a/examples/ags/benchmark/hotpotqa.py +++ b/examples/ags/benchmark/hotpotqa.py @@ -104,11 +104,12 @@ def f1_score(predicted_bag: Set[str], gold_bag: Set[str]) -> float: f1 = (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0 return f1 -async def load_data(file_path: str, samples=20) -> List[dict]: +async def load_data(file_path: str, samples=20, total_length=1000) -> List[dict]: data = [] async with aiofiles.open(file_path, mode="r") as file: async for line in file: data.append(json.loads(line)) + data = data[:total_length] random_indices = generate_random_indices(len(data), samples) data = [data[i] for i in random_indices] return data diff --git a/examples/ags/experiments/baselines/cot_hotpotqa.py b/examples/ags/experiments/baselines/cot_hotpotqa.py index e69de29bb..22bd69438 100644 --- a/examples/ags/experiments/baselines/cot_hotpotqa.py +++ b/examples/ags/experiments/baselines/cot_hotpotqa.py @@ -0,0 +1,58 @@ +from examples.ags.scripts.operator import Operator +from examples.ags.scripts.graph import SolveGraph +from examples.ags.benchmark.hotpotqa import hotpotqa_evaluation +from examples.ags.scripts.operator_an import GenerateOp +from metagpt.actions.action_node import ActionNode +from metagpt.configs.models_config import ModelsConfig +from metagpt.llm import LLM +from pydantic import BaseModel, Field +from typing import Tuple + +HOTPOTQA_PROMPT = """ +问题: {question} + +上下文: +{context} + +请一步步思考,并在最后给出你的答案和支持性句子。使用XML标签包裹内容。 +""" + +class GenerateOp(BaseModel): + answer: str = Field(default="", description="问题的答案") + supporting_sentences: str = Field(default="", description="支持性句子") + +class CoTGenerate(Operator): + def __init__(self, llm: LLM, name: str = "Generate"): + super().__init__(name, llm) + + async def __call__(self, question: str, context: str, mode: str = None) -> Tuple[str, str]: + prompt = HOTPOTQA_PROMPT.format(question=question, context=context) + fill_kwargs = {"context": prompt, "llm": self.llm} + if mode: + fill_kwargs["mode"] = mode + node = await ActionNode.from_pydantic(GenerateOp).fill(**fill_kwargs) + response = node.instruct_content.model_dump() + return response["answer"], response["supporting_sentences"] + +class CoTSolveGraph(SolveGraph): + def __init__(self, name: str, llm_config, dataset: str): + super().__init__(name, llm_config, dataset) + self.cot_generate = CoTGenerate(self.llm) + + async def __call__(self, question: str, context: str) -> Tuple[str, str]: + answer, supporting_sentences = await self.cot_generate(question, context, mode="context_fill") + return answer, supporting_sentences + +if __name__ == "__main__": + async def main(): + llm_config = ModelsConfig.default().get("gpt-4o-mini") + # llm_config = ModelsConfig.default().get("gpt-35-turbo-1106") + graph = CoTSolveGraph(name="CoT", llm_config=llm_config, dataset="HotpotQA") + file_path = "examples/ags/data/hotpotqa.jsonl" + samples = 50 # TODO 选择前1000条跑实验 + path = "examples/ags/data/baselines/general/hotpotqa" + score = await hotpotqa_evaluation(graph, file_path, samples, path) + return score + + import asyncio + asyncio.run(main()) \ No newline at end of file