add spo typing hint

This commit is contained in:
xiangjinyu 2025-02-12 18:50:20 +08:00
parent 424726eed9
commit 368d77b196
7 changed files with 41 additions and 23 deletions

View file

@ -1,6 +1,7 @@
import asyncio
import sys
from pathlib import Path
from typing import Dict
import streamlit as st
import yaml
@ -13,14 +14,14 @@ from metagpt.ext.spo.components.optimizer import PromptOptimizer # noqa: E402
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType # noqa: E402
def load_yaml_template(template_path):
def load_yaml_template(template_path: Path) -> Dict:
if template_path.exists():
with open(template_path, "r", encoding="utf-8") as f:
return yaml.safe_load(f)
return {"prompt": "", "requirements": "", "count": None, "faq": [{"question": "", "answer": ""}]}
def save_yaml_template(template_path, data):
def save_yaml_template(template_path: Path, data: Dict) -> None:
template_format = {
"prompt": str(data.get("prompt", "")),
"requirements": str(data.get("requirements", "")),

View file

@ -47,7 +47,7 @@ class QuickEvaluate:
def __init__(self):
self.llm = SPO_LLM.get_instance()
async def prompt_evaluate(self, samples: list, new_samples: list) -> bool:
async def prompt_evaluate(self, samples: dict, new_samples: dict) -> bool:
_, requirement, qa, _ = load.load_meta_data()
if random.random() < 0.5:

View file

@ -4,6 +4,7 @@
# @Desc : optimizer for prompt
import asyncio
from typing import List
from metagpt.ext.spo.prompts.optimize_prompt import PROMPT_OPTIMIZE_PROMPT
from metagpt.ext.spo.utils import load
@ -74,7 +75,7 @@ class PromptOptimizer:
return self.prompt
async def _handle_first_round(self, prompt_path, data):
async def _handle_first_round(self, prompt_path: str, data: List[dict]) -> None:
logger.info("\n⚡ RUNNING Round 1 PROMPT ⚡\n")
directory = self.prompt_utils.create_round_directory(prompt_path, self.round)

View file

@ -79,7 +79,7 @@ class DataUtils:
return self.top_scores
def list_to_markdown(self, questions_list):
def list_to_markdown(self, questions_list: list):
"""
Convert a list of question-answer dictionaries to a formatted Markdown string.

View file

@ -1,3 +1,5 @@
from typing import Any, List, Optional, Tuple
import tiktoken
from metagpt.ext.spo.components.evaluator import QuickEvaluate, QuickExecute
@ -6,7 +8,7 @@ from metagpt.logs import logger
EVALUATION_REPETITION = 4
def count_tokens(sample):
def count_tokens(sample: dict):
if not sample:
return 0
else:
@ -15,10 +17,10 @@ def count_tokens(sample):
class EvaluationUtils:
def __init__(self, root_path: str):
def __init__(self, root_path: str) -> None:
self.root_path = root_path
async def execute_prompt(self, optimizer, prompt_path):
async def execute_prompt(self, optimizer: Any, prompt_path: str) -> dict:
optimizer.prompt = optimizer.prompt_utils.load_prompt(optimizer.round, prompt_path)
executor = QuickExecute(prompt=optimizer.prompt)
@ -30,7 +32,15 @@ class EvaluationUtils:
return new_data
async def evaluate_prompt(self, optimizer, samples, new_samples, path, data, initial=False):
async def evaluate_prompt(
self,
optimizer: Any,
samples: Optional[dict],
new_samples: dict,
path: str,
data: List[dict],
initial: bool = False,
) -> Tuple[bool, dict]:
evaluator = QuickEvaluate()
new_token = count_tokens(new_samples)

View file

@ -1,10 +1,11 @@
import asyncio
import re
from enum import Enum
from typing import Optional
from typing import Any, List, Optional
from metagpt.configs.models_config import ModelsConfig
from metagpt.llm import LLM
from metagpt.logs import logger
class RequestType(Enum):
@ -16,12 +17,17 @@ class RequestType(Enum):
class SPO_LLM:
_instance: Optional["SPO_LLM"] = None
def __init__(self, optimize_kwargs=None, evaluate_kwargs=None, execute_kwargs=None):
def __init__(
self,
optimize_kwargs: Optional[dict] = None,
evaluate_kwargs: Optional[dict] = None,
execute_kwargs: Optional[dict] = None,
) -> None:
self.evaluate_llm = LLM(llm_config=self._load_llm_config(evaluate_kwargs))
self.optimize_llm = LLM(llm_config=self._load_llm_config(optimize_kwargs))
self.execute_llm = LLM(llm_config=self._load_llm_config(execute_kwargs))
def _load_llm_config(self, kwargs: dict):
def _load_llm_config(self, kwargs: dict) -> Any:
model = kwargs.get("model")
if not model:
raise ValueError("'model' parameter is required")
@ -44,7 +50,7 @@ class SPO_LLM:
except Exception as e:
raise ValueError(f"Error loading configuration for model '{model}': {str(e)}")
async def responser(self, request_type: RequestType, messages: list):
async def responser(self, request_type: RequestType, messages: List[dict]) -> str:
llm_mapping = {
RequestType.OPTIMIZE: self.optimize_llm,
RequestType.EVALUATE: self.evaluate_llm,
@ -59,25 +65,25 @@ class SPO_LLM:
return response.choices[0].message.content
@classmethod
def initialize(cls, optimize_kwargs, evaluate_kwargs, execute_kwargs):
def initialize(cls, optimize_kwargs: dict, evaluate_kwargs: dict, execute_kwargs: dict) -> None:
"""Initialize the global instance"""
cls._instance = cls(optimize_kwargs, evaluate_kwargs, execute_kwargs)
@classmethod
def get_instance(cls):
def get_instance(cls) -> "SPO_LLM":
"""Get the global instance"""
if cls._instance is None:
raise RuntimeError("SPO_LLM not initialized. Call initialize() first.")
return cls._instance
def extract_content(xml_string, tag):
def extract_content(xml_string: str, tag: str) -> Optional[str]:
pattern = rf"<{tag}>(.*?)</{tag}>"
match = re.search(pattern, xml_string, re.DOTALL)
return match.group(1).strip() if match else None
async def spo():
async def main():
# test LLM
SPO_LLM.initialize(
optimize_kwargs={"model": "gpt-4o", "temperature": 0.7},
@ -90,12 +96,12 @@ async def spo():
# test messages
hello_msg = [{"role": "user", "content": "hello"}]
response = await llm.responser(request_type=RequestType.EXECUTE, messages=hello_msg)
print(f"AI: {response}")
logger(f"AI: {response}")
response = await llm.responser(request_type=RequestType.OPTIMIZE, messages=hello_msg)
print(f"AI: {response}")
logger(f"AI: {response}")
response = await llm.responser(request_type=RequestType.EVALUATE, messages=hello_msg)
print(f"AI: {response}")
logger(f"AI: {response}")
if __name__ == "__main__":
asyncio.run(spo())
asyncio.run(main())

View file

@ -7,12 +7,12 @@ FILE_NAME = ""
SAMPLE_K = 3
def set_file_name(name):
def set_file_name(name: str):
global FILE_NAME
FILE_NAME = name
def load_meta_data(k=SAMPLE_K):
def load_meta_data(k: int = SAMPLE_K):
# load yaml file
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "settings", FILE_NAME)