mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-08 15:05:17 +02:00
use llm cache to make exp_pool
This commit is contained in:
parent
d902a6f18c
commit
c624c0ffc7
41 changed files with 844 additions and 368 deletions
|
|
@ -75,8 +75,10 @@ s3:
|
|||
bucket: "test"
|
||||
|
||||
exp_pool:
|
||||
enable_read: true
|
||||
enable_write: true
|
||||
enable_read: false
|
||||
enable_write: false
|
||||
persist_path: .chroma_exp_data # The directory.
|
||||
init_exp: false # If set to true, basic experiences associated with the roles will be added to the experience pool.
|
||||
|
||||
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
|
||||
azure_tts_region: "eastus"
|
||||
|
|
|
|||
|
|
@ -7,16 +7,15 @@ from metagpt.exp_pool import exp_cache, exp_manager
|
|||
from metagpt.logs import logger
|
||||
|
||||
|
||||
@exp_cache(pass_exps_to_func=True)
|
||||
async def produce(req, exps=None):
|
||||
logger.info(f"Previous experiences: {exps}")
|
||||
@exp_cache()
|
||||
async def produce(req=""):
|
||||
return f"{req} {uuid.uuid4().hex}"
|
||||
|
||||
|
||||
async def main():
|
||||
req = "Water"
|
||||
|
||||
resp = await produce(req)
|
||||
resp = await produce(req=req)
|
||||
logger.info(f"The resp of `produce{req}` is: {resp}")
|
||||
|
||||
exps = await exp_manager.query_exps(req)
|
||||
|
|
|
|||
|
|
@ -50,9 +50,9 @@ async def generate_novel():
|
|||
"Fill the empty nodes with your own ideas. Be creative! Use your own words!"
|
||||
"I will tip you $100,000 if you write a good novel."
|
||||
)
|
||||
novel_node = await ActionNode.from_pydantic(Novel).fill(context=instruction, llm=LLM())
|
||||
novel_node = await ActionNode.from_pydantic(Novel).fill(req=instruction, llm=LLM())
|
||||
chap_node = await ActionNode.from_pydantic(Chapters).fill(
|
||||
context=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM()
|
||||
req=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM()
|
||||
)
|
||||
print(chap_node.instruct_content)
|
||||
|
||||
|
|
|
|||
|
|
@ -90,7 +90,7 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
|
|||
msgs = args[0]
|
||||
context = "## History Messages\n"
|
||||
context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))])
|
||||
return await self.node.fill(context=context, llm=self.llm)
|
||||
return await self.node.fill(req=context, llm=self.llm)
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
"""Run action"""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
|
||||
from metagpt.actions.action_outcls_registry import register_action_outcls
|
||||
from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT
|
||||
from metagpt.exp_pool import exp_cache
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
|
|
@ -465,9 +466,33 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def deserialize_to_action_node(cls, serialized_data) -> "ActionNode":
|
||||
"""Customized deserialization, it will be triggered when a perfect experience is found.
|
||||
|
||||
ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'.
|
||||
"""
|
||||
|
||||
class InstructContent:
|
||||
def __init__(self, json_data):
|
||||
self.json_data = json_data
|
||||
|
||||
def model_dump_json(self):
|
||||
return self.json_data
|
||||
|
||||
action_node = cls(key="", expected_type=Type[str], instruction="", example="")
|
||||
action_node.instruct_content = InstructContent(serialized_data)
|
||||
|
||||
return action_node
|
||||
|
||||
@exp_cache(
|
||||
resp_serialize=lambda action_node: action_node.instruct_content.model_dump_json(),
|
||||
resp_deserialize=lambda resp: ActionNode.deserialize_to_action_node(resp),
|
||||
)
|
||||
async def fill(
|
||||
self,
|
||||
context,
|
||||
*,
|
||||
req,
|
||||
llm,
|
||||
schema="json",
|
||||
mode="auto",
|
||||
|
|
@ -478,7 +503,7 @@ class ActionNode:
|
|||
):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
:param req: Everything we should know when filling node.
|
||||
:param llm: Large Language Model with pre-defined system message.
|
||||
:param schema: json/markdown, determine example and output format.
|
||||
- raw: free form text
|
||||
|
|
@ -497,7 +522,7 @@ class ActionNode:
|
|||
:return: self
|
||||
"""
|
||||
self.set_llm(llm)
|
||||
self.set_context(context)
|
||||
self.set_context(req)
|
||||
if self.schema:
|
||||
schema = self.schema
|
||||
|
||||
|
|
|
|||
|
|
@ -178,12 +178,12 @@ class WriteDesign(Action):
|
|||
)
|
||||
|
||||
async def _new_system_design(self, context):
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema)
|
||||
node = await DESIGN_API_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
|
||||
return node
|
||||
|
||||
async def _merge(self, prd_doc, system_design_doc):
|
||||
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
|
||||
node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema)
|
||||
node = await REFINED_DESIGN_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
|
||||
system_design_doc.content = node.instruct_content.model_dump_json()
|
||||
return system_design_doc
|
||||
|
||||
|
|
|
|||
|
|
@ -22,4 +22,4 @@ class GenerateQuestions(Action):
|
|||
name: str = "GenerateQuestions"
|
||||
|
||||
async def run(self, context) -> ActionNode:
|
||||
return await QUESTIONS.fill(context=context, llm=self.llm)
|
||||
return await QUESTIONS.fill(req=context, llm=self.llm)
|
||||
|
|
|
|||
|
|
@ -22,4 +22,4 @@ class PrepareInterview(Action):
|
|||
name: str = "PrepareInterview"
|
||||
|
||||
async def run(self, context):
|
||||
return await QUESTIONS.fill(context=context, llm=self.llm)
|
||||
return await QUESTIONS.fill(req=context, llm=self.llm)
|
||||
|
|
|
|||
|
|
@ -151,12 +151,12 @@ class WriteTasks(Action):
|
|||
return task_doc
|
||||
|
||||
async def _run_new_tasks(self, context: str):
|
||||
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
node = await PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
|
||||
return node
|
||||
|
||||
async def _merge(self, system_design_doc, task_doc) -> Document:
|
||||
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content)
|
||||
node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
node = await REFINED_PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
|
||||
task_doc.content = node.instruct_content.model_dump_json()
|
||||
return task_doc
|
||||
|
||||
|
|
|
|||
|
|
@ -578,7 +578,7 @@ class WriteCodeAN(Action):
|
|||
|
||||
async def run(self, context):
|
||||
self.llm.system_prompt = "You are an outstanding engineer and can implement any code"
|
||||
return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
return await WRITE_MOVE_NODE.fill(req=context, llm=self.llm, schema="json")
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class WriteCodePlanAndChange(Action):
|
|||
code=await self.get_old_codes(),
|
||||
)
|
||||
logger.info("Writing code plan and change..")
|
||||
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(req=context, llm=self.llm, schema="json")
|
||||
|
||||
async def get_old_codes(self) -> str:
|
||||
old_codes = await self.repo.srcs.get_all()
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ class WritePRD(Action):
|
|||
context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name)
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(
|
||||
context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
|
||||
req=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
|
||||
) # schema=schema
|
||||
return node
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ class WritePRD(Action):
|
|||
async def _is_bugfix(self, context: str) -> bool:
|
||||
if not self.repo.code_files_exists():
|
||||
return False
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(req=context, llm=self.llm)
|
||||
return node.get("issue_type") == "BUG"
|
||||
|
||||
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
|
||||
|
|
@ -248,14 +248,14 @@ class WritePRD(Action):
|
|||
|
||||
async def _is_related(self, req: Document, old_prd: Document) -> bool:
|
||||
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
|
||||
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
|
||||
node = await WP_IS_RELATIVE_NODE.fill(req=context, llm=self.llm)
|
||||
return node.get("is_relative") == "YES"
|
||||
|
||||
async def _merge(self, req: Document, related_doc: Document) -> Document:
|
||||
if not self.project_name:
|
||||
self.project_name = Path(self.project_path).name
|
||||
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
|
||||
node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
|
||||
node = await REFINED_PRD_NODE.fill(req=prompt, llm=self.llm, schema=self.prompt_schema)
|
||||
related_doc.content = node.instruct_content.model_dump_json()
|
||||
await self._rename_workspace(node)
|
||||
return related_doc
|
||||
|
|
|
|||
|
|
@ -36,4 +36,4 @@ class WriteReview(Action):
|
|||
name: str = "WriteReview"
|
||||
|
||||
async def run(self, context):
|
||||
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
return await WRITE_REVIEW_NODE.fill(req=context, llm=self.llm, schema="json")
|
||||
|
|
|
|||
|
|
@ -4,5 +4,9 @@ from metagpt.utils.yaml_model import YamlModel
|
|||
|
||||
|
||||
class ExperiencePoolConfig(YamlModel):
|
||||
enable_read: bool = Field(default=True, description="Enable to read from experience pool.")
|
||||
enable_write: bool = Field(default=True, description="Enable to write to experience pool.")
|
||||
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
|
||||
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
|
||||
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
|
||||
init_exp: bool = Field(
|
||||
default=False, description="Put some basic experiences associated with the roles into the experience pool."
|
||||
)
|
||||
|
|
|
|||
7
metagpt/exp_pool/context_builders/__init__.py
Normal file
7
metagpt/exp_pool/context_builders/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Context builders init."""
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
from metagpt.exp_pool.context_builders.simple import SimpleContextBuilder
|
||||
from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder
|
||||
|
||||
__all__ = ["BaseContextBuilder", "SimpleContextBuilder", "RoleZeroContextBuilder"]
|
||||
52
metagpt/exp_pool/context_builders/base.py
Normal file
52
metagpt/exp_pool/context_builders/base.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""Base context builder."""
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.exp_pool.schema import Experience
|
||||
|
||||
EXP_TEMPLATE = """Given the request: {req}, We can get the response: {resp}, Which scored: {score}."""
|
||||
|
||||
|
||||
class BaseContextBuilder(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
exps: list[Experience] = []
|
||||
|
||||
@abstractmethod
|
||||
async def build(self, *args, **kwargs) -> Any:
|
||||
"""Build context from parameters."""
|
||||
|
||||
def format_exps(self) -> str:
|
||||
"""Format experiences into a numbered list of strings."""
|
||||
|
||||
result = []
|
||||
for i, exp in enumerate(self.exps, start=1):
|
||||
result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=exp.metric.score.val))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
@staticmethod
|
||||
def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str:
|
||||
"""Replace the content between `start_marker` and `end_marker` in the text with `new_content`.
|
||||
|
||||
Args:
|
||||
text (str): The original text.
|
||||
new_content (str): The new content to replace the old content.
|
||||
start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'.
|
||||
end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'.
|
||||
|
||||
Returns:
|
||||
str: The text with the content replaced.
|
||||
"""
|
||||
|
||||
pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL)
|
||||
|
||||
def replacement(match):
|
||||
return f"{match.group(1)}{new_content}\n{match.group(3)}"
|
||||
|
||||
replaced_text = pattern.sub(replacement, text)
|
||||
return replaced_text
|
||||
26
metagpt/exp_pool/context_builders/role_zero.py
Normal file
26
metagpt/exp_pool/context_builders/role_zero.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""RoleZero context builder."""
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
|
||||
|
||||
class RoleZeroContextBuilder(BaseContextBuilder):
|
||||
async def build(self, *args, **kwargs) -> list[dict]:
|
||||
"""Builds the context by updating the req with formatted experiences.
|
||||
|
||||
If there are no experiences, retains the original examples in req, otherwise replaces the examples with the formatted experiences.
|
||||
"""
|
||||
|
||||
req = kwargs.get("req", [])
|
||||
if not req:
|
||||
return req
|
||||
|
||||
exps_str = self.format_exps()
|
||||
if not exps_str:
|
||||
return req
|
||||
|
||||
req[-1]["content"] = self.replace_example_content(req[-1].get("content", ""), exps_str)
|
||||
|
||||
return req
|
||||
|
||||
def replace_example_content(self, text: str, new_example_content: str) -> str:
|
||||
return self.replace_content_between_markers(text, "# Example", "# Instruction", new_example_content)
|
||||
24
metagpt/exp_pool/context_builders/simple.py
Normal file
24
metagpt/exp_pool/context_builders/simple.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""Simple context builder."""
|
||||
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
|
||||
SIMPLE_CONTEXT_TEMPLATE = """
|
||||
{req}
|
||||
|
||||
### Experiences
|
||||
-----
|
||||
{exps}
|
||||
-----
|
||||
|
||||
## Instruction
|
||||
Consider **Experiences** to generate a better answer.
|
||||
"""
|
||||
|
||||
|
||||
class SimpleContextBuilder(BaseContextBuilder):
|
||||
async def build(self, *args, **kwargs) -> str:
|
||||
req = kwargs.get("req", "")
|
||||
exps = self.format_exps()
|
||||
|
||||
return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req
|
||||
|
|
@ -2,18 +2,19 @@
|
|||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
|
||||
from metagpt.exp_pool.manager import ExperienceManager, exp_manager
|
||||
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
|
||||
from metagpt.exp_pool.scorers import ExperienceScorer, SimpleScorer
|
||||
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.reflection import get_class_name
|
||||
|
||||
ReturnType = TypeVar("ReturnType")
|
||||
|
||||
|
|
@ -21,42 +22,64 @@ ReturnType = TypeVar("ReturnType")
|
|||
def exp_cache(
|
||||
_func: Optional[Callable[..., ReturnType]] = None,
|
||||
query_type: QueryType = QueryType.SEMANTIC,
|
||||
scorer: Optional[ExperienceScorer] = None,
|
||||
manager: Optional[ExperienceManager] = None,
|
||||
pass_exps_to_func: bool = False,
|
||||
scorer: Optional[BaseScorer] = None,
|
||||
perfect_judge: Optional[BasePerfectJudge] = None,
|
||||
context_builder: Optional[BaseContextBuilder] = None,
|
||||
req_serialize: Optional[Callable[..., str]] = None,
|
||||
resp_serialize: Optional[Callable[..., str]] = None,
|
||||
resp_deserialize: Optional[Callable[[str], Any]] = None,
|
||||
tag: Optional[str] = None,
|
||||
):
|
||||
"""Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience.
|
||||
|
||||
This can be applied to both synchronous and asynchronous functions.
|
||||
1. This can be applied to both synchronous and asynchronous functions.
|
||||
2. The function must have a `req` parameter, and it must be provided as a keyword argument.
|
||||
3. If `config.exp_pool.enable_read` is False, the decorator will just directly execute the function.
|
||||
|
||||
Args:
|
||||
_func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache().
|
||||
query_type: The type of query to be used when fetching experiences.
|
||||
scorer: Evaluate experience. Default SimpleScorer.
|
||||
manager: How to fetch, evaluate and save experience, etc. Default exp_manager.
|
||||
pass_exps_to_func: To control whether imperfect experiences are passed to the function, if True, the func must have a parameter named 'exps'.
|
||||
manager: How to fetch, evaluate and save experience, etc. Default to `exp_manager`.
|
||||
scorer: Evaluate experience. Default to `SimpleScorer()`.
|
||||
perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`.
|
||||
context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`.
|
||||
req_serialize: Serializes the request for storage. Defaults to `lambda req: str(req)`.
|
||||
resp_serialize: Serializes the function's return value for storage. Defaults to `lambda resp: str(resp)`.
|
||||
resp_deserialize: Deserializes the stored response back to the function's return value. Defaults to `lambda resp: resp`.
|
||||
tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
|
||||
if not config.exp_pool.enable_read:
|
||||
return func
|
||||
|
||||
@functools.wraps(func)
|
||||
async def get_or_create(args: Any, kwargs: Any) -> ReturnType:
|
||||
logger.info("exp_cache is enabled.")
|
||||
handler = ExpCacheHandler(
|
||||
func=func,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
query_type=query_type,
|
||||
exp_manager=manager,
|
||||
exp_scorer=scorer,
|
||||
pass_exps_to_func=pass_exps_to_func,
|
||||
exp_perfect_judge=perfect_judge,
|
||||
context_builder=context_builder,
|
||||
req_serialize=req_serialize,
|
||||
resp_serialize=resp_serialize,
|
||||
resp_deserialize=resp_deserialize,
|
||||
tag=tag,
|
||||
)
|
||||
|
||||
await handler.fetch_experiences(query_type)
|
||||
if exp := handler.get_one_perfect_experience():
|
||||
await handler.fetch_experiences()
|
||||
if exp := await handler.get_one_perfect_exp():
|
||||
return exp
|
||||
|
||||
await handler.execute_function()
|
||||
await handler.process_experience()
|
||||
|
||||
return handler._result
|
||||
return handler._raw_resp
|
||||
|
||||
return ExpCacheHandler.choose_wrapper(func, get_or_create)
|
||||
|
||||
|
|
@ -69,39 +92,59 @@ class ExpCacheHandler(BaseModel):
|
|||
func: Callable
|
||||
args: Any
|
||||
kwargs: Any
|
||||
query_type: QueryType = QueryType.SEMANTIC
|
||||
exp_manager: Optional[ExperienceManager] = None
|
||||
exp_scorer: Optional[ExperienceScorer] = None
|
||||
pass_exps_to_func: bool = False
|
||||
exp_scorer: Optional[BaseScorer] = None
|
||||
exp_perfect_judge: Optional[BasePerfectJudge] = None
|
||||
context_builder: Optional[BaseContextBuilder] = None
|
||||
req_serialize: Optional[Callable[..., str]] = None
|
||||
resp_serialize: Optional[Callable[..., str]] = None
|
||||
resp_deserialize: Optional[Callable[[str], Any]] = None
|
||||
tag: Optional[str] = None
|
||||
|
||||
_exps: list[Experience] = None
|
||||
_result: Any = None
|
||||
_req: str = ""
|
||||
_resp: str = ""
|
||||
_raw_resp: Any = None
|
||||
_score: Score = None
|
||||
_req: str = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize(self):
|
||||
if self.exp_manager is None:
|
||||
self.exp_manager = exp_manager
|
||||
self._validate_params()
|
||||
|
||||
if self.exp_scorer is None:
|
||||
self.exp_scorer = SimpleScorer()
|
||||
self.exp_manager = self.exp_manager or exp_manager
|
||||
self.exp_scorer = self.exp_scorer or SimpleScorer()
|
||||
self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge()
|
||||
self.context_builder = self.context_builder or SimpleContextBuilder()
|
||||
self.req_serialize = self.req_serialize or (lambda resp: str(resp))
|
||||
self.resp_serialize = self.resp_serialize or (lambda resp: str(resp))
|
||||
self.resp_deserialize = self.resp_deserialize or (lambda resp: resp)
|
||||
self.tag = self.tag or self._generate_tag()
|
||||
|
||||
self._req = self.generate_req_identifier(self.func, *self.args, **self.kwargs)
|
||||
self._req = self.req_serialize(self.kwargs["req"])
|
||||
|
||||
return self
|
||||
|
||||
async def fetch_experiences(self, query_type: QueryType):
|
||||
async def fetch_experiences(self):
|
||||
"""Fetch experiences by query_type."""
|
||||
|
||||
self._exps = await self.exp_manager.query_exps(self._req, query_type=query_type)
|
||||
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
|
||||
|
||||
def get_one_perfect_experience(self) -> Optional[Experience]:
|
||||
"""Get a potentially perfect experience."""
|
||||
return self.exp_manager.extract_one_perfect_exp(self._exps)
|
||||
async def get_one_perfect_exp(self) -> Optional[Any]:
|
||||
"""Get a potentially perfect experience, and resolve resp."""
|
||||
|
||||
for exp in self._exps:
|
||||
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
|
||||
logger.info(f"Get one perfect experience: {exp.req[:20]}...")
|
||||
return self.resp_deserialize(exp.resp)
|
||||
|
||||
return None
|
||||
|
||||
async def execute_function(self):
|
||||
"""Execute the function, and save the result."""
|
||||
self._result = await self._execute_function()
|
||||
"""Execute the function, and save resp."""
|
||||
|
||||
self._raw_resp = await self._execute_function()
|
||||
self._resp = self.resp_serialize(self._raw_resp)
|
||||
|
||||
@handle_exception
|
||||
async def process_experience(self):
|
||||
|
|
@ -110,41 +153,21 @@ class ExpCacheHandler(BaseModel):
|
|||
Evaluates and saves experience.
|
||||
Use `handle_exception` to ensure robustness, do not stop subsequent operations.
|
||||
"""
|
||||
|
||||
await self.evaluate_experience()
|
||||
self.save_experience()
|
||||
|
||||
async def evaluate_experience(self):
|
||||
"""Evaluate the experience, and save the score."""
|
||||
|
||||
self._score = await self.exp_scorer.evaluate(self.func, self._result, self.args, self.kwargs)
|
||||
self._score = await self.exp_scorer.evaluate(self.func, self._resp, self.args, self.kwargs)
|
||||
|
||||
def save_experience(self):
|
||||
"""Save the new experience."""
|
||||
|
||||
exp = Experience(req=self._req, resp=self._result, metric=Metric(score=self._score))
|
||||
|
||||
exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score))
|
||||
self.exp_manager.create_exp(exp)
|
||||
|
||||
@classmethod
|
||||
def generate_req_identifier(cls, func, *args, **kwargs) -> str:
|
||||
"""Generate a unique request identifier for any given function and its arguments.
|
||||
|
||||
Serializing args and kwargs into JSON strings and replacing ',' with '~' and ':' with '!'.
|
||||
|
||||
Return Example:
|
||||
SimpleClass.test_method@[1~2]@{"c"!3}
|
||||
"""
|
||||
cls_name = get_class_name(func)
|
||||
func_name = f"{cls_name}.{func.__name__}" if cls_name else func.__name__
|
||||
|
||||
if cls_name and args and inspect.isfunction(func):
|
||||
args = args[1:]
|
||||
|
||||
args = cls._serialize_and_replace(args)
|
||||
kwargs = cls._serialize_and_replace(kwargs)
|
||||
|
||||
return f"{func_name}@{args}@{kwargs}"
|
||||
|
||||
@staticmethod
|
||||
def choose_wrapper(func, wrapped_func):
|
||||
"""Choose how to run wrapped_func based on whether the function is asynchronous."""
|
||||
|
|
@ -158,25 +181,31 @@ class ExpCacheHandler(BaseModel):
|
|||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
@classmethod
|
||||
def _serialize_and_replace(cls, data):
|
||||
json_str = json.dumps(data)
|
||||
return json_str.replace(", ", "~").replace(": ", "!")
|
||||
def _validate_params(self):
|
||||
if "req" not in self.kwargs:
|
||||
raise ValueError("`req` must be provided as a keyword argument.")
|
||||
|
||||
def _generate_tag(self) -> str:
|
||||
"""Generates a tag for the self.func.
|
||||
|
||||
"ClassName.method_name" if the first argument is a class instance, otherwise just "function_name".
|
||||
"""
|
||||
|
||||
if self.args and hasattr(self.args[0], "__class__"):
|
||||
cls_name = type(self.args[0]).__name__
|
||||
return f"{cls_name}.{self.func.__name__}"
|
||||
|
||||
return self.func.__name__
|
||||
|
||||
async def _build_context(self) -> str:
|
||||
self.context_builder.exps = self._exps
|
||||
|
||||
return await self.context_builder.build(*self.args, **self.kwargs)
|
||||
|
||||
async def _execute_function(self):
|
||||
if self.pass_exps_to_func:
|
||||
return await self._execute_function_with_exps()
|
||||
self.kwargs["req"] = await self._build_context()
|
||||
|
||||
return await self._execute_function_without_exps()
|
||||
|
||||
async def _execute_function_without_exps(self):
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
return await self.func(*self.args, **self.kwargs)
|
||||
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
async def _execute_function_with_exps(self):
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
return await self.func(*self.args, **self.kwargs, exps=self._exps)
|
||||
|
||||
return self.func(*self.args, **self.kwargs, exps=self._exps)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,22 @@
|
|||
"""Experience Manager."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from metagpt.config2 import Config, config
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType
|
||||
from metagpt.exp_pool.schema import (
|
||||
DEFAULT_COLLECTION_NAME,
|
||||
DEFAULT_SIMILARITY_TOP_K,
|
||||
EntryType,
|
||||
Experience,
|
||||
Metric,
|
||||
QueryType,
|
||||
Score,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
|
||||
from metagpt.strategy.experience_retriever import ENGINEER_EXAMPLE, TL_EXAMPLE
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
|
|
@ -27,14 +36,33 @@ class ExperienceManager(BaseModel):
|
|||
@model_validator(mode="after")
|
||||
def initialize(self):
|
||||
if self.storage is None:
|
||||
self.storage = SimpleEngine.from_objs(
|
||||
retriever_configs=[
|
||||
ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data")
|
||||
],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
retriever_configs = [
|
||||
ChromaRetrieverConfig(
|
||||
persist_path=self.config.exp_pool.persist_path,
|
||||
collection_name=DEFAULT_COLLECTION_NAME,
|
||||
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
|
||||
)
|
||||
]
|
||||
ranker_configs = [LLMRankerConfig()]
|
||||
|
||||
self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
self.init_exp_pool()
|
||||
|
||||
return self
|
||||
|
||||
@handle_exception
|
||||
def init_exp_pool(self):
|
||||
if not self.config.exp_pool.init_exp:
|
||||
return
|
||||
|
||||
if self._has_exps():
|
||||
return
|
||||
|
||||
self._init_teamleader_exps()
|
||||
self._init_engineer2_exps()
|
||||
logger.info("`init_exp_pool` done.")
|
||||
|
||||
@handle_exception
|
||||
def create_exp(self, exp: Experience):
|
||||
"""Adds an experience to the storage if writing is enabled.
|
||||
|
|
@ -74,39 +102,26 @@ class ExperienceManager(BaseModel):
|
|||
|
||||
return exps
|
||||
|
||||
def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]:
|
||||
"""Extracts the first 'perfect' experience from a list of experiences.
|
||||
def _has_exps(self) -> bool:
|
||||
vector_store: ChromaVectorStore = self.storage._retriever._vector_store
|
||||
|
||||
Args:
|
||||
exps (list[Experience]): The experiences to evaluate.
|
||||
return bool(vector_store._get(limit=1, where={}).ids)
|
||||
|
||||
Returns:
|
||||
Optional[Experience]: The first perfect experience if found, otherwise None.
|
||||
"""
|
||||
for exp in exps:
|
||||
if self.is_perfect_exp(exp):
|
||||
return exp
|
||||
def _init_exp(self, req: str, resp: str, tag: str, metric: Metric = None):
|
||||
exp = Experience(
|
||||
req=req,
|
||||
resp=resp,
|
||||
entry_type=EntryType.MANUAL,
|
||||
tag=tag,
|
||||
metric=metric or Metric(score=Score(val=9, reason="Manual")),
|
||||
)
|
||||
self.create_exp(exp)
|
||||
|
||||
return None
|
||||
def _init_teamleader_exps(self):
|
||||
self._init_exp(req=TL_EXAMPLE, resp=TL_EXAMPLE, tag="TeamLeader.llm_cached_aask")
|
||||
|
||||
@staticmethod
|
||||
def is_perfect_exp(exp: Experience) -> bool:
|
||||
"""Determines if an experience is considered 'perfect'.
|
||||
|
||||
Args:
|
||||
exp (Experience): The experience to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the experience is manually entered, otherwise False.
|
||||
"""
|
||||
if not exp:
|
||||
return False
|
||||
|
||||
# TODO: need more metrics
|
||||
if exp.metric and exp.metric.score.val == MAX_SCORE:
|
||||
return True
|
||||
|
||||
return False
|
||||
def _init_engineer2_exps(self):
|
||||
self._init_exp(req=ENGINEER_EXAMPLE, resp=ENGINEER_EXAMPLE, tag="Engineer2.llm_cached_aask")
|
||||
|
||||
|
||||
exp_manager = ExperienceManager()
|
||||
|
|
|
|||
6
metagpt/exp_pool/perfect_judges/__init__.py
Normal file
6
metagpt/exp_pool/perfect_judges/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""Perfect judges init."""
|
||||
|
||||
from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge
|
||||
from metagpt.exp_pool.perfect_judges.simple import SimplePerfectJudge
|
||||
|
||||
__all__ = ["BasePerfectJudge", "SimplePerfectJudge"]
|
||||
20
metagpt/exp_pool/perfect_judges/base.py
Normal file
20
metagpt/exp_pool/perfect_judges/base.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Base perfect judge."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.exp_pool.schema import Experience
|
||||
|
||||
|
||||
class BasePerfectJudge(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool:
|
||||
"""Determine whether the experience is perfect.
|
||||
|
||||
Args:
|
||||
exp (Experience): The experience to evaluate.
|
||||
serialized_req (str): The serialized request to compare against the experience's request.
|
||||
"""
|
||||
27
metagpt/exp_pool/perfect_judges/simple.py
Normal file
27
metagpt/exp_pool/perfect_judges/simple.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Simple perfect judge."""
|
||||
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience
|
||||
|
||||
|
||||
class SimplePerfectJudge(BasePerfectJudge):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool:
|
||||
"""Determine whether the experience is perfect.
|
||||
|
||||
Args:
|
||||
exp (Experience): The experience to evaluate.
|
||||
serialized_req (str): The serialized request to compare against the experience's request.
|
||||
|
||||
Returns:
|
||||
bool: True if the serialized request matches the experience's request and the experience's score is perfect, False otherwise.
|
||||
"""
|
||||
|
||||
if not exp.metric or not exp.metric.score:
|
||||
return False
|
||||
|
||||
return serialized_req == exp.req and exp.metric.score.val == MAX_SCORE
|
||||
|
|
@ -1,13 +1,16 @@
|
|||
"""Experience schema."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
MAX_SCORE = 10
|
||||
|
||||
DEFAULT_COLLECTION_NAME = "experience_pool"
|
||||
DEFAULT_SIMILARITY_TOP_K = 2
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""Type of query experiences."""
|
||||
|
|
@ -59,7 +62,7 @@ class Experience(BaseModel):
|
|||
"""Experience."""
|
||||
|
||||
req: str = Field(..., description="")
|
||||
resp: Any = Field(..., description="The type is string/json/code.")
|
||||
resp: str = Field(..., description="The type is string/json/code.")
|
||||
metric: Optional[Metric] = Field(default=None, description="Metric.")
|
||||
exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.")
|
||||
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Experience scorers init."""
|
||||
"""Scorers init."""
|
||||
|
||||
from metagpt.exp_pool.scorers.base import ExperienceScorer
|
||||
from metagpt.exp_pool.scorers.base import BaseScorer
|
||||
from metagpt.exp_pool.scorers.simple import SimpleScorer
|
||||
|
||||
__all__ = ["ExperienceScorer", "SimpleScorer"]
|
||||
__all__ = ["BaseScorer", "SimpleScorer"]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Experience Scorers."""
|
||||
"""Base scorer."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict
|
|||
from metagpt.exp_pool.schema import Score
|
||||
|
||||
|
||||
class ExperienceScorer(BaseModel):
|
||||
class BaseScorer(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Simple Scorer."""
|
||||
"""Simple scorer."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
|
@ -7,7 +7,7 @@ from typing import Any, Callable
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.exp_pool.schema import Score
|
||||
from metagpt.exp_pool.scorers.base import ExperienceScorer
|
||||
from metagpt.exp_pool.scorers.base import BaseScorer
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -54,7 +54,7 @@ Follow instructions, generate output and make sure it follows the **Constraint**
|
|||
"""
|
||||
|
||||
|
||||
class SimpleScorer(ExperienceScorer):
|
||||
class SimpleScorer(BaseScorer):
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
|
|
@ -10,8 +11,14 @@ from pydantic import model_validator
|
|||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.di.run_command import RunCommand
|
||||
from metagpt.exp_pool import exp_cache
|
||||
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.di.role_zero import CMD_PROMPT, ROLE_INSTRUCTION, JSON_REPAIR_PROMPT
|
||||
from metagpt.prompts.di.role_zero import (
|
||||
CMD_PROMPT,
|
||||
JSON_REPAIR_PROMPT,
|
||||
ROLE_INSTRUCTION,
|
||||
)
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import AIMessage, Message, UserMessage
|
||||
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
|
||||
|
|
@ -21,8 +28,8 @@ from metagpt.tools.libs.editor import Editor
|
|||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
|
||||
from metagpt.utils.report import ThoughtReporter
|
||||
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType
|
||||
|
||||
|
||||
@register_tool(include_functions=["ask_human", "reply_to_human"])
|
||||
|
|
@ -154,11 +161,37 @@ class RoleZero(Role):
|
|||
context = self.llm.format_msg(memory + [UserMessage(content=prompt)])
|
||||
# print(*context, sep="\n" + "*" * 5 + "\n")
|
||||
async with ThoughtReporter(enable_llm_stream=True):
|
||||
self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg)
|
||||
self.command_rsp = await self.llm_cached_aask(req=context, system_msgs=self.system_msg)
|
||||
self.rc.memory.add(AIMessage(content=self.command_rsp))
|
||||
|
||||
return True
|
||||
|
||||
@exp_cache(context_builder=RoleZeroContextBuilder(), req_serialize=lambda req: RoleZero._req_serialize(req))
|
||||
async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str]) -> str:
|
||||
return await self.llm.aask(req, system_msgs=system_msgs)
|
||||
|
||||
@staticmethod
|
||||
def _req_serialize(req: list[dict]) -> str:
|
||||
"""Serialize the request for database storage, ensuring it is a string.
|
||||
|
||||
This function deep copies the request and modifies the content of the last element
|
||||
to remove unnecessary sections, making the request more concise.
|
||||
"""
|
||||
|
||||
req_copy = copy.deepcopy(req)
|
||||
|
||||
last_content = req_copy[-1]["content"]
|
||||
last_content = RoleZeroContextBuilder.replace_content_between_markers(
|
||||
last_content, "# Data Structure", "# Current Plan", ""
|
||||
)
|
||||
last_content = RoleZeroContextBuilder.replace_content_between_markers(
|
||||
last_content, "# Example", "# Instruction", ""
|
||||
)
|
||||
|
||||
req_copy[-1]["content"] = last_content
|
||||
|
||||
return json.dumps(req_copy)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
if self.use_fixed_sop:
|
||||
return await super()._act()
|
||||
|
|
@ -166,7 +199,7 @@ class RoleZero(Role):
|
|||
try:
|
||||
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
|
||||
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
|
||||
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class NaiveSolver(BaseSolver):
|
|||
self.graph.topological_sort()
|
||||
for key in self.graph.execution_order:
|
||||
op = self.graph.nodes[key]
|
||||
await op.fill(self.context, self.llm, mode="root")
|
||||
await op.fill(req=self.context, llm=self.llm, mode="root")
|
||||
|
||||
|
||||
class TOTSolver(BaseSolver):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""class tools, including method inspection, class attributes, inheritance relationships, etc."""
|
||||
import inspect
|
||||
|
||||
|
||||
def check_methods(C, *methods):
|
||||
|
|
@ -17,25 +16,3 @@ def check_methods(C, *methods):
|
|||
else:
|
||||
return NotImplemented
|
||||
return True
|
||||
|
||||
|
||||
def get_class_name(func) -> str:
|
||||
"""Returns the class name of the object that a method belongs to.
|
||||
|
||||
- If `func` is a bound method or a class method, extracts the class name directly from the method.
|
||||
- Returns an empty string if it's a regular function or cannot determine the class.
|
||||
"""
|
||||
if inspect.ismethod(func):
|
||||
if inspect.isclass(func.__self__):
|
||||
return func.__self__.__name__
|
||||
|
||||
return func.__self__.__class__.__name__
|
||||
|
||||
if inspect.isfunction(func):
|
||||
qualname_parts = func.__qualname__.split(".")
|
||||
if len(qualname_parts) > 1:
|
||||
class_name = qualname_parts[-2]
|
||||
if class_name.isidentifier():
|
||||
return class_name
|
||||
|
||||
return ""
|
||||
|
|
|
|||
|
|
@ -91,10 +91,10 @@ async def test_action_node_two_layer():
|
|||
assert node_b in root.children.values()
|
||||
|
||||
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
|
||||
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
|
||||
answer1 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
|
||||
assert "579" in answer1.content
|
||||
|
||||
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
|
||||
answer2 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
|
||||
assert "579" in answer2.content
|
||||
|
||||
|
||||
|
|
@ -112,7 +112,7 @@ async def test_action_node_review():
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = await node_a.review()
|
||||
|
||||
_ = await node_a.fill(context=None, llm=LLM())
|
||||
_ = await node_a.fill(req=None, llm=LLM())
|
||||
setattr(node_a.instruct_content, key, "game snake") # wrong content to review
|
||||
|
||||
review_comments = await node_a.review(review_mode=ReviewMode.AUTO)
|
||||
|
|
@ -126,7 +126,7 @@ async def test_action_node_review():
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = await node.review()
|
||||
|
||||
_ = await node.fill(context=None, llm=LLM())
|
||||
_ = await node.fill(req=None, llm=LLM())
|
||||
|
||||
review_comments = await node.review(review_mode=ReviewMode.AUTO)
|
||||
assert len(review_comments) == 1
|
||||
|
|
@ -151,7 +151,7 @@ async def test_action_node_revise():
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = await node_a.review()
|
||||
|
||||
_ = await node_a.fill(context=None, llm=LLM())
|
||||
_ = await node_a.fill(req=None, llm=LLM())
|
||||
setattr(node_a.instruct_content, key, "game snake") # wrong content to revise
|
||||
revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
|
|
@ -164,7 +164,7 @@ async def test_action_node_revise():
|
|||
with pytest.raises(RuntimeError):
|
||||
_ = await node.revise()
|
||||
|
||||
_ = await node.fill(context=None, llm=LLM())
|
||||
_ = await node.fill(req=None, llm=LLM())
|
||||
setattr(node.instruct_content, key, "game snake")
|
||||
revise_contents = await node.revise(revise_mode=ReviseMode.AUTO)
|
||||
assert len(revise_contents) == 1
|
||||
|
|
@ -257,7 +257,7 @@ async def test_action_node_with_image(mocker):
|
|||
invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png")
|
||||
img_base64 = encode_image(invoice_path)
|
||||
mocker.patch("metagpt.provider.openai_api.OpenAILLM._cons_kwargs", _cons_kwargs)
|
||||
node = await invoice.fill(context="", llm=LLM(), images=[img_base64])
|
||||
node = await invoice.fill(req="", llm=LLM(), images=[img_base64])
|
||||
assert node.instruct_content.invoice
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -38,7 +38,7 @@ async def test_write_design_an(mocker):
|
|||
mocker.patch("metagpt.actions.design_api_an.REFINED_DESIGN_NODE.fill", return_value=root)
|
||||
|
||||
prompt = NEW_REQ_TEMPLATE.format(old_design=DESIGN_SAMPLE, context=dict_to_markdown(REFINED_PRD_JSON))
|
||||
node = await REFINED_DESIGN_NODE.fill(prompt, llm)
|
||||
node = await REFINED_DESIGN_NODE.fill(req=prompt, llm=llm)
|
||||
|
||||
assert "Refined Implementation Approach" in node.instruct_content.model_dump()
|
||||
assert "Refined File list" in node.instruct_content.model_dump()
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ async def test_project_management_an(mocker):
|
|||
root.instruct_content.model_dump = mock_task_json
|
||||
mocker.patch("metagpt.actions.project_management_an.PM_NODE.fill", return_value=root)
|
||||
|
||||
node = await PM_NODE.fill(dict_to_markdown(REFINED_DESIGN_JSON), llm)
|
||||
node = await PM_NODE.fill(req=dict_to_markdown(REFINED_DESIGN_JSON), llm=llm)
|
||||
|
||||
assert "Logic Analysis" in node.instruct_content.model_dump()
|
||||
assert "Task list" in node.instruct_content.model_dump()
|
||||
|
|
@ -59,7 +59,7 @@ async def test_project_management_an_inc(mocker):
|
|||
mocker.patch("metagpt.actions.project_management_an.REFINED_PM_NODE.fill", return_value=root)
|
||||
|
||||
prompt = NEW_REQ_TEMPLATE.format(old_task=TASK_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON))
|
||||
node = await REFINED_PM_NODE.fill(prompt, llm)
|
||||
node = await REFINED_PM_NODE.fill(req=prompt, llm=llm)
|
||||
|
||||
assert "Refined Logic Analysis" in node.instruct_content.model_dump()
|
||||
assert "Refined Task list" in node.instruct_content.model_dump()
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ async def test_write_prd_an(mocker):
|
|||
requirements=NEW_REQUIREMENT_SAMPLE,
|
||||
old_prd=PRD_SAMPLE,
|
||||
)
|
||||
node = await REFINED_PRD_NODE.fill(prompt, llm)
|
||||
node = await REFINED_PRD_NODE.fill(req=prompt, llm=llm)
|
||||
|
||||
assert "Refined Requirements" in node.instruct_content.model_dump()
|
||||
assert "Refined Product Goals" in node.instruct_content.model_dump()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,45 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import (
|
||||
EXP_TEMPLATE,
|
||||
BaseContextBuilder,
|
||||
Experience,
|
||||
)
|
||||
from metagpt.exp_pool.schema import Metric, Score
|
||||
|
||||
|
||||
class TestBaseContextBuilder:
|
||||
class ConcreteContextBuilder(BaseContextBuilder):
|
||||
async def build(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
@pytest.fixture
|
||||
def context_builder(self):
|
||||
return self.ConcreteContextBuilder()
|
||||
|
||||
def test_format_exps(self, context_builder):
|
||||
exp1 = Experience(req="req1", resp="resp1", metric=Metric(score=Score(val=8)))
|
||||
exp2 = Experience(req="req2", resp="resp2", metric=Metric(score=Score(val=9)))
|
||||
context_builder.exps = [exp1, exp2]
|
||||
|
||||
result = context_builder.format_exps()
|
||||
expected = "\n".join(
|
||||
[
|
||||
f"1. {EXP_TEMPLATE.format(req='req1', resp='resp1', score=8)}",
|
||||
f"2. {EXP_TEMPLATE.format(req='req2', resp='resp2', score=9)}",
|
||||
]
|
||||
)
|
||||
assert result == expected
|
||||
|
||||
def test_replace_content_between_markers(self):
|
||||
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
|
||||
new_content = "New content"
|
||||
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
|
||||
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
|
||||
assert result == expected
|
||||
|
||||
def test_replace_content_between_markers_no_match(self):
|
||||
text = "Start\nNo markers\nEnd"
|
||||
new_content = "New content"
|
||||
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
|
||||
assert result == text
|
||||
|
|
@ -0,0 +1,38 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder
|
||||
|
||||
|
||||
class TestRoleZeroContextBuilder:
|
||||
@pytest.fixture
|
||||
def context_builder(self):
|
||||
return RoleZeroContextBuilder()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_empty_req(self, context_builder):
|
||||
result = await context_builder.build(req=[])
|
||||
assert result == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_no_experiences(self, context_builder, mocker):
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="")
|
||||
req = [{"content": "Original content"}]
|
||||
result = await context_builder.build(req=req)
|
||||
assert result == req
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_with_experiences(self, context_builder, mocker):
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="Formatted experiences")
|
||||
mocker.patch.object(RoleZeroContextBuilder, "replace_example_content", return_value="Updated content")
|
||||
req = [{"content": "Original content"}]
|
||||
result = await context_builder.build(req=req)
|
||||
assert result == [{"content": "Updated content"}]
|
||||
|
||||
def test_replace_example_content(self, context_builder, mocker):
|
||||
mocker.patch.object(BaseContextBuilder, "replace_content_between_markers", return_value="Replaced content")
|
||||
result = context_builder.replace_example_content("Original text", "New example content")
|
||||
assert result == "Replaced content"
|
||||
context_builder.replace_content_between_markers.assert_called_once_with(
|
||||
"Original text", "# Example", "# Instruction", "New example content"
|
||||
)
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
from metagpt.exp_pool.context_builders.simple import (
|
||||
SIMPLE_CONTEXT_TEMPLATE,
|
||||
SimpleContextBuilder,
|
||||
)
|
||||
|
||||
|
||||
class TestSimpleContextBuilder:
|
||||
@pytest.fixture
|
||||
def context_builder(self):
|
||||
return SimpleContextBuilder()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_with_experiences(self, context_builder, mocker):
|
||||
# Mock the format_exps method
|
||||
mock_exps = "Mocked experiences"
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps)
|
||||
|
||||
req = "Test request"
|
||||
result = await context_builder.build(req=req)
|
||||
|
||||
expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=mock_exps)
|
||||
assert result == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_without_experiences(self, context_builder, mocker):
|
||||
# Mock the format_exps method to return an empty string
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="")
|
||||
|
||||
req = "Test request"
|
||||
result = await context_builder.build(req=req)
|
||||
|
||||
assert result == req
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_build_without_req(self, context_builder, mocker):
|
||||
# Mock the format_exps method
|
||||
mock_exps = "Mocked experiences"
|
||||
mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps)
|
||||
|
||||
result = await context_builder.build()
|
||||
|
||||
expected = SIMPLE_CONTEXT_TEMPLATE.format(req="", exps=mock_exps)
|
||||
assert result == expected
|
||||
|
|
@ -1,29 +1,17 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
|
||||
import pytest
|
||||
|
||||
from metagpt.exp_pool.context_builders import SimpleContextBuilder
|
||||
from metagpt.exp_pool.decorator import ExpCacheHandler, exp_cache
|
||||
from metagpt.exp_pool.manager import ExperienceManager
|
||||
from metagpt.exp_pool.perfect_judges import SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import Experience, QueryType, Score
|
||||
from metagpt.exp_pool.scorers import SimpleScorer
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
def for_test_function(a, b, c=None):
|
||||
return a + b if c is None else a + b + c
|
||||
|
||||
|
||||
class ForTestClass:
|
||||
def for_test_method(self, x, y):
|
||||
return x * y
|
||||
|
||||
@classmethod
|
||||
def for_test_class_method(cls, x, y):
|
||||
return x**y
|
||||
|
||||
|
||||
class TestExpCache:
|
||||
class TestExpCacheHandler:
|
||||
@pytest.fixture
|
||||
def mock_func(self, mocker):
|
||||
return mocker.AsyncMock()
|
||||
|
|
@ -34,7 +22,6 @@ class TestExpCache:
|
|||
manager.storage = mocker.MagicMock(spec=SimpleEngine)
|
||||
manager.query_exps = mocker.AsyncMock()
|
||||
manager.create_exp = mocker.MagicMock()
|
||||
manager.extract_one_perfect_exp = mocker.MagicMock()
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -44,174 +31,165 @@ class TestExpCache:
|
|||
return scorer
|
||||
|
||||
@pytest.fixture
|
||||
def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer):
|
||||
def mock_perfect_judge(self, mocker):
|
||||
return mocker.MagicMock(spec=SimplePerfectJudge)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_context_builder(self, mocker):
|
||||
return mocker.MagicMock(spec=SimpleContextBuilder)
|
||||
|
||||
@pytest.fixture
|
||||
def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer, mock_perfect_judge, mock_context_builder):
|
||||
return ExpCacheHandler(
|
||||
func=mock_func, args=(), kwargs={}, exp_manager=mock_exp_manager, exp_scorer=mock_scorer, pass_exps=False
|
||||
func=mock_func,
|
||||
args=(),
|
||||
kwargs={"req": "test_req"},
|
||||
exp_manager=mock_exp_manager,
|
||||
exp_scorer=mock_scorer,
|
||||
exp_perfect_judge=mock_perfect_judge,
|
||||
context_builder=mock_context_builder,
|
||||
)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_experiences(self, exp_cache_handler, mock_exp_manager):
|
||||
await exp_cache_handler.fetch_experiences(QueryType.SEMANTIC)
|
||||
mock_exp_manager.query_exps.assert_called_once()
|
||||
mock_exp_manager.query_exps.return_value = [Experience(req="test_req", resp="test_resp")]
|
||||
await exp_cache_handler.fetch_experiences()
|
||||
mock_exp_manager.query_exps.assert_called_once_with(
|
||||
"test_req", query_type=QueryType.SEMANTIC, tag=exp_cache_handler.tag
|
||||
)
|
||||
assert len(exp_cache_handler._exps) == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_perfect_experience_found(self, exp_cache_handler, mock_exp_manager, mock_func):
|
||||
# Setup: Assume perfect experience is found
|
||||
perfect_exp = Experience(req="req", resp="resp")
|
||||
mock_exp_manager.extract_one_perfect_exp.return_value = perfect_exp
|
||||
|
||||
# Exec
|
||||
exp_cache_handler._exps = [perfect_exp] # Simulate fetched experiences
|
||||
result = exp_cache_handler.get_one_perfect_experience()
|
||||
|
||||
# Assert
|
||||
assert result.resp == "resp"
|
||||
mock_func.assert_not_called() # Function should not be called
|
||||
async def test_get_one_perfect_exp(self, exp_cache_handler, mock_perfect_judge):
|
||||
exp = Experience(req="test_req", resp="perfect_resp")
|
||||
exp_cache_handler._exps = [exp]
|
||||
mock_perfect_judge.is_perfect_exp.return_value = True
|
||||
result = await exp_cache_handler.get_one_perfect_exp()
|
||||
assert result == "perfect_resp"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_execute_function_when_no_perfect_exp(self, exp_cache_handler, mock_exp_manager, mock_func):
|
||||
# Setup: No perfect experience
|
||||
mock_exp_manager.extract_one_perfect_exp.return_value = None
|
||||
mock_func.return_value = "Computed result"
|
||||
|
||||
# Exec
|
||||
async def test_execute_function(self, exp_cache_handler, mock_func, mock_context_builder):
|
||||
mock_context_builder.build.return_value = "built_context"
|
||||
mock_func.return_value = "function_result"
|
||||
await exp_cache_handler.execute_function()
|
||||
|
||||
# Assert
|
||||
assert exp_cache_handler._result == "Computed result"
|
||||
mock_func.assert_called_once()
|
||||
mock_context_builder.build.assert_called_once()
|
||||
mock_func.assert_called_once_with(req="built_context")
|
||||
assert exp_cache_handler._raw_resp == "function_result"
|
||||
assert exp_cache_handler._resp == "function_result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate_and_save_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager):
|
||||
# Setup
|
||||
mock_scorer.evaluate.return_value = Score(value=100)
|
||||
exp_cache_handler._result = "Computed result"
|
||||
|
||||
# Exec
|
||||
await exp_cache_handler.evaluate_experience()
|
||||
exp_cache_handler.save_experience()
|
||||
|
||||
# Assert
|
||||
async def test_process_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager):
|
||||
exp_cache_handler._resp = "test_resp"
|
||||
mock_scorer.evaluate.return_value = Score(val=8)
|
||||
await exp_cache_handler.process_experience()
|
||||
mock_scorer.evaluate.assert_called_once()
|
||||
mock_exp_manager.create_exp.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_function_execution_with_exps(self, exp_cache_handler, mock_exp_manager, mock_func):
|
||||
# Setup
|
||||
exp_cache_handler.pass_exps_to_func = True
|
||||
mock_func.return_value = "Async result with exps"
|
||||
mock_exp_manager.extract_one_perfect_exp.return_value = None
|
||||
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
|
||||
async def test_evaluate_experience(self, exp_cache_handler, mock_scorer):
|
||||
exp_cache_handler._resp = "test_resp"
|
||||
mock_scorer.evaluate.return_value = Score(val=9)
|
||||
await exp_cache_handler.evaluate_experience()
|
||||
assert exp_cache_handler._score.val == 9
|
||||
|
||||
# Exec
|
||||
await exp_cache_handler.execute_function()
|
||||
|
||||
# Assert
|
||||
mock_func.assert_called_once_with(exps=exp_cache_handler._exps)
|
||||
assert exp_cache_handler._result == "Async result with exps"
|
||||
|
||||
def test_sync_function_execution_with_exps(self, mocker, exp_cache_handler, mock_exp_manager, mock_func):
|
||||
# Setup
|
||||
exp_cache_handler.func = mocker.Mock(return_value="Sync result with exps")
|
||||
exp_cache_handler.pass_exps_to_func = True
|
||||
mock_exp_manager.extract_one_perfect_exp.return_value = None
|
||||
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
|
||||
|
||||
# Exec
|
||||
asyncio.get_event_loop().run_until_complete(exp_cache_handler.execute_function())
|
||||
|
||||
# Assert
|
||||
exp_cache_handler.func.assert_called_once_with(exps=exp_cache_handler._exps)
|
||||
assert exp_cache_handler._result == "Sync result with exps"
|
||||
|
||||
def test_wrapper_selection_async(self, mocker, exp_cache_handler, mock_func):
|
||||
# Setup
|
||||
mock_func = mocker.AsyncMock()
|
||||
|
||||
# Exec
|
||||
wrapper = ExpCacheHandler.choose_wrapper(mock_func, exp_cache_handler.execute_function)
|
||||
|
||||
# Assert
|
||||
assert asyncio.iscoroutinefunction(wrapper), "Wrapper should be asynchronous"
|
||||
|
||||
def test_wrapper_selection_sync(self, exp_cache_handler, mocker):
|
||||
# Setup
|
||||
sync_func = mocker.Mock()
|
||||
|
||||
# Exec
|
||||
wrapper = ExpCacheHandler.choose_wrapper(sync_func, exp_cache_handler.execute_function)
|
||||
|
||||
# Assert
|
||||
assert not asyncio.iscoroutinefunction(wrapper), "Wrapper should be synchronous"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"func, args, kwargs, expected",
|
||||
[
|
||||
(for_test_function, (1, 2), {"c": 3}, 'for_test_function@[1~2]@{"c"!3}'),
|
||||
(ForTestClass().for_test_method, (4, 5), {}, "ForTestClass.for_test_method@[4~5]@{}"),
|
||||
(ForTestClass.for_test_class_method, (6, 7), {}, "ForTestClass.for_test_class_method@[6~7]@{}"),
|
||||
(for_test_function, (), {}, "for_test_function@[]@{}"),
|
||||
(
|
||||
for_test_function,
|
||||
("hello", [1, 2]),
|
||||
{"key": "value"},
|
||||
'for_test_function@["hello"~[1~2]]@{"key"!"value"}',
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_generate_req_identifier(self, func, args, kwargs, expected):
|
||||
req_identifier = ExpCacheHandler.generate_req_identifier(func, *args, **kwargs)
|
||||
assert req_identifier == expected
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_with_perfect_experience(self, mocker, mock_exp_manager):
|
||||
# Mock perfect experience
|
||||
perfect_exp = Experience(req="test_req", resp="perfect_response")
|
||||
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[perfect_exp])
|
||||
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=perfect_exp)
|
||||
async_mock_func = mocker.AsyncMock()
|
||||
|
||||
# Setup
|
||||
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager)
|
||||
|
||||
# Exec
|
||||
result: Experience = await decorated_func()
|
||||
|
||||
# Assert
|
||||
assert result.resp == "perfect_response", "Should return the perfect experience response"
|
||||
async_mock_func.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_without_perfect_experience(self, mocker, mock_exp_manager):
|
||||
# Mock
|
||||
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[])
|
||||
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None)
|
||||
async_mock_func = mocker.AsyncMock(return_value="computed_response")
|
||||
async_mock_func.__signature__ = inspect.signature(for_test_function)
|
||||
|
||||
# Setup
|
||||
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager)
|
||||
|
||||
# Exec
|
||||
result = await decorated_func()
|
||||
|
||||
# Assert
|
||||
assert result == "computed_response", "Should execute and return the function's response"
|
||||
async_mock_func.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_saves_new_experience(self, mocker, mock_exp_manager, mock_scorer):
|
||||
# Mock
|
||||
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[])
|
||||
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None)
|
||||
async_mock_func = mocker.AsyncMock(return_value="computed_response")
|
||||
mock_scorer.evaluate = mocker.AsyncMock(return_value=Score(value=100))
|
||||
|
||||
# Setup
|
||||
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager, scorer=mock_scorer)
|
||||
|
||||
# Exec
|
||||
await decorated_func()
|
||||
|
||||
# Assert
|
||||
def test_save_experience(self, exp_cache_handler, mock_exp_manager):
|
||||
exp_cache_handler._req = "test_req"
|
||||
exp_cache_handler._resp = "test_resp"
|
||||
exp_cache_handler._score = Score(val=7)
|
||||
exp_cache_handler.save_experience()
|
||||
mock_exp_manager.create_exp.assert_called_once()
|
||||
|
||||
def test_choose_wrapper_async(self, mocker):
|
||||
async def async_func():
|
||||
pass
|
||||
|
||||
wrapper = ExpCacheHandler.choose_wrapper(async_func, mocker.AsyncMock())
|
||||
assert asyncio.iscoroutinefunction(wrapper)
|
||||
|
||||
def test_choose_wrapper_sync(self, mocker):
|
||||
def sync_func():
|
||||
pass
|
||||
|
||||
wrapper = ExpCacheHandler.choose_wrapper(sync_func, mocker.AsyncMock())
|
||||
assert not asyncio.iscoroutinefunction(wrapper)
|
||||
|
||||
def test_validate_params(self):
|
||||
with pytest.raises(ValueError):
|
||||
ExpCacheHandler(func=lambda x: x, args=(), kwargs={})
|
||||
|
||||
def test_generate_tag(self):
|
||||
class TestClass:
|
||||
def test_method(self):
|
||||
pass
|
||||
|
||||
handler = ExpCacheHandler(func=TestClass().test_method, args=(TestClass(),), kwargs={"req": "test"})
|
||||
assert handler._generate_tag() == "TestClass.test_method"
|
||||
|
||||
handler = ExpCacheHandler(func=lambda x: x, args=(), kwargs={"req": "test"})
|
||||
assert handler._generate_tag() == "<lambda>"
|
||||
|
||||
|
||||
class TestExpCache:
|
||||
@pytest.fixture
|
||||
def mock_exp_manager(self, mocker):
|
||||
manager = mocker.MagicMock(spec=ExperienceManager)
|
||||
manager.storage = mocker.MagicMock(spec=SimpleEngine)
|
||||
manager.query_exps = mocker.AsyncMock()
|
||||
manager.create_exp = mocker.MagicMock()
|
||||
return manager
|
||||
|
||||
@pytest.fixture
|
||||
def mock_scorer(self, mocker):
|
||||
scorer = mocker.MagicMock(spec=SimpleScorer)
|
||||
scorer.evaluate = mocker.AsyncMock(return_value=Score())
|
||||
return scorer
|
||||
|
||||
@pytest.fixture
|
||||
def mock_perfect_judge(self, mocker):
|
||||
return mocker.MagicMock(spec=SimplePerfectJudge)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config(self, mocker):
|
||||
return mocker.patch("metagpt.exp_pool.decorator.config")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_disabled(self, mock_config, mock_exp_manager):
|
||||
mock_config.exp_pool.enable_read = False
|
||||
|
||||
@exp_cache(manager=mock_exp_manager)
|
||||
async def test_func(req):
|
||||
return "result"
|
||||
|
||||
result = await test_func(req="test")
|
||||
assert result == "result"
|
||||
mock_exp_manager.query_exps.assert_not_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_enabled_no_perfect_exp(self, mock_config, mock_exp_manager, mock_scorer):
|
||||
mock_config.exp_pool.enable_read = True
|
||||
mock_exp_manager.query_exps.return_value = []
|
||||
|
||||
@exp_cache(manager=mock_exp_manager, scorer=mock_scorer)
|
||||
async def test_func(req):
|
||||
return "computed_result"
|
||||
|
||||
result = await test_func(req="test")
|
||||
assert result == "computed_result"
|
||||
mock_exp_manager.query_exps.assert_called()
|
||||
mock_exp_manager.create_exp.assert_called()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_exp_cache_enabled_with_perfect_exp(self, mock_config, mock_exp_manager, mock_perfect_judge):
|
||||
mock_config.exp_pool.enable_read = True
|
||||
perfect_exp = Experience(req="test", resp="perfect_result")
|
||||
mock_exp_manager.query_exps.return_value = [perfect_exp]
|
||||
mock_perfect_judge.is_perfect_exp.return_value = True
|
||||
|
||||
@exp_cache(manager=mock_exp_manager, perfect_judge=mock_perfect_judge)
|
||||
async def test_func(req):
|
||||
return "should_not_be_called"
|
||||
|
||||
result = await test_func(req="test")
|
||||
assert result == "perfect_result"
|
||||
mock_exp_manager.query_exps.assert_called_once()
|
||||
mock_exp_manager.create_exp.assert_not_called()
|
||||
|
|
|
|||
|
|
@ -4,20 +4,25 @@ from metagpt.config2 import Config
|
|||
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
|
||||
from metagpt.configs.llm_config import LLMConfig
|
||||
from metagpt.exp_pool.manager import ExperienceManager
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score
|
||||
from metagpt.exp_pool.schema import Experience
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
|
||||
|
||||
class TestExperienceManager:
|
||||
@pytest.fixture
|
||||
def mock_config(self):
|
||||
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
|
||||
return Config(
|
||||
llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, init_exp=False)
|
||||
)
|
||||
|
||||
@pytest.fixture
|
||||
def mock_storage(self, mocker):
|
||||
engine = mocker.MagicMock(spec=SimpleEngine)
|
||||
engine.add_objs = mocker.MagicMock()
|
||||
engine.aretrieve = mocker.AsyncMock(return_value=[])
|
||||
engine._retriever = mocker.MagicMock()
|
||||
engine._retriever._vector_store = mocker.MagicMock()
|
||||
engine._retriever._vector_store._get = mocker.MagicMock(return_value=mocker.MagicMock(ids=[]))
|
||||
return engine
|
||||
|
||||
@pytest.fixture
|
||||
|
|
@ -33,7 +38,7 @@ class TestExperienceManager:
|
|||
|
||||
def test_create_exp(self, mock_experience_manager, mock_experience):
|
||||
mock_experience_manager.create_exp(mock_experience)
|
||||
mock_experience_manager.storage.add_objs.assert_called_once_with([mock_experience])
|
||||
mock_experience_manager.storage.add_objs.assert_called_with([mock_experience])
|
||||
|
||||
def test_create_exp_write_disabled(self, mock_experience_manager, mock_experience, mock_config):
|
||||
mock_config.exp_pool.enable_write = False
|
||||
|
|
@ -60,18 +65,44 @@ class TestExperienceManager:
|
|||
result = await mock_experience_manager.query_exps("query")
|
||||
assert result == []
|
||||
|
||||
def test_extract_one_perfect_exp(self, mock_experience_manager):
|
||||
experiences = [
|
||||
Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))),
|
||||
Experience(req="req", resp="resp"),
|
||||
]
|
||||
perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences)
|
||||
assert perfect_exp is not None
|
||||
assert perfect_exp.metric.score.val == MAX_SCORE
|
||||
def test_init_exp_pool(self, mock_experience_manager, mock_config, mocker):
|
||||
mock_experience_manager._has_exps = mocker.MagicMock(return_value=False)
|
||||
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
|
||||
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
|
||||
|
||||
def test_is_perfect_exp(self):
|
||||
exp = Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
|
||||
assert ExperienceManager.is_perfect_exp(exp) == True
|
||||
mock_config.exp_pool.init_exp = True
|
||||
mock_experience_manager.init_exp_pool()
|
||||
|
||||
exp = Experience(req="req", resp="resp")
|
||||
assert ExperienceManager.is_perfect_exp(exp) == False
|
||||
mock_experience_manager._has_exps.assert_called_once()
|
||||
mock_experience_manager._init_teamleader_exps.assert_called_once()
|
||||
mock_experience_manager._init_engineer2_exps.assert_called_once()
|
||||
|
||||
def test_init_exp_pool_already_has_exps(self, mock_experience_manager, mock_config, mocker):
|
||||
mock_experience_manager._has_exps = mocker.MagicMock(return_value=True)
|
||||
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
|
||||
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
|
||||
|
||||
mock_config.exp_pool.init_exp = True
|
||||
mock_experience_manager.init_exp_pool()
|
||||
|
||||
mock_experience_manager._has_exps.assert_called_once()
|
||||
mock_experience_manager._init_teamleader_exps.assert_not_called()
|
||||
mock_experience_manager._init_engineer2_exps.assert_not_called()
|
||||
|
||||
def test_has_exps(self, mock_experience_manager, mock_storage):
|
||||
mock_storage._retriever._vector_store._get.return_value.ids = ["id1"]
|
||||
|
||||
assert mock_experience_manager._has_exps() is True
|
||||
|
||||
mock_storage._retriever._vector_store._get.return_value.ids = []
|
||||
assert mock_experience_manager._has_exps() is False
|
||||
|
||||
def test_init_teamleader_exps(self, mock_experience_manager, mocker):
|
||||
mock_experience_manager._init_exp = mocker.MagicMock()
|
||||
mock_experience_manager._init_teamleader_exps()
|
||||
mock_experience_manager._init_exp.assert_called_once()
|
||||
|
||||
def test_init_engineer2_exps(self, mock_experience_manager, mocker):
|
||||
mock_experience_manager._init_exp = mocker.MagicMock()
|
||||
mock_experience_manager._init_engineer2_exps()
|
||||
mock_experience_manager._init_exp.assert_called_once()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.perfect_judges import SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score
|
||||
|
||||
|
||||
class TestSimplePerfectJudge:
|
||||
@pytest.fixture
|
||||
def simple_perfect_judge(self):
|
||||
return SimplePerfectJudge()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_perfect_match(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
|
||||
assert result is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_imperfect_score(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE - 1)))
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_mismatched_request(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "different_request")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_no_metric(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp")
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
|
||||
assert result is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_is_perfect_exp_no_score(self, simple_perfect_judge):
|
||||
exp = Experience(req="test_request", resp="resp", metric=Metric())
|
||||
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
|
||||
assert result is False
|
||||
49
tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py
Normal file
49
tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import pytest
|
||||
|
||||
from metagpt.exp_pool.schema import Score
|
||||
from metagpt.exp_pool.scorers.simple import SIMPLE_SCORER_TEMPLATE, SimpleScorer
|
||||
from metagpt.llm import BaseLLM
|
||||
|
||||
|
||||
class TestSimpleScorer:
|
||||
@pytest.fixture
|
||||
def mock_llm(self, mocker):
|
||||
mock_llm = mocker.MagicMock(spec=BaseLLM)
|
||||
return mock_llm
|
||||
|
||||
@pytest.fixture
|
||||
def simple_scorer(self, mock_llm):
|
||||
return SimpleScorer(llm=mock_llm)
|
||||
|
||||
def test_init(self, mock_llm):
|
||||
scorer = SimpleScorer(llm=mock_llm)
|
||||
assert isinstance(scorer.llm, BaseLLM)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_evaluate(self, simple_scorer, mock_llm):
|
||||
# Mock function to evaluate
|
||||
def mock_func(a, b):
|
||||
"""This is a mock function."""
|
||||
return a + b
|
||||
|
||||
# Mock LLM response
|
||||
mock_llm.aask.return_value = '```json\n{"val": 8, "reason": "Good performance"}\n```'
|
||||
|
||||
# Test evaluate method
|
||||
result = await simple_scorer.evaluate(mock_func, 5, args=(2, 3), kwargs={})
|
||||
|
||||
# Assert LLM was called with correct prompt
|
||||
expected_prompt = SIMPLE_SCORER_TEMPLATE.format(
|
||||
func_name=mock_func.__name__,
|
||||
func_doc=mock_func.__doc__,
|
||||
func_signature="(a, b)",
|
||||
func_args=(2, 3),
|
||||
func_kwargs={},
|
||||
func_result=5,
|
||||
)
|
||||
mock_llm.aask.assert_called_once_with(expected_prompt)
|
||||
|
||||
# Assert the result is correct
|
||||
assert isinstance(result, Score)
|
||||
assert result.val == 8
|
||||
assert result.reason == "Good performance"
|
||||
Loading…
Add table
Add a link
Reference in a new issue