mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
add spo typing hint
This commit is contained in:
parent
424726eed9
commit
368d77b196
7 changed files with 41 additions and 23 deletions
|
|
@ -1,6 +1,7 @@
|
|||
import asyncio
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Dict
|
||||
|
||||
import streamlit as st
|
||||
import yaml
|
||||
|
|
@ -13,14 +14,14 @@ from metagpt.ext.spo.components.optimizer import PromptOptimizer # noqa: E402
|
|||
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType # noqa: E402
|
||||
|
||||
|
||||
def load_yaml_template(template_path):
|
||||
def load_yaml_template(template_path: Path) -> Dict:
|
||||
if template_path.exists():
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
return yaml.safe_load(f)
|
||||
return {"prompt": "", "requirements": "", "count": None, "faq": [{"question": "", "answer": ""}]}
|
||||
|
||||
|
||||
def save_yaml_template(template_path, data):
|
||||
def save_yaml_template(template_path: Path, data: Dict) -> None:
|
||||
template_format = {
|
||||
"prompt": str(data.get("prompt", "")),
|
||||
"requirements": str(data.get("requirements", "")),
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class QuickEvaluate:
|
|||
def __init__(self):
|
||||
self.llm = SPO_LLM.get_instance()
|
||||
|
||||
async def prompt_evaluate(self, samples: list, new_samples: list) -> bool:
|
||||
async def prompt_evaluate(self, samples: dict, new_samples: dict) -> bool:
|
||||
_, requirement, qa, _ = load.load_meta_data()
|
||||
|
||||
if random.random() < 0.5:
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@
|
|||
# @Desc : optimizer for prompt
|
||||
|
||||
import asyncio
|
||||
from typing import List
|
||||
|
||||
from metagpt.ext.spo.prompts.optimize_prompt import PROMPT_OPTIMIZE_PROMPT
|
||||
from metagpt.ext.spo.utils import load
|
||||
|
|
@ -74,7 +75,7 @@ class PromptOptimizer:
|
|||
|
||||
return self.prompt
|
||||
|
||||
async def _handle_first_round(self, prompt_path, data):
|
||||
async def _handle_first_round(self, prompt_path: str, data: List[dict]) -> None:
|
||||
logger.info("\n⚡ RUNNING Round 1 PROMPT ⚡\n")
|
||||
directory = self.prompt_utils.create_round_directory(prompt_path, self.round)
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ class DataUtils:
|
|||
|
||||
return self.top_scores
|
||||
|
||||
def list_to_markdown(self, questions_list):
|
||||
def list_to_markdown(self, questions_list: list):
|
||||
"""
|
||||
Convert a list of question-answer dictionaries to a formatted Markdown string.
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from typing import Any, List, Optional, Tuple
|
||||
|
||||
import tiktoken
|
||||
|
||||
from metagpt.ext.spo.components.evaluator import QuickEvaluate, QuickExecute
|
||||
|
|
@ -6,7 +8,7 @@ from metagpt.logs import logger
|
|||
EVALUATION_REPETITION = 4
|
||||
|
||||
|
||||
def count_tokens(sample):
|
||||
def count_tokens(sample: dict):
|
||||
if not sample:
|
||||
return 0
|
||||
else:
|
||||
|
|
@ -15,10 +17,10 @@ def count_tokens(sample):
|
|||
|
||||
|
||||
class EvaluationUtils:
|
||||
def __init__(self, root_path: str):
|
||||
def __init__(self, root_path: str) -> None:
|
||||
self.root_path = root_path
|
||||
|
||||
async def execute_prompt(self, optimizer, prompt_path):
|
||||
async def execute_prompt(self, optimizer: Any, prompt_path: str) -> dict:
|
||||
optimizer.prompt = optimizer.prompt_utils.load_prompt(optimizer.round, prompt_path)
|
||||
executor = QuickExecute(prompt=optimizer.prompt)
|
||||
|
||||
|
|
@ -30,7 +32,15 @@ class EvaluationUtils:
|
|||
|
||||
return new_data
|
||||
|
||||
async def evaluate_prompt(self, optimizer, samples, new_samples, path, data, initial=False):
|
||||
async def evaluate_prompt(
|
||||
self,
|
||||
optimizer: Any,
|
||||
samples: Optional[dict],
|
||||
new_samples: dict,
|
||||
path: str,
|
||||
data: List[dict],
|
||||
initial: bool = False,
|
||||
) -> Tuple[bool, dict]:
|
||||
evaluator = QuickEvaluate()
|
||||
new_token = count_tokens(new_samples)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,10 +1,11 @@
|
|||
import asyncio
|
||||
import re
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from metagpt.configs.models_config import ModelsConfig
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.logs import logger
|
||||
|
||||
|
||||
class RequestType(Enum):
|
||||
|
|
@ -16,12 +17,17 @@ class RequestType(Enum):
|
|||
class SPO_LLM:
|
||||
_instance: Optional["SPO_LLM"] = None
|
||||
|
||||
def __init__(self, optimize_kwargs=None, evaluate_kwargs=None, execute_kwargs=None):
|
||||
def __init__(
|
||||
self,
|
||||
optimize_kwargs: Optional[dict] = None,
|
||||
evaluate_kwargs: Optional[dict] = None,
|
||||
execute_kwargs: Optional[dict] = None,
|
||||
) -> None:
|
||||
self.evaluate_llm = LLM(llm_config=self._load_llm_config(evaluate_kwargs))
|
||||
self.optimize_llm = LLM(llm_config=self._load_llm_config(optimize_kwargs))
|
||||
self.execute_llm = LLM(llm_config=self._load_llm_config(execute_kwargs))
|
||||
|
||||
def _load_llm_config(self, kwargs: dict):
|
||||
def _load_llm_config(self, kwargs: dict) -> Any:
|
||||
model = kwargs.get("model")
|
||||
if not model:
|
||||
raise ValueError("'model' parameter is required")
|
||||
|
|
@ -44,7 +50,7 @@ class SPO_LLM:
|
|||
except Exception as e:
|
||||
raise ValueError(f"Error loading configuration for model '{model}': {str(e)}")
|
||||
|
||||
async def responser(self, request_type: RequestType, messages: list):
|
||||
async def responser(self, request_type: RequestType, messages: List[dict]) -> str:
|
||||
llm_mapping = {
|
||||
RequestType.OPTIMIZE: self.optimize_llm,
|
||||
RequestType.EVALUATE: self.evaluate_llm,
|
||||
|
|
@ -59,25 +65,25 @@ class SPO_LLM:
|
|||
return response.choices[0].message.content
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, optimize_kwargs, evaluate_kwargs, execute_kwargs):
|
||||
def initialize(cls, optimize_kwargs: dict, evaluate_kwargs: dict, execute_kwargs: dict) -> None:
|
||||
"""Initialize the global instance"""
|
||||
cls._instance = cls(optimize_kwargs, evaluate_kwargs, execute_kwargs)
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls):
|
||||
def get_instance(cls) -> "SPO_LLM":
|
||||
"""Get the global instance"""
|
||||
if cls._instance is None:
|
||||
raise RuntimeError("SPO_LLM not initialized. Call initialize() first.")
|
||||
return cls._instance
|
||||
|
||||
|
||||
def extract_content(xml_string, tag):
|
||||
def extract_content(xml_string: str, tag: str) -> Optional[str]:
|
||||
pattern = rf"<{tag}>(.*?)</{tag}>"
|
||||
match = re.search(pattern, xml_string, re.DOTALL)
|
||||
return match.group(1).strip() if match else None
|
||||
|
||||
|
||||
async def spo():
|
||||
async def main():
|
||||
# test LLM
|
||||
SPO_LLM.initialize(
|
||||
optimize_kwargs={"model": "gpt-4o", "temperature": 0.7},
|
||||
|
|
@ -90,12 +96,12 @@ async def spo():
|
|||
# test messages
|
||||
hello_msg = [{"role": "user", "content": "hello"}]
|
||||
response = await llm.responser(request_type=RequestType.EXECUTE, messages=hello_msg)
|
||||
print(f"AI: {response}")
|
||||
logger(f"AI: {response}")
|
||||
response = await llm.responser(request_type=RequestType.OPTIMIZE, messages=hello_msg)
|
||||
print(f"AI: {response}")
|
||||
logger(f"AI: {response}")
|
||||
response = await llm.responser(request_type=RequestType.EVALUATE, messages=hello_msg)
|
||||
print(f"AI: {response}")
|
||||
logger(f"AI: {response}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(spo())
|
||||
asyncio.run(main())
|
||||
|
|
|
|||
|
|
@ -7,12 +7,12 @@ FILE_NAME = ""
|
|||
SAMPLE_K = 3
|
||||
|
||||
|
||||
def set_file_name(name):
|
||||
def set_file_name(name: str):
|
||||
global FILE_NAME
|
||||
FILE_NAME = name
|
||||
|
||||
|
||||
def load_meta_data(k=SAMPLE_K):
|
||||
def load_meta_data(k: int = SAMPLE_K):
|
||||
# load yaml file
|
||||
config_path = os.path.join(os.path.dirname(os.path.dirname(__file__)), "settings", FILE_NAME)
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue