use llm cache to make exp_pool

This commit is contained in:
seehi 2024-07-08 10:09:36 +08:00
parent d902a6f18c
commit c624c0ffc7
41 changed files with 844 additions and 368 deletions

View file

@ -75,8 +75,10 @@ s3:
bucket: "test"
exp_pool:
enable_read: true
enable_write: true
enable_read: false
enable_write: false
persist_path: .chroma_exp_data # The directory.
init_exp: false # If set to true, basic experiences associated with the roles will be added to the experience pool.
azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY"
azure_tts_region: "eastus"

View file

@ -7,16 +7,15 @@ from metagpt.exp_pool import exp_cache, exp_manager
from metagpt.logs import logger
@exp_cache(pass_exps_to_func=True)
async def produce(req, exps=None):
logger.info(f"Previous experiences: {exps}")
@exp_cache()
async def produce(req=""):
return f"{req} {uuid.uuid4().hex}"
async def main():
req = "Water"
resp = await produce(req)
resp = await produce(req=req)
logger.info(f"The resp of `produce{req}` is: {resp}")
exps = await exp_manager.query_exps(req)

View file

@ -50,9 +50,9 @@ async def generate_novel():
"Fill the empty nodes with your own ideas. Be creative! Use your own words!"
"I will tip you $100,000 if you write a good novel."
)
novel_node = await ActionNode.from_pydantic(Novel).fill(context=instruction, llm=LLM())
novel_node = await ActionNode.from_pydantic(Novel).fill(req=instruction, llm=LLM())
chap_node = await ActionNode.from_pydantic(Chapters).fill(
context=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM()
req=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM()
)
print(chap_node.instruct_content)

View file

@ -90,7 +90,7 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
msgs = args[0]
context = "## History Messages\n"
context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))])
return await self.node.fill(context=context, llm=self.llm)
return await self.node.fill(req=context, llm=self.llm)
async def run(self, *args, **kwargs):
"""Run action"""

View file

@ -18,6 +18,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT
from metagpt.exp_pool import exp_cache
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
@ -465,9 +466,33 @@ class ActionNode:
return self
@classmethod
def deserialize_to_action_node(cls, serialized_data) -> "ActionNode":
"""Customized deserialization, it will be triggered when a perfect experience is found.
ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'.
"""
class InstructContent:
def __init__(self, json_data):
self.json_data = json_data
def model_dump_json(self):
return self.json_data
action_node = cls(key="", expected_type=Type[str], instruction="", example="")
action_node.instruct_content = InstructContent(serialized_data)
return action_node
@exp_cache(
resp_serialize=lambda action_node: action_node.instruct_content.model_dump_json(),
resp_deserialize=lambda resp: ActionNode.deserialize_to_action_node(resp),
)
async def fill(
self,
context,
*,
req,
llm,
schema="json",
mode="auto",
@ -478,7 +503,7 @@ class ActionNode:
):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
:param req: Everything we should know when filling node.
:param llm: Large Language Model with pre-defined system message.
:param schema: json/markdown, determine example and output format.
- raw: free form text
@ -497,7 +522,7 @@ class ActionNode:
:return: self
"""
self.set_llm(llm)
self.set_context(context)
self.set_context(req)
if self.schema:
schema = self.schema

View file

@ -178,12 +178,12 @@ class WriteDesign(Action):
)
async def _new_system_design(self, context):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema)
node = await DESIGN_API_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
return node
async def _merge(self, prd_doc, system_design_doc):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema)
node = await REFINED_DESIGN_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
system_design_doc.content = node.instruct_content.model_dump_json()
return system_design_doc

View file

@ -22,4 +22,4 @@ class GenerateQuestions(Action):
name: str = "GenerateQuestions"
async def run(self, context) -> ActionNode:
return await QUESTIONS.fill(context=context, llm=self.llm)
return await QUESTIONS.fill(req=context, llm=self.llm)

View file

@ -22,4 +22,4 @@ class PrepareInterview(Action):
name: str = "PrepareInterview"
async def run(self, context):
return await QUESTIONS.fill(context=context, llm=self.llm)
return await QUESTIONS.fill(req=context, llm=self.llm)

View file

@ -151,12 +151,12 @@ class WriteTasks(Action):
return task_doc
async def _run_new_tasks(self, context: str):
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
node = await PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
return node
async def _merge(self, system_design_doc, task_doc) -> Document:
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content)
node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
node = await REFINED_PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
task_doc.content = node.instruct_content.model_dump_json()
return task_doc

View file

@ -578,7 +578,7 @@ class WriteCodeAN(Action):
async def run(self, context):
self.llm.system_prompt = "You are an outstanding engineer and can implement any code"
return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json")
return await WRITE_MOVE_NODE.fill(req=context, llm=self.llm, schema="json")
async def main():

View file

@ -229,7 +229,7 @@ class WriteCodePlanAndChange(Action):
code=await self.get_old_codes(),
)
logger.info("Writing code plan and change..")
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(req=context, llm=self.llm, schema="json")
async def get_old_codes(self) -> str:
old_codes = await self.repo.srcs.get_all()

View file

@ -211,7 +211,7 @@ class WritePRD(Action):
context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(
context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
req=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
) # schema=schema
return node
@ -238,7 +238,7 @@ class WritePRD(Action):
async def _is_bugfix(self, context: str) -> bool:
if not self.repo.code_files_exists():
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
node = await WP_ISSUE_TYPE_NODE.fill(req=context, llm=self.llm)
return node.get("issue_type") == "BUG"
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
@ -248,14 +248,14 @@ class WritePRD(Action):
async def _is_related(self, req: Document, old_prd: Document) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
node = await WP_IS_RELATIVE_NODE.fill(req=context, llm=self.llm)
return node.get("is_relative") == "YES"
async def _merge(self, req: Document, related_doc: Document) -> Document:
if not self.project_name:
self.project_name = Path(self.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
node = await REFINED_PRD_NODE.fill(req=prompt, llm=self.llm, schema=self.prompt_schema)
related_doc.content = node.instruct_content.model_dump_json()
await self._rename_workspace(node)
return related_doc

View file

@ -36,4 +36,4 @@ class WriteReview(Action):
name: str = "WriteReview"
async def run(self, context):
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")
return await WRITE_REVIEW_NODE.fill(req=context, llm=self.llm, schema="json")

View file

@ -4,5 +4,9 @@ from metagpt.utils.yaml_model import YamlModel
class ExperiencePoolConfig(YamlModel):
enable_read: bool = Field(default=True, description="Enable to read from experience pool.")
enable_write: bool = Field(default=True, description="Enable to write to experience pool.")
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
init_exp: bool = Field(
default=False, description="Put some basic experiences associated with the roles into the experience pool."
)

View file

@ -0,0 +1,7 @@
"""Context builders init."""
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
from metagpt.exp_pool.context_builders.simple import SimpleContextBuilder
from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder
__all__ = ["BaseContextBuilder", "SimpleContextBuilder", "RoleZeroContextBuilder"]

View file

@ -0,0 +1,52 @@
"""Base context builder."""
import re
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Experience
EXP_TEMPLATE = """Given the request: {req}, We can get the response: {resp}, Which scored: {score}."""
class BaseContextBuilder(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
exps: list[Experience] = []
@abstractmethod
async def build(self, *args, **kwargs) -> Any:
"""Build context from parameters."""
def format_exps(self) -> str:
"""Format experiences into a numbered list of strings."""
result = []
for i, exp in enumerate(self.exps, start=1):
result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=exp.metric.score.val))
return "\n".join(result)
@staticmethod
def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str:
"""Replace the content between `start_marker` and `end_marker` in the text with `new_content`.
Args:
text (str): The original text.
new_content (str): The new content to replace the old content.
start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'.
end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'.
Returns:
str: The text with the content replaced.
"""
pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL)
def replacement(match):
return f"{match.group(1)}{new_content}\n{match.group(3)}"
replaced_text = pattern.sub(replacement, text)
return replaced_text

View file

@ -0,0 +1,26 @@
"""RoleZero context builder."""
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
class RoleZeroContextBuilder(BaseContextBuilder):
async def build(self, *args, **kwargs) -> list[dict]:
"""Builds the context by updating the req with formatted experiences.
If there are no experiences, retains the original examples in req, otherwise replaces the examples with the formatted experiences.
"""
req = kwargs.get("req", [])
if not req:
return req
exps_str = self.format_exps()
if not exps_str:
return req
req[-1]["content"] = self.replace_example_content(req[-1].get("content", ""), exps_str)
return req
def replace_example_content(self, text: str, new_example_content: str) -> str:
return self.replace_content_between_markers(text, "# Example", "# Instruction", new_example_content)

View file

@ -0,0 +1,24 @@
"""Simple context builder."""
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
SIMPLE_CONTEXT_TEMPLATE = """
{req}
### Experiences
-----
{exps}
-----
## Instruction
Consider **Experiences** to generate a better answer.
"""
class SimpleContextBuilder(BaseContextBuilder):
async def build(self, *args, **kwargs) -> str:
req = kwargs.get("req", "")
exps = self.format_exps()
return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req

View file

@ -2,18 +2,19 @@
import asyncio
import functools
import inspect
import json
from typing import Any, Callable, Optional, TypeVar
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.config2 import config
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
from metagpt.exp_pool.manager import ExperienceManager, exp_manager
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
from metagpt.exp_pool.scorers import ExperienceScorer, SimpleScorer
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
from metagpt.logs import logger
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.reflection import get_class_name
ReturnType = TypeVar("ReturnType")
@ -21,42 +22,64 @@ ReturnType = TypeVar("ReturnType")
def exp_cache(
_func: Optional[Callable[..., ReturnType]] = None,
query_type: QueryType = QueryType.SEMANTIC,
scorer: Optional[ExperienceScorer] = None,
manager: Optional[ExperienceManager] = None,
pass_exps_to_func: bool = False,
scorer: Optional[BaseScorer] = None,
perfect_judge: Optional[BasePerfectJudge] = None,
context_builder: Optional[BaseContextBuilder] = None,
req_serialize: Optional[Callable[..., str]] = None,
resp_serialize: Optional[Callable[..., str]] = None,
resp_deserialize: Optional[Callable[[str], Any]] = None,
tag: Optional[str] = None,
):
"""Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience.
This can be applied to both synchronous and asynchronous functions.
1. This can be applied to both synchronous and asynchronous functions.
2. The function must have a `req` parameter, and it must be provided as a keyword argument.
3. If `config.exp_pool.enable_read` is False, the decorator will just directly execute the function.
Args:
_func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache().
query_type: The type of query to be used when fetching experiences.
scorer: Evaluate experience. Default SimpleScorer.
manager: How to fetch, evaluate and save experience, etc. Default exp_manager.
pass_exps_to_func: To control whether imperfect experiences are passed to the function, if True, the func must have a parameter named 'exps'.
manager: How to fetch, evaluate and save experience, etc. Default to `exp_manager`.
scorer: Evaluate experience. Default to `SimpleScorer()`.
perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`.
context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`.
req_serialize: Serializes the request for storage. Defaults to `lambda req: str(req)`.
resp_serialize: Serializes the function's return value for storage. Defaults to `lambda resp: str(resp)`.
resp_deserialize: Deserializes the stored response back to the function's return value. Defaults to `lambda resp: resp`.
tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`.
"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
if not config.exp_pool.enable_read:
return func
@functools.wraps(func)
async def get_or_create(args: Any, kwargs: Any) -> ReturnType:
logger.info("exp_cache is enabled.")
handler = ExpCacheHandler(
func=func,
args=args,
kwargs=kwargs,
query_type=query_type,
exp_manager=manager,
exp_scorer=scorer,
pass_exps_to_func=pass_exps_to_func,
exp_perfect_judge=perfect_judge,
context_builder=context_builder,
req_serialize=req_serialize,
resp_serialize=resp_serialize,
resp_deserialize=resp_deserialize,
tag=tag,
)
await handler.fetch_experiences(query_type)
if exp := handler.get_one_perfect_experience():
await handler.fetch_experiences()
if exp := await handler.get_one_perfect_exp():
return exp
await handler.execute_function()
await handler.process_experience()
return handler._result
return handler._raw_resp
return ExpCacheHandler.choose_wrapper(func, get_or_create)
@ -69,39 +92,59 @@ class ExpCacheHandler(BaseModel):
func: Callable
args: Any
kwargs: Any
query_type: QueryType = QueryType.SEMANTIC
exp_manager: Optional[ExperienceManager] = None
exp_scorer: Optional[ExperienceScorer] = None
pass_exps_to_func: bool = False
exp_scorer: Optional[BaseScorer] = None
exp_perfect_judge: Optional[BasePerfectJudge] = None
context_builder: Optional[BaseContextBuilder] = None
req_serialize: Optional[Callable[..., str]] = None
resp_serialize: Optional[Callable[..., str]] = None
resp_deserialize: Optional[Callable[[str], Any]] = None
tag: Optional[str] = None
_exps: list[Experience] = None
_result: Any = None
_req: str = ""
_resp: str = ""
_raw_resp: Any = None
_score: Score = None
_req: str = None
@model_validator(mode="after")
def initialize(self):
if self.exp_manager is None:
self.exp_manager = exp_manager
self._validate_params()
if self.exp_scorer is None:
self.exp_scorer = SimpleScorer()
self.exp_manager = self.exp_manager or exp_manager
self.exp_scorer = self.exp_scorer or SimpleScorer()
self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge()
self.context_builder = self.context_builder or SimpleContextBuilder()
self.req_serialize = self.req_serialize or (lambda resp: str(resp))
self.resp_serialize = self.resp_serialize or (lambda resp: str(resp))
self.resp_deserialize = self.resp_deserialize or (lambda resp: resp)
self.tag = self.tag or self._generate_tag()
self._req = self.generate_req_identifier(self.func, *self.args, **self.kwargs)
self._req = self.req_serialize(self.kwargs["req"])
return self
async def fetch_experiences(self, query_type: QueryType):
async def fetch_experiences(self):
"""Fetch experiences by query_type."""
self._exps = await self.exp_manager.query_exps(self._req, query_type=query_type)
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
def get_one_perfect_experience(self) -> Optional[Experience]:
"""Get a potentially perfect experience."""
return self.exp_manager.extract_one_perfect_exp(self._exps)
async def get_one_perfect_exp(self) -> Optional[Any]:
"""Get a potentially perfect experience, and resolve resp."""
for exp in self._exps:
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
logger.info(f"Get one perfect experience: {exp.req[:20]}...")
return self.resp_deserialize(exp.resp)
return None
async def execute_function(self):
"""Execute the function, and save the result."""
self._result = await self._execute_function()
"""Execute the function, and save resp."""
self._raw_resp = await self._execute_function()
self._resp = self.resp_serialize(self._raw_resp)
@handle_exception
async def process_experience(self):
@ -110,41 +153,21 @@ class ExpCacheHandler(BaseModel):
Evaluates and saves experience.
Use `handle_exception` to ensure robustness, do not stop subsequent operations.
"""
await self.evaluate_experience()
self.save_experience()
async def evaluate_experience(self):
"""Evaluate the experience, and save the score."""
self._score = await self.exp_scorer.evaluate(self.func, self._result, self.args, self.kwargs)
self._score = await self.exp_scorer.evaluate(self.func, self._resp, self.args, self.kwargs)
def save_experience(self):
"""Save the new experience."""
exp = Experience(req=self._req, resp=self._result, metric=Metric(score=self._score))
exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score))
self.exp_manager.create_exp(exp)
@classmethod
def generate_req_identifier(cls, func, *args, **kwargs) -> str:
"""Generate a unique request identifier for any given function and its arguments.
Serializing args and kwargs into JSON strings and replacing ',' with '~' and ':' with '!'.
Return Example:
SimpleClass.test_method@[1~2]@{"c"!3}
"""
cls_name = get_class_name(func)
func_name = f"{cls_name}.{func.__name__}" if cls_name else func.__name__
if cls_name and args and inspect.isfunction(func):
args = args[1:]
args = cls._serialize_and_replace(args)
kwargs = cls._serialize_and_replace(kwargs)
return f"{func_name}@{args}@{kwargs}"
@staticmethod
def choose_wrapper(func, wrapped_func):
"""Choose how to run wrapped_func based on whether the function is asynchronous."""
@ -158,25 +181,31 @@ class ExpCacheHandler(BaseModel):
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
@classmethod
def _serialize_and_replace(cls, data):
json_str = json.dumps(data)
return json_str.replace(", ", "~").replace(": ", "!")
def _validate_params(self):
if "req" not in self.kwargs:
raise ValueError("`req` must be provided as a keyword argument.")
def _generate_tag(self) -> str:
"""Generates a tag for the self.func.
"ClassName.method_name" if the first argument is a class instance, otherwise just "function_name".
"""
if self.args and hasattr(self.args[0], "__class__"):
cls_name = type(self.args[0]).__name__
return f"{cls_name}.{self.func.__name__}"
return self.func.__name__
async def _build_context(self) -> str:
self.context_builder.exps = self._exps
return await self.context_builder.build(*self.args, **self.kwargs)
async def _execute_function(self):
if self.pass_exps_to_func:
return await self._execute_function_with_exps()
self.kwargs["req"] = await self._build_context()
return await self._execute_function_without_exps()
async def _execute_function_without_exps(self):
if asyncio.iscoroutinefunction(self.func):
return await self.func(*self.args, **self.kwargs)
return self.func(*self.args, **self.kwargs)
async def _execute_function_with_exps(self):
if asyncio.iscoroutinefunction(self.func):
return await self.func(*self.args, **self.kwargs, exps=self._exps)
return self.func(*self.args, **self.kwargs, exps=self._exps)

View file

@ -1,13 +1,22 @@
"""Experience Manager."""
from typing import Optional
from llama_index.vector_stores.chroma import ChromaVectorStore
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.config2 import Config, config
from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType
from metagpt.exp_pool.schema import (
DEFAULT_COLLECTION_NAME,
DEFAULT_SIMILARITY_TOP_K,
EntryType,
Experience,
Metric,
QueryType,
Score,
)
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
from metagpt.strategy.experience_retriever import ENGINEER_EXAMPLE, TL_EXAMPLE
from metagpt.utils.exceptions import handle_exception
@ -27,14 +36,33 @@ class ExperienceManager(BaseModel):
@model_validator(mode="after")
def initialize(self):
if self.storage is None:
self.storage = SimpleEngine.from_objs(
retriever_configs=[
ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data")
],
ranker_configs=[LLMRankerConfig()],
)
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
collection_name=DEFAULT_COLLECTION_NAME,
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
)
]
ranker_configs = [LLMRankerConfig()]
self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
self.init_exp_pool()
return self
@handle_exception
def init_exp_pool(self):
if not self.config.exp_pool.init_exp:
return
if self._has_exps():
return
self._init_teamleader_exps()
self._init_engineer2_exps()
logger.info("`init_exp_pool` done.")
@handle_exception
def create_exp(self, exp: Experience):
"""Adds an experience to the storage if writing is enabled.
@ -74,39 +102,26 @@ class ExperienceManager(BaseModel):
return exps
def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]:
"""Extracts the first 'perfect' experience from a list of experiences.
def _has_exps(self) -> bool:
vector_store: ChromaVectorStore = self.storage._retriever._vector_store
Args:
exps (list[Experience]): The experiences to evaluate.
return bool(vector_store._get(limit=1, where={}).ids)
Returns:
Optional[Experience]: The first perfect experience if found, otherwise None.
"""
for exp in exps:
if self.is_perfect_exp(exp):
return exp
def _init_exp(self, req: str, resp: str, tag: str, metric: Metric = None):
exp = Experience(
req=req,
resp=resp,
entry_type=EntryType.MANUAL,
tag=tag,
metric=metric or Metric(score=Score(val=9, reason="Manual")),
)
self.create_exp(exp)
return None
def _init_teamleader_exps(self):
self._init_exp(req=TL_EXAMPLE, resp=TL_EXAMPLE, tag="TeamLeader.llm_cached_aask")
@staticmethod
def is_perfect_exp(exp: Experience) -> bool:
"""Determines if an experience is considered 'perfect'.
Args:
exp (Experience): The experience to evaluate.
Returns:
bool: True if the experience is manually entered, otherwise False.
"""
if not exp:
return False
# TODO: need more metrics
if exp.metric and exp.metric.score.val == MAX_SCORE:
return True
return False
def _init_engineer2_exps(self):
self._init_exp(req=ENGINEER_EXAMPLE, resp=ENGINEER_EXAMPLE, tag="Engineer2.llm_cached_aask")
exp_manager = ExperienceManager()

View file

@ -0,0 +1,6 @@
"""Perfect judges init."""
from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge
from metagpt.exp_pool.perfect_judges.simple import SimplePerfectJudge
__all__ = ["BasePerfectJudge", "SimplePerfectJudge"]

View file

@ -0,0 +1,20 @@
"""Base perfect judge."""
from abc import ABC, abstractmethod
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Experience
class BasePerfectJudge(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool:
"""Determine whether the experience is perfect.
Args:
exp (Experience): The experience to evaluate.
serialized_req (str): The serialized request to compare against the experience's request.
"""

View file

@ -0,0 +1,27 @@
"""Simple perfect judge."""
from pydantic import ConfigDict
from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge
from metagpt.exp_pool.schema import MAX_SCORE, Experience
class SimplePerfectJudge(BasePerfectJudge):
model_config = ConfigDict(arbitrary_types_allowed=True)
async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool:
"""Determine whether the experience is perfect.
Args:
exp (Experience): The experience to evaluate.
serialized_req (str): The serialized request to compare against the experience's request.
Returns:
bool: True if the serialized request matches the experience's request and the experience's score is perfect, False otherwise.
"""
if not exp.metric or not exp.metric.score:
return False
return serialized_req == exp.req and exp.metric.score.val == MAX_SCORE

View file

@ -1,13 +1,16 @@
"""Experience schema."""
from enum import Enum
from typing import Any, Optional
from typing import Optional
from llama_index.core.schema import TextNode
from pydantic import BaseModel, Field
MAX_SCORE = 10
DEFAULT_COLLECTION_NAME = "experience_pool"
DEFAULT_SIMILARITY_TOP_K = 2
class QueryType(str, Enum):
"""Type of query experiences."""
@ -59,7 +62,7 @@ class Experience(BaseModel):
"""Experience."""
req: str = Field(..., description="")
resp: Any = Field(..., description="The type is string/json/code.")
resp: str = Field(..., description="The type is string/json/code.")
metric: Optional[Metric] = Field(default=None, description="Metric.")
exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.")
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")

View file

@ -1,6 +1,6 @@
"""Experience scorers init."""
"""Scorers init."""
from metagpt.exp_pool.scorers.base import ExperienceScorer
from metagpt.exp_pool.scorers.base import BaseScorer
from metagpt.exp_pool.scorers.simple import SimpleScorer
__all__ = ["ExperienceScorer", "SimpleScorer"]
__all__ = ["BaseScorer", "SimpleScorer"]

View file

@ -1,6 +1,6 @@
"""Experience Scorers."""
"""Base scorer."""
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Any, Callable
from pydantic import BaseModel, ConfigDict
@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Score
class ExperienceScorer(BaseModel):
class BaseScorer(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod

View file

@ -1,4 +1,4 @@
"""Simple Scorer."""
"""Simple scorer."""
import inspect
import json
@ -7,7 +7,7 @@ from typing import Any, Callable
from pydantic import Field
from metagpt.exp_pool.schema import Score
from metagpt.exp_pool.scorers.base import ExperienceScorer
from metagpt.exp_pool.scorers.base import BaseScorer
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import CodeParser
@ -54,7 +54,7 @@ Follow instructions, generate output and make sure it follows the **Constraint**
"""
class SimpleScorer(ExperienceScorer):
class SimpleScorer(BaseScorer):
llm: BaseLLM = Field(default_factory=LLM)
async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score:

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import copy
import inspect
import json
import re
@ -10,8 +11,14 @@ from pydantic import model_validator
from metagpt.actions import Action
from metagpt.actions.di.run_command import RunCommand
from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.logs import logger
from metagpt.prompts.di.role_zero import CMD_PROMPT, ROLE_INSTRUCTION, JSON_REPAIR_PROMPT
from metagpt.prompts.di.role_zero import (
CMD_PROMPT,
JSON_REPAIR_PROMPT,
ROLE_INSTRUCTION,
)
from metagpt.roles import Role
from metagpt.schema import AIMessage, Message, UserMessage
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
@ -21,8 +28,8 @@ from metagpt.tools.libs.editor import Editor
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
from metagpt.utils.report import ThoughtReporter
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType
@register_tool(include_functions=["ask_human", "reply_to_human"])
@ -154,11 +161,37 @@ class RoleZero(Role):
context = self.llm.format_msg(memory + [UserMessage(content=prompt)])
# print(*context, sep="\n" + "*" * 5 + "\n")
async with ThoughtReporter(enable_llm_stream=True):
self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg)
self.command_rsp = await self.llm_cached_aask(req=context, system_msgs=self.system_msg)
self.rc.memory.add(AIMessage(content=self.command_rsp))
return True
@exp_cache(context_builder=RoleZeroContextBuilder(), req_serialize=lambda req: RoleZero._req_serialize(req))
async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str]) -> str:
return await self.llm.aask(req, system_msgs=system_msgs)
@staticmethod
def _req_serialize(req: list[dict]) -> str:
"""Serialize the request for database storage, ensuring it is a string.
This function deep copies the request and modifies the content of the last element
to remove unnecessary sections, making the request more concise.
"""
req_copy = copy.deepcopy(req)
last_content = req_copy[-1]["content"]
last_content = RoleZeroContextBuilder.replace_content_between_markers(
last_content, "# Data Structure", "# Current Plan", ""
)
last_content = RoleZeroContextBuilder.replace_content_between_markers(
last_content, "# Example", "# Instruction", ""
)
req_copy[-1]["content"] = last_content
return json.dumps(req_copy)
async def _act(self) -> Message:
if self.use_fixed_sop:
return await super()._act()
@ -166,7 +199,7 @@ class RoleZero(Role):
try:
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError as e:
except json.JSONDecodeError:
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except Exception as e:

View file

@ -39,7 +39,7 @@ class NaiveSolver(BaseSolver):
self.graph.topological_sort()
for key in self.graph.execution_order:
op = self.graph.nodes[key]
await op.fill(self.context, self.llm, mode="root")
await op.fill(req=self.context, llm=self.llm, mode="root")
class TOTSolver(BaseSolver):

View file

@ -1,5 +1,4 @@
"""class tools, including method inspection, class attributes, inheritance relationships, etc."""
import inspect
def check_methods(C, *methods):
@ -17,25 +16,3 @@ def check_methods(C, *methods):
else:
return NotImplemented
return True
def get_class_name(func) -> str:
"""Returns the class name of the object that a method belongs to.
- If `func` is a bound method or a class method, extracts the class name directly from the method.
- Returns an empty string if it's a regular function or cannot determine the class.
"""
if inspect.ismethod(func):
if inspect.isclass(func.__self__):
return func.__self__.__name__
return func.__self__.__class__.__name__
if inspect.isfunction(func):
qualname_parts = func.__qualname__.split(".")
if len(qualname_parts) > 1:
class_name = qualname_parts[-2]
if class_name.isidentifier():
return class_name
return ""

View file

@ -91,10 +91,10 @@ async def test_action_node_two_layer():
assert node_b in root.children.values()
# FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST.
answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
answer1 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM())
assert "579" in answer1.content
answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
answer2 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM())
assert "579" in answer2.content
@ -112,7 +112,7 @@ async def test_action_node_review():
with pytest.raises(RuntimeError):
_ = await node_a.review()
_ = await node_a.fill(context=None, llm=LLM())
_ = await node_a.fill(req=None, llm=LLM())
setattr(node_a.instruct_content, key, "game snake") # wrong content to review
review_comments = await node_a.review(review_mode=ReviewMode.AUTO)
@ -126,7 +126,7 @@ async def test_action_node_review():
with pytest.raises(RuntimeError):
_ = await node.review()
_ = await node.fill(context=None, llm=LLM())
_ = await node.fill(req=None, llm=LLM())
review_comments = await node.review(review_mode=ReviewMode.AUTO)
assert len(review_comments) == 1
@ -151,7 +151,7 @@ async def test_action_node_revise():
with pytest.raises(RuntimeError):
_ = await node_a.review()
_ = await node_a.fill(context=None, llm=LLM())
_ = await node_a.fill(req=None, llm=LLM())
setattr(node_a.instruct_content, key, "game snake") # wrong content to revise
revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 1
@ -164,7 +164,7 @@ async def test_action_node_revise():
with pytest.raises(RuntimeError):
_ = await node.revise()
_ = await node.fill(context=None, llm=LLM())
_ = await node.fill(req=None, llm=LLM())
setattr(node.instruct_content, key, "game snake")
revise_contents = await node.revise(revise_mode=ReviseMode.AUTO)
assert len(revise_contents) == 1
@ -257,7 +257,7 @@ async def test_action_node_with_image(mocker):
invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png")
img_base64 = encode_image(invoice_path)
mocker.patch("metagpt.provider.openai_api.OpenAILLM._cons_kwargs", _cons_kwargs)
node = await invoice.fill(context="", llm=LLM(), images=[img_base64])
node = await invoice.fill(req="", llm=LLM(), images=[img_base64])
assert node.instruct_content.invoice

View file

@ -38,7 +38,7 @@ async def test_write_design_an(mocker):
mocker.patch("metagpt.actions.design_api_an.REFINED_DESIGN_NODE.fill", return_value=root)
prompt = NEW_REQ_TEMPLATE.format(old_design=DESIGN_SAMPLE, context=dict_to_markdown(REFINED_PRD_JSON))
node = await REFINED_DESIGN_NODE.fill(prompt, llm)
node = await REFINED_DESIGN_NODE.fill(req=prompt, llm=llm)
assert "Refined Implementation Approach" in node.instruct_content.model_dump()
assert "Refined File list" in node.instruct_content.model_dump()

View file

@ -42,7 +42,7 @@ async def test_project_management_an(mocker):
root.instruct_content.model_dump = mock_task_json
mocker.patch("metagpt.actions.project_management_an.PM_NODE.fill", return_value=root)
node = await PM_NODE.fill(dict_to_markdown(REFINED_DESIGN_JSON), llm)
node = await PM_NODE.fill(req=dict_to_markdown(REFINED_DESIGN_JSON), llm=llm)
assert "Logic Analysis" in node.instruct_content.model_dump()
assert "Task list" in node.instruct_content.model_dump()
@ -59,7 +59,7 @@ async def test_project_management_an_inc(mocker):
mocker.patch("metagpt.actions.project_management_an.REFINED_PM_NODE.fill", return_value=root)
prompt = NEW_REQ_TEMPLATE.format(old_task=TASK_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON))
node = await REFINED_PM_NODE.fill(prompt, llm)
node = await REFINED_PM_NODE.fill(req=prompt, llm=llm)
assert "Refined Logic Analysis" in node.instruct_content.model_dump()
assert "Refined Task list" in node.instruct_content.model_dump()

View file

@ -39,7 +39,7 @@ async def test_write_prd_an(mocker):
requirements=NEW_REQUIREMENT_SAMPLE,
old_prd=PRD_SAMPLE,
)
node = await REFINED_PRD_NODE.fill(prompt, llm)
node = await REFINED_PRD_NODE.fill(req=prompt, llm=llm)
assert "Refined Requirements" in node.instruct_content.model_dump()
assert "Refined Product Goals" in node.instruct_content.model_dump()

View file

@ -0,0 +1,45 @@
import pytest
from metagpt.exp_pool.context_builders.base import (
EXP_TEMPLATE,
BaseContextBuilder,
Experience,
)
from metagpt.exp_pool.schema import Metric, Score
class TestBaseContextBuilder:
class ConcreteContextBuilder(BaseContextBuilder):
async def build(self, *args, **kwargs):
pass
@pytest.fixture
def context_builder(self):
return self.ConcreteContextBuilder()
def test_format_exps(self, context_builder):
exp1 = Experience(req="req1", resp="resp1", metric=Metric(score=Score(val=8)))
exp2 = Experience(req="req2", resp="resp2", metric=Metric(score=Score(val=9)))
context_builder.exps = [exp1, exp2]
result = context_builder.format_exps()
expected = "\n".join(
[
f"1. {EXP_TEMPLATE.format(req='req1', resp='resp1', score=8)}",
f"2. {EXP_TEMPLATE.format(req='req2', resp='resp2', score=9)}",
]
)
assert result == expected
def test_replace_content_between_markers(self):
text = "Start\n# Example\nOld content\n# Instruction\nEnd"
new_content = "New content"
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd"
assert result == expected
def test_replace_content_between_markers_no_match(self):
text = "Start\nNo markers\nEnd"
new_content = "New content"
result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content)
assert result == text

View file

@ -0,0 +1,38 @@
import pytest
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder
class TestRoleZeroContextBuilder:
@pytest.fixture
def context_builder(self):
return RoleZeroContextBuilder()
@pytest.mark.asyncio
async def test_build_empty_req(self, context_builder):
result = await context_builder.build(req=[])
assert result == []
@pytest.mark.asyncio
async def test_build_no_experiences(self, context_builder, mocker):
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="")
req = [{"content": "Original content"}]
result = await context_builder.build(req=req)
assert result == req
@pytest.mark.asyncio
async def test_build_with_experiences(self, context_builder, mocker):
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="Formatted experiences")
mocker.patch.object(RoleZeroContextBuilder, "replace_example_content", return_value="Updated content")
req = [{"content": "Original content"}]
result = await context_builder.build(req=req)
assert result == [{"content": "Updated content"}]
def test_replace_example_content(self, context_builder, mocker):
mocker.patch.object(BaseContextBuilder, "replace_content_between_markers", return_value="Replaced content")
result = context_builder.replace_example_content("Original text", "New example content")
assert result == "Replaced content"
context_builder.replace_content_between_markers.assert_called_once_with(
"Original text", "# Example", "# Instruction", "New example content"
)

View file

@ -0,0 +1,46 @@
import pytest
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
from metagpt.exp_pool.context_builders.simple import (
SIMPLE_CONTEXT_TEMPLATE,
SimpleContextBuilder,
)
class TestSimpleContextBuilder:
@pytest.fixture
def context_builder(self):
return SimpleContextBuilder()
@pytest.mark.asyncio
async def test_build_with_experiences(self, context_builder, mocker):
# Mock the format_exps method
mock_exps = "Mocked experiences"
mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps)
req = "Test request"
result = await context_builder.build(req=req)
expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=mock_exps)
assert result == expected
@pytest.mark.asyncio
async def test_build_without_experiences(self, context_builder, mocker):
# Mock the format_exps method to return an empty string
mocker.patch.object(BaseContextBuilder, "format_exps", return_value="")
req = "Test request"
result = await context_builder.build(req=req)
assert result == req
@pytest.mark.asyncio
async def test_build_without_req(self, context_builder, mocker):
# Mock the format_exps method
mock_exps = "Mocked experiences"
mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps)
result = await context_builder.build()
expected = SIMPLE_CONTEXT_TEMPLATE.format(req="", exps=mock_exps)
assert result == expected

View file

@ -1,29 +1,17 @@
import asyncio
import inspect
import pytest
from metagpt.exp_pool.context_builders import SimpleContextBuilder
from metagpt.exp_pool.decorator import ExpCacheHandler, exp_cache
from metagpt.exp_pool.manager import ExperienceManager
from metagpt.exp_pool.perfect_judges import SimplePerfectJudge
from metagpt.exp_pool.schema import Experience, QueryType, Score
from metagpt.exp_pool.scorers import SimpleScorer
from metagpt.rag.engines import SimpleEngine
def for_test_function(a, b, c=None):
return a + b if c is None else a + b + c
class ForTestClass:
def for_test_method(self, x, y):
return x * y
@classmethod
def for_test_class_method(cls, x, y):
return x**y
class TestExpCache:
class TestExpCacheHandler:
@pytest.fixture
def mock_func(self, mocker):
return mocker.AsyncMock()
@ -34,7 +22,6 @@ class TestExpCache:
manager.storage = mocker.MagicMock(spec=SimpleEngine)
manager.query_exps = mocker.AsyncMock()
manager.create_exp = mocker.MagicMock()
manager.extract_one_perfect_exp = mocker.MagicMock()
return manager
@pytest.fixture
@ -44,174 +31,165 @@ class TestExpCache:
return scorer
@pytest.fixture
def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer):
def mock_perfect_judge(self, mocker):
return mocker.MagicMock(spec=SimplePerfectJudge)
@pytest.fixture
def mock_context_builder(self, mocker):
return mocker.MagicMock(spec=SimpleContextBuilder)
@pytest.fixture
def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer, mock_perfect_judge, mock_context_builder):
return ExpCacheHandler(
func=mock_func, args=(), kwargs={}, exp_manager=mock_exp_manager, exp_scorer=mock_scorer, pass_exps=False
func=mock_func,
args=(),
kwargs={"req": "test_req"},
exp_manager=mock_exp_manager,
exp_scorer=mock_scorer,
exp_perfect_judge=mock_perfect_judge,
context_builder=mock_context_builder,
)
@pytest.mark.asyncio
async def test_fetch_experiences(self, exp_cache_handler, mock_exp_manager):
await exp_cache_handler.fetch_experiences(QueryType.SEMANTIC)
mock_exp_manager.query_exps.assert_called_once()
mock_exp_manager.query_exps.return_value = [Experience(req="test_req", resp="test_resp")]
await exp_cache_handler.fetch_experiences()
mock_exp_manager.query_exps.assert_called_once_with(
"test_req", query_type=QueryType.SEMANTIC, tag=exp_cache_handler.tag
)
assert len(exp_cache_handler._exps) == 1
@pytest.mark.asyncio
async def test_perfect_experience_found(self, exp_cache_handler, mock_exp_manager, mock_func):
# Setup: Assume perfect experience is found
perfect_exp = Experience(req="req", resp="resp")
mock_exp_manager.extract_one_perfect_exp.return_value = perfect_exp
# Exec
exp_cache_handler._exps = [perfect_exp] # Simulate fetched experiences
result = exp_cache_handler.get_one_perfect_experience()
# Assert
assert result.resp == "resp"
mock_func.assert_not_called() # Function should not be called
async def test_get_one_perfect_exp(self, exp_cache_handler, mock_perfect_judge):
exp = Experience(req="test_req", resp="perfect_resp")
exp_cache_handler._exps = [exp]
mock_perfect_judge.is_perfect_exp.return_value = True
result = await exp_cache_handler.get_one_perfect_exp()
assert result == "perfect_resp"
@pytest.mark.asyncio
async def test_execute_function_when_no_perfect_exp(self, exp_cache_handler, mock_exp_manager, mock_func):
# Setup: No perfect experience
mock_exp_manager.extract_one_perfect_exp.return_value = None
mock_func.return_value = "Computed result"
# Exec
async def test_execute_function(self, exp_cache_handler, mock_func, mock_context_builder):
mock_context_builder.build.return_value = "built_context"
mock_func.return_value = "function_result"
await exp_cache_handler.execute_function()
# Assert
assert exp_cache_handler._result == "Computed result"
mock_func.assert_called_once()
mock_context_builder.build.assert_called_once()
mock_func.assert_called_once_with(req="built_context")
assert exp_cache_handler._raw_resp == "function_result"
assert exp_cache_handler._resp == "function_result"
@pytest.mark.asyncio
async def test_evaluate_and_save_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager):
# Setup
mock_scorer.evaluate.return_value = Score(value=100)
exp_cache_handler._result = "Computed result"
# Exec
await exp_cache_handler.evaluate_experience()
exp_cache_handler.save_experience()
# Assert
async def test_process_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager):
exp_cache_handler._resp = "test_resp"
mock_scorer.evaluate.return_value = Score(val=8)
await exp_cache_handler.process_experience()
mock_scorer.evaluate.assert_called_once()
mock_exp_manager.create_exp.assert_called_once()
@pytest.mark.asyncio
async def test_async_function_execution_with_exps(self, exp_cache_handler, mock_exp_manager, mock_func):
# Setup
exp_cache_handler.pass_exps_to_func = True
mock_func.return_value = "Async result with exps"
mock_exp_manager.extract_one_perfect_exp.return_value = None
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
async def test_evaluate_experience(self, exp_cache_handler, mock_scorer):
exp_cache_handler._resp = "test_resp"
mock_scorer.evaluate.return_value = Score(val=9)
await exp_cache_handler.evaluate_experience()
assert exp_cache_handler._score.val == 9
# Exec
await exp_cache_handler.execute_function()
# Assert
mock_func.assert_called_once_with(exps=exp_cache_handler._exps)
assert exp_cache_handler._result == "Async result with exps"
def test_sync_function_execution_with_exps(self, mocker, exp_cache_handler, mock_exp_manager, mock_func):
# Setup
exp_cache_handler.func = mocker.Mock(return_value="Sync result with exps")
exp_cache_handler.pass_exps_to_func = True
mock_exp_manager.extract_one_perfect_exp.return_value = None
exp_cache_handler._exps = [Experience(req="req", resp="resp")]
# Exec
asyncio.get_event_loop().run_until_complete(exp_cache_handler.execute_function())
# Assert
exp_cache_handler.func.assert_called_once_with(exps=exp_cache_handler._exps)
assert exp_cache_handler._result == "Sync result with exps"
def test_wrapper_selection_async(self, mocker, exp_cache_handler, mock_func):
# Setup
mock_func = mocker.AsyncMock()
# Exec
wrapper = ExpCacheHandler.choose_wrapper(mock_func, exp_cache_handler.execute_function)
# Assert
assert asyncio.iscoroutinefunction(wrapper), "Wrapper should be asynchronous"
def test_wrapper_selection_sync(self, exp_cache_handler, mocker):
# Setup
sync_func = mocker.Mock()
# Exec
wrapper = ExpCacheHandler.choose_wrapper(sync_func, exp_cache_handler.execute_function)
# Assert
assert not asyncio.iscoroutinefunction(wrapper), "Wrapper should be synchronous"
@pytest.mark.parametrize(
"func, args, kwargs, expected",
[
(for_test_function, (1, 2), {"c": 3}, 'for_test_function@[1~2]@{"c"!3}'),
(ForTestClass().for_test_method, (4, 5), {}, "ForTestClass.for_test_method@[4~5]@{}"),
(ForTestClass.for_test_class_method, (6, 7), {}, "ForTestClass.for_test_class_method@[6~7]@{}"),
(for_test_function, (), {}, "for_test_function@[]@{}"),
(
for_test_function,
("hello", [1, 2]),
{"key": "value"},
'for_test_function@["hello"~[1~2]]@{"key"!"value"}',
),
],
)
def test_generate_req_identifier(self, func, args, kwargs, expected):
req_identifier = ExpCacheHandler.generate_req_identifier(func, *args, **kwargs)
assert req_identifier == expected
@pytest.mark.asyncio
async def test_exp_cache_with_perfect_experience(self, mocker, mock_exp_manager):
# Mock perfect experience
perfect_exp = Experience(req="test_req", resp="perfect_response")
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[perfect_exp])
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=perfect_exp)
async_mock_func = mocker.AsyncMock()
# Setup
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager)
# Exec
result: Experience = await decorated_func()
# Assert
assert result.resp == "perfect_response", "Should return the perfect experience response"
async_mock_func.assert_not_called()
@pytest.mark.asyncio
async def test_exp_cache_without_perfect_experience(self, mocker, mock_exp_manager):
# Mock
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[])
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None)
async_mock_func = mocker.AsyncMock(return_value="computed_response")
async_mock_func.__signature__ = inspect.signature(for_test_function)
# Setup
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager)
# Exec
result = await decorated_func()
# Assert
assert result == "computed_response", "Should execute and return the function's response"
async_mock_func.assert_called_once()
@pytest.mark.asyncio
async def test_exp_cache_saves_new_experience(self, mocker, mock_exp_manager, mock_scorer):
# Mock
mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[])
mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None)
async_mock_func = mocker.AsyncMock(return_value="computed_response")
mock_scorer.evaluate = mocker.AsyncMock(return_value=Score(value=100))
# Setup
decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager, scorer=mock_scorer)
# Exec
await decorated_func()
# Assert
def test_save_experience(self, exp_cache_handler, mock_exp_manager):
exp_cache_handler._req = "test_req"
exp_cache_handler._resp = "test_resp"
exp_cache_handler._score = Score(val=7)
exp_cache_handler.save_experience()
mock_exp_manager.create_exp.assert_called_once()
def test_choose_wrapper_async(self, mocker):
async def async_func():
pass
wrapper = ExpCacheHandler.choose_wrapper(async_func, mocker.AsyncMock())
assert asyncio.iscoroutinefunction(wrapper)
def test_choose_wrapper_sync(self, mocker):
def sync_func():
pass
wrapper = ExpCacheHandler.choose_wrapper(sync_func, mocker.AsyncMock())
assert not asyncio.iscoroutinefunction(wrapper)
def test_validate_params(self):
with pytest.raises(ValueError):
ExpCacheHandler(func=lambda x: x, args=(), kwargs={})
def test_generate_tag(self):
class TestClass:
def test_method(self):
pass
handler = ExpCacheHandler(func=TestClass().test_method, args=(TestClass(),), kwargs={"req": "test"})
assert handler._generate_tag() == "TestClass.test_method"
handler = ExpCacheHandler(func=lambda x: x, args=(), kwargs={"req": "test"})
assert handler._generate_tag() == "<lambda>"
class TestExpCache:
@pytest.fixture
def mock_exp_manager(self, mocker):
manager = mocker.MagicMock(spec=ExperienceManager)
manager.storage = mocker.MagicMock(spec=SimpleEngine)
manager.query_exps = mocker.AsyncMock()
manager.create_exp = mocker.MagicMock()
return manager
@pytest.fixture
def mock_scorer(self, mocker):
scorer = mocker.MagicMock(spec=SimpleScorer)
scorer.evaluate = mocker.AsyncMock(return_value=Score())
return scorer
@pytest.fixture
def mock_perfect_judge(self, mocker):
return mocker.MagicMock(spec=SimplePerfectJudge)
@pytest.fixture
def mock_config(self, mocker):
return mocker.patch("metagpt.exp_pool.decorator.config")
@pytest.mark.asyncio
async def test_exp_cache_disabled(self, mock_config, mock_exp_manager):
mock_config.exp_pool.enable_read = False
@exp_cache(manager=mock_exp_manager)
async def test_func(req):
return "result"
result = await test_func(req="test")
assert result == "result"
mock_exp_manager.query_exps.assert_not_called()
@pytest.mark.asyncio
async def test_exp_cache_enabled_no_perfect_exp(self, mock_config, mock_exp_manager, mock_scorer):
mock_config.exp_pool.enable_read = True
mock_exp_manager.query_exps.return_value = []
@exp_cache(manager=mock_exp_manager, scorer=mock_scorer)
async def test_func(req):
return "computed_result"
result = await test_func(req="test")
assert result == "computed_result"
mock_exp_manager.query_exps.assert_called()
mock_exp_manager.create_exp.assert_called()
@pytest.mark.asyncio
async def test_exp_cache_enabled_with_perfect_exp(self, mock_config, mock_exp_manager, mock_perfect_judge):
mock_config.exp_pool.enable_read = True
perfect_exp = Experience(req="test", resp="perfect_result")
mock_exp_manager.query_exps.return_value = [perfect_exp]
mock_perfect_judge.is_perfect_exp.return_value = True
@exp_cache(manager=mock_exp_manager, perfect_judge=mock_perfect_judge)
async def test_func(req):
return "should_not_be_called"
result = await test_func(req="test")
assert result == "perfect_result"
mock_exp_manager.query_exps.assert_called_once()
mock_exp_manager.create_exp.assert_not_called()

View file

@ -4,20 +4,25 @@ from metagpt.config2 import Config
from metagpt.configs.exp_pool_config import ExperiencePoolConfig
from metagpt.configs.llm_config import LLMConfig
from metagpt.exp_pool.manager import ExperienceManager
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score
from metagpt.exp_pool.schema import Experience
from metagpt.rag.engines import SimpleEngine
class TestExperienceManager:
@pytest.fixture
def mock_config(self):
return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True))
return Config(
llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, init_exp=False)
)
@pytest.fixture
def mock_storage(self, mocker):
engine = mocker.MagicMock(spec=SimpleEngine)
engine.add_objs = mocker.MagicMock()
engine.aretrieve = mocker.AsyncMock(return_value=[])
engine._retriever = mocker.MagicMock()
engine._retriever._vector_store = mocker.MagicMock()
engine._retriever._vector_store._get = mocker.MagicMock(return_value=mocker.MagicMock(ids=[]))
return engine
@pytest.fixture
@ -33,7 +38,7 @@ class TestExperienceManager:
def test_create_exp(self, mock_experience_manager, mock_experience):
mock_experience_manager.create_exp(mock_experience)
mock_experience_manager.storage.add_objs.assert_called_once_with([mock_experience])
mock_experience_manager.storage.add_objs.assert_called_with([mock_experience])
def test_create_exp_write_disabled(self, mock_experience_manager, mock_experience, mock_config):
mock_config.exp_pool.enable_write = False
@ -60,18 +65,44 @@ class TestExperienceManager:
result = await mock_experience_manager.query_exps("query")
assert result == []
def test_extract_one_perfect_exp(self, mock_experience_manager):
experiences = [
Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))),
Experience(req="req", resp="resp"),
]
perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences)
assert perfect_exp is not None
assert perfect_exp.metric.score.val == MAX_SCORE
def test_init_exp_pool(self, mock_experience_manager, mock_config, mocker):
mock_experience_manager._has_exps = mocker.MagicMock(return_value=False)
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
def test_is_perfect_exp(self):
exp = Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
assert ExperienceManager.is_perfect_exp(exp) == True
mock_config.exp_pool.init_exp = True
mock_experience_manager.init_exp_pool()
exp = Experience(req="req", resp="resp")
assert ExperienceManager.is_perfect_exp(exp) == False
mock_experience_manager._has_exps.assert_called_once()
mock_experience_manager._init_teamleader_exps.assert_called_once()
mock_experience_manager._init_engineer2_exps.assert_called_once()
def test_init_exp_pool_already_has_exps(self, mock_experience_manager, mock_config, mocker):
mock_experience_manager._has_exps = mocker.MagicMock(return_value=True)
mock_experience_manager._init_teamleader_exps = mocker.MagicMock()
mock_experience_manager._init_engineer2_exps = mocker.MagicMock()
mock_config.exp_pool.init_exp = True
mock_experience_manager.init_exp_pool()
mock_experience_manager._has_exps.assert_called_once()
mock_experience_manager._init_teamleader_exps.assert_not_called()
mock_experience_manager._init_engineer2_exps.assert_not_called()
def test_has_exps(self, mock_experience_manager, mock_storage):
mock_storage._retriever._vector_store._get.return_value.ids = ["id1"]
assert mock_experience_manager._has_exps() is True
mock_storage._retriever._vector_store._get.return_value.ids = []
assert mock_experience_manager._has_exps() is False
def test_init_teamleader_exps(self, mock_experience_manager, mocker):
mock_experience_manager._init_exp = mocker.MagicMock()
mock_experience_manager._init_teamleader_exps()
mock_experience_manager._init_exp.assert_called_once()
def test_init_engineer2_exps(self, mock_experience_manager, mocker):
mock_experience_manager._init_exp = mocker.MagicMock()
mock_experience_manager._init_engineer2_exps()
mock_experience_manager._init_exp.assert_called_once()

View file

@ -0,0 +1,40 @@
import pytest
from metagpt.exp_pool.perfect_judges import SimplePerfectJudge
from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score
class TestSimplePerfectJudge:
@pytest.fixture
def simple_perfect_judge(self):
return SimplePerfectJudge()
@pytest.mark.asyncio
async def test_is_perfect_exp_perfect_match(self, simple_perfect_judge):
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
assert result is True
@pytest.mark.asyncio
async def test_is_perfect_exp_imperfect_score(self, simple_perfect_judge):
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE - 1)))
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
assert result is False
@pytest.mark.asyncio
async def test_is_perfect_exp_mismatched_request(self, simple_perfect_judge):
exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE)))
result = await simple_perfect_judge.is_perfect_exp(exp, "different_request")
assert result is False
@pytest.mark.asyncio
async def test_is_perfect_exp_no_metric(self, simple_perfect_judge):
exp = Experience(req="test_request", resp="resp")
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
assert result is False
@pytest.mark.asyncio
async def test_is_perfect_exp_no_score(self, simple_perfect_judge):
exp = Experience(req="test_request", resp="resp", metric=Metric())
result = await simple_perfect_judge.is_perfect_exp(exp, "test_request")
assert result is False

View file

@ -0,0 +1,49 @@
import pytest
from metagpt.exp_pool.schema import Score
from metagpt.exp_pool.scorers.simple import SIMPLE_SCORER_TEMPLATE, SimpleScorer
from metagpt.llm import BaseLLM
class TestSimpleScorer:
@pytest.fixture
def mock_llm(self, mocker):
mock_llm = mocker.MagicMock(spec=BaseLLM)
return mock_llm
@pytest.fixture
def simple_scorer(self, mock_llm):
return SimpleScorer(llm=mock_llm)
def test_init(self, mock_llm):
scorer = SimpleScorer(llm=mock_llm)
assert isinstance(scorer.llm, BaseLLM)
@pytest.mark.asyncio
async def test_evaluate(self, simple_scorer, mock_llm):
# Mock function to evaluate
def mock_func(a, b):
"""This is a mock function."""
return a + b
# Mock LLM response
mock_llm.aask.return_value = '```json\n{"val": 8, "reason": "Good performance"}\n```'
# Test evaluate method
result = await simple_scorer.evaluate(mock_func, 5, args=(2, 3), kwargs={})
# Assert LLM was called with correct prompt
expected_prompt = SIMPLE_SCORER_TEMPLATE.format(
func_name=mock_func.__name__,
func_doc=mock_func.__doc__,
func_signature="(a, b)",
func_args=(2, 3),
func_kwargs={},
func_result=5,
)
mock_llm.aask.assert_called_once_with(expected_prompt)
# Assert the result is correct
assert isinstance(result, Score)
assert result.val == 8
assert result.reason == "Good performance"