diff --git a/metagpt/ext/spo/app.py b/metagpt/ext/spo/app.py index 183757124..7a93cddf2 100644 --- a/metagpt/ext/spo/app.py +++ b/metagpt/ext/spo/app.py @@ -1,6 +1,7 @@ import asyncio import sys from pathlib import Path +from typing import Dict import streamlit as st import yaml @@ -13,14 +14,14 @@ from metagpt.ext.spo.components.optimizer import PromptOptimizer # noqa: E402 from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType # noqa: E402 -def load_yaml_template(template_path): +def load_yaml_template(template_path: Path) -> Dict: if template_path.exists(): with open(template_path, "r", encoding="utf-8") as f: return yaml.safe_load(f) return {"prompt": "", "requirements": "", "count": None, "faq": [{"question": "", "answer": ""}]} -def save_yaml_template(template_path, data): +def save_yaml_template(template_path: Path, data: Dict) -> None: template_format = { "prompt": str(data.get("prompt", "")), "requirements": str(data.get("requirements", "")), diff --git a/metagpt/ext/spo/components/evaluator.py b/metagpt/ext/spo/components/evaluator.py index 5e95f3719..952ef211b 100644 --- a/metagpt/ext/spo/components/evaluator.py +++ b/metagpt/ext/spo/components/evaluator.py @@ -47,7 +47,7 @@ class QuickEvaluate: def __init__(self): self.llm = SPO_LLM.get_instance() - async def prompt_evaluate(self, samples: list, new_samples: list) -> bool: + async def prompt_evaluate(self, samples: dict, new_samples: dict) -> bool: _, requirement, qa, _ = load.load_meta_data() if random.random() < 0.5: diff --git a/metagpt/ext/spo/components/optimizer.py b/metagpt/ext/spo/components/optimizer.py index fca67bc56..6b5a0824f 100644 --- a/metagpt/ext/spo/components/optimizer.py +++ b/metagpt/ext/spo/components/optimizer.py @@ -4,6 +4,7 @@ # @Desc : optimizer for prompt import asyncio +from typing import List from metagpt.ext.spo.prompts.optimize_prompt import PROMPT_OPTIMIZE_PROMPT from metagpt.ext.spo.utils import load @@ -74,7 +75,7 @@ class PromptOptimizer: return self.prompt - async def _handle_first_round(self, prompt_path, data): + async def _handle_first_round(self, prompt_path: str, data: List[dict]) -> None: logger.info("\n⚡ RUNNING Round 1 PROMPT ⚡\n") directory = self.prompt_utils.create_round_directory(prompt_path, self.round) diff --git a/metagpt/ext/spo/utils/data_utils.py b/metagpt/ext/spo/utils/data_utils.py index 7bf57dbf7..0d0d99def 100644 --- a/metagpt/ext/spo/utils/data_utils.py +++ b/metagpt/ext/spo/utils/data_utils.py @@ -79,7 +79,7 @@ class DataUtils: return self.top_scores - def list_to_markdown(self, questions_list): + def list_to_markdown(self, questions_list: list): """ Convert a list of question-answer dictionaries to a formatted Markdown string. diff --git a/metagpt/ext/spo/utils/evaluation_utils.py b/metagpt/ext/spo/utils/evaluation_utils.py index 4a0403a9a..9e633b9bf 100644 --- a/metagpt/ext/spo/utils/evaluation_utils.py +++ b/metagpt/ext/spo/utils/evaluation_utils.py @@ -1,3 +1,5 @@ +from typing import Any, List, Optional, Tuple + import tiktoken from metagpt.ext.spo.components.evaluator import QuickEvaluate, QuickExecute @@ -6,7 +8,7 @@ from metagpt.logs import logger EVALUATION_REPETITION = 4 -def count_tokens(sample): +def count_tokens(sample: dict): if not sample: return 0 else: @@ -15,10 +17,10 @@ def count_tokens(sample): class EvaluationUtils: - def __init__(self, root_path: str): + def __init__(self, root_path: str) -> None: self.root_path = root_path - async def execute_prompt(self, optimizer, prompt_path): + async def execute_prompt(self, optimizer: Any, prompt_path: str) -> dict: optimizer.prompt = optimizer.prompt_utils.load_prompt(optimizer.round, prompt_path) executor = QuickExecute(prompt=optimizer.prompt) @@ -30,7 +32,15 @@ class EvaluationUtils: return new_data - async def evaluate_prompt(self, optimizer, samples, new_samples, path, data, initial=False): + async def evaluate_prompt( + self, + optimizer: Any, + samples: Optional[dict], + new_samples: dict, + path: str, + data: List[dict], + initial: bool = False, + ) -> Tuple[bool, dict]: evaluator = QuickEvaluate() new_token = count_tokens(new_samples) diff --git a/metagpt/ext/spo/utils/llm_client.py b/metagpt/ext/spo/utils/llm_client.py index 689d2a5ef..81524d3c1 100644 --- a/metagpt/ext/spo/utils/llm_client.py +++ b/metagpt/ext/spo/utils/llm_client.py @@ -1,10 +1,11 @@ import asyncio import re from enum import Enum -from typing import Optional +from typing import Any, List, Optional from metagpt.configs.models_config import ModelsConfig from metagpt.llm import LLM +from metagpt.logs import logger class RequestType(Enum): @@ -16,12 +17,17 @@ class RequestType(Enum): class SPO_LLM: _instance: Optional["SPO_LLM"] = None - def __init__(self, optimize_kwargs=None, evaluate_kwargs=None, execute_kwargs=None): + def __init__( + self, + optimize_kwargs: Optional[dict] = None, + evaluate_kwargs: Optional[dict] = None, + execute_kwargs: Optional[dict] = None, + ) -> None: self.evaluate_llm = LLM(llm_config=self._load_llm_config(evaluate_kwargs)) self.optimize_llm = LLM(llm_config=self._load_llm_config(optimize_kwargs)) self.execute_llm = LLM(llm_config=self._load_llm_config(execute_kwargs)) - def _load_llm_config(self, kwargs: dict): + def _load_llm_config(self, kwargs: dict) -> Any: model = kwargs.get("model") if not model: raise ValueError("'model' parameter is required") @@ -44,7 +50,7 @@ class SPO_LLM: except Exception as e: raise ValueError(f"Error loading configuration for model '{model}': {str(e)}") - async def responser(self, request_type: RequestType, messages: list): + async def responser(self, request_type: RequestType, messages: List[dict]) -> str: llm_mapping = { RequestType.OPTIMIZE: self.optimize_llm, RequestType.EVALUATE: self.evaluate_llm, @@ -59,25 +65,25 @@ class SPO_LLM: return response.choices[0].message.content @classmethod - def initialize(cls, optimize_kwargs, evaluate_kwargs, execute_kwargs): + def initialize(cls, optimize_kwargs: dict, evaluate_kwargs: dict, execute_kwargs: dict) -> None: """Initialize the global instance""" cls._instance = cls(optimize_kwargs, evaluate_kwargs, execute_kwargs) @classmethod - def get_instance(cls): + def get_instance(cls) -> "SPO_LLM": """Get the global instance""" if cls._instance is None: raise RuntimeError("SPO_LLM not initialized. Call initialize() first.") return cls._instance -def extract_content(xml_string, tag): +def extract_content(xml_string: str, tag: str) -> Optional[str]: pattern = rf"<{tag}>(.*?)" match = re.search(pattern, xml_string, re.DOTALL) return match.group(1).strip() if match else None -async def spo(): +async def main(): # test LLM SPO_LLM.initialize( optimize_kwargs={"model": "gpt-4o", "temperature": 0.7}, @@ -90,12 +96,12 @@ async def spo(): # test messages hello_msg = [{"role": "user", "content": "hello"}] response = await llm.responser(request_type=RequestType.EXECUTE, messages=hello_msg) - print(f"AI: {response}") + logger(f"AI: {response}") response = await llm.responser(request_type=RequestType.OPTIMIZE, messages=hello_msg) - print(f"AI: {response}") + logger(f"AI: {response}") response = await llm.responser(request_type=RequestType.EVALUATE, messages=hello_msg) - print(f"AI: {response}") + logger(f"AI: {response}") if __name__ == "__main__": - asyncio.run(spo()) + asyncio.run(main()) diff --git a/metagpt/ext/spo/utils/load.py b/metagpt/ext/spo/utils/load.py index 3f9ab0c27..bf0d8af4e 100644 --- a/metagpt/ext/spo/utils/load.py +++ b/metagpt/ext/spo/utils/load.py @@ -7,12 +7,12 @@ FILE_NAME = "" SAMPLE_K = 3 -def set_file_name(name): +def set_file_name(name: str): global FILE_NAME FILE_NAME = name -def load_meta_data(k=SAMPLE_K): +def load_meta_data(k: int = SAMPLE_K): # load yaml file config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "settings", FILE_NAME)