diff --git a/metagpt/ext/spo/app.py b/metagpt/ext/spo/app.py index 7a93cddf2..8a102ad85 100644 --- a/metagpt/ext/spo/app.py +++ b/metagpt/ext/spo/app.py @@ -217,7 +217,6 @@ def main(): max_rounds=max_rounds, template=f"{template_name}.yaml", name=template_name, - iteration=True, ) # Run optimization with progress bar @@ -228,7 +227,7 @@ def main(): st.header("Optimization Results") - prompt_path = f"{optimizer.root_path}/prompts" + prompt_path = optimizer.root_path / "prompts" result_data = optimizer.data_utils.load_results(prompt_path) st.session_state.optimization_results = result_data diff --git a/metagpt/ext/spo/components/optimizer.py b/metagpt/ext/spo/components/optimizer.py index 6b5a0824f..0ce588f44 100644 --- a/metagpt/ext/spo/components/optimizer.py +++ b/metagpt/ext/spo/components/optimizer.py @@ -4,6 +4,7 @@ # @Desc : optimizer for prompt import asyncio +from pathlib import Path from typing import List from metagpt.ext.spo.prompts.optimize_prompt import PROMPT_OPTIMIZE_PROMPT @@ -24,8 +25,8 @@ class PromptOptimizer: name: str = "", template: str = "", ) -> None: - self.dataset = name - self.root_path = f"{optimized_path}/{self.dataset}" + self.name = name + self.root_path = Path(optimized_path) / self.name self.top_scores = [] self.round = initial_round self.max_rounds = max_rounds @@ -55,7 +56,7 @@ class PromptOptimizer: logger.info("\n" + "=" * 50 + "\n") async def _optimize_prompt(self): - prompt_path = f"{self.root_path}/prompts" + prompt_path = self.root_path / "prompts" load.set_file_name(self.template) data = self.data_utils.load_results(prompt_path) @@ -75,7 +76,7 @@ class PromptOptimizer: return self.prompt - async def _handle_first_round(self, prompt_path: str, data: List[dict]) -> None: + async def _handle_first_round(self, prompt_path: Path, data: List[dict]) -> None: logger.info("\n⚡ RUNNING Round 1 PROMPT ⚡\n") directory = self.prompt_utils.create_round_directory(prompt_path, self.round) diff --git a/metagpt/ext/spo/utils/data_utils.py b/metagpt/ext/spo/utils/data_utils.py index 0d0d99def..17771c021 100644 --- a/metagpt/ext/spo/utils/data_utils.py +++ b/metagpt/ext/spo/utils/data_utils.py @@ -1,6 +1,6 @@ import datetime import json -import os +from pathlib import Path from typing import Dict, List, Union import pandas as pd @@ -9,18 +9,17 @@ from metagpt.logs import logger class DataUtils: - def __init__(self, root_path: str): + def __init__(self, root_path: Path): self.root_path = root_path self.top_scores = [] - def load_results(self, path: str) -> list: + def load_results(self, path: Path) -> list: result_path = self.get_results_file_path(path) - if os.path.exists(result_path): - with open(result_path, "r") as json_file: - try: - return json.load(json_file) - except json.JSONDecodeError: - return [] + if result_path.exists(): + try: + return json.loads(result_path.read_text()) + except json.JSONDecodeError: + return [] return [] def get_best_round(self): @@ -32,30 +31,28 @@ class DataUtils: return None - def get_results_file_path(self, prompt_path: str) -> str: - return os.path.join(prompt_path, "results.json") + def get_results_file_path(self, prompt_path: Path) -> Path: + return prompt_path / "results.json" def create_result_data(self, round: int, answers: list[dict], prompt: str, succeed: bool, tokens: int) -> dict: now = datetime.datetime.now() return {"round": round, "answers": answers, "prompt": prompt, "succeed": succeed, "tokens": tokens, "time": now} - def save_results(self, json_file_path: str, data: Union[List, Dict]): - with open(json_file_path, "w") as json_file: - json.dump(data, json_file, default=str, indent=4) + def save_results(self, json_file_path: Path, data: Union[List, Dict]): + json_path = json_file_path + json_path.write_text(json.dumps(data, default=str, indent=4)) def _load_scores(self): - rounds_dir = os.path.join(self.root_path, "prompts") - result_file = os.path.join(rounds_dir, "results.json") + rounds_dir = self.root_path / "prompts" + result_file = rounds_dir / "results.json" self.top_scores = [] try: - if not os.path.exists(result_file): + if not result_file.exists(): logger.warning(f"Results file not found at {result_file}") return self.top_scores - with open(result_file, "r", encoding="utf-8") as file: - data = json.load(file) - + data = json.loads(result_file.read_text(encoding="utf-8")) df = pd.DataFrame(data) for index, row in df.iterrows(): diff --git a/metagpt/ext/spo/utils/evaluation_utils.py b/metagpt/ext/spo/utils/evaluation_utils.py index 9e633b9bf..9814a70ba 100644 --- a/metagpt/ext/spo/utils/evaluation_utils.py +++ b/metagpt/ext/spo/utils/evaluation_utils.py @@ -1,3 +1,4 @@ +from pathlib import Path from typing import Any, List, Optional, Tuple import tiktoken @@ -17,10 +18,10 @@ def count_tokens(sample: dict): class EvaluationUtils: - def __init__(self, root_path: str) -> None: + def __init__(self, root_path: Path) -> None: self.root_path = root_path - async def execute_prompt(self, optimizer: Any, prompt_path: str) -> dict: + async def execute_prompt(self, optimizer: Any, prompt_path: Path) -> dict: optimizer.prompt = optimizer.prompt_utils.load_prompt(optimizer.round, prompt_path) executor = QuickExecute(prompt=optimizer.prompt) @@ -37,7 +38,7 @@ class EvaluationUtils: optimizer: Any, samples: Optional[dict], new_samples: dict, - path: str, + path: Path, data: List[dict], initial: bool = False, ) -> Tuple[bool, dict]: diff --git a/metagpt/ext/spo/utils/prompt_utils.py b/metagpt/ext/spo/utils/prompt_utils.py index 449611219..c1c960bb7 100644 --- a/metagpt/ext/spo/utils/prompt_utils.py +++ b/metagpt/ext/spo/utils/prompt_utils.py @@ -1,34 +1,34 @@ -import os +from pathlib import Path from metagpt.logs import logger class PromptUtils: - def __init__(self, root_path: str): + def __init__(self, root_path: Path): self.root_path = root_path - def create_round_directory(self, prompt_path: str, round_number: int) -> str: - directory = os.path.join(prompt_path, f"round_{round_number}") - os.makedirs(directory, exist_ok=True) + def create_round_directory(self, prompt_path: Path, round_number: int) -> Path: + directory = prompt_path / f"round_{round_number}" + directory.mkdir(parents=True, exist_ok=True) return directory - def load_prompt(self, round_number: int, prompts_path: str): - prompt_file_name = f"{prompts_path}/prompt.txt" + def load_prompt(self, round_number: int, prompts_path: Path): + prompt_file = prompts_path / "prompt.txt" try: - with open(prompt_file_name, "r", encoding="utf-8") as file: - return file.read() + return prompt_file.read_text(encoding="utf-8") except FileNotFoundError as e: logger.info(f"Error loading prompt for round {round_number}: {e}") raise - def write_answers(self, directory: str, answers: dict, name: str = "answers.txt"): - with open(os.path.join(directory, name), "w", encoding="utf-8") as file: + def write_answers(self, directory: Path, answers: dict, name: str = "answers.txt"): + answers_file = directory / name + with answers_file.open("w", encoding="utf-8") as file: for item in answers: file.write(f"Question:\n{item['question']}\n") file.write(f"Answer:\n{item['answer']}\n") file.write("\n") - def write_prompt(self, directory: str, prompt: str): - with open(os.path.join(directory, "prompt.txt"), "w", encoding="utf-8") as file: - file.write(prompt) + def write_prompt(self, directory: Path, prompt: str): + prompt_file = directory / "prompt.txt" + prompt_file.write_text(prompt, encoding="utf-8")