diff --git a/config/config2.example.yaml b/config/config2.example.yaml index c7b2cae2c..a3bd5c367 100644 --- a/config/config2.example.yaml +++ b/config/config2.example.yaml @@ -75,8 +75,10 @@ s3: bucket: "test" exp_pool: - enable_read: true - enable_write: true + enable_read: false + enable_write: false + persist_path: .chroma_exp_data # The directory. + init_exp: false # If set to true, basic experiences associated with the roles will be added to the experience pool. azure_tts_subscription_key: "YOUR_SUBSCRIPTION_KEY" azure_tts_region: "eastus" diff --git a/examples/exp_pool/decorator.py b/examples/exp_pool/decorator.py index 3f6093e01..00726a0a8 100644 --- a/examples/exp_pool/decorator.py +++ b/examples/exp_pool/decorator.py @@ -7,16 +7,15 @@ from metagpt.exp_pool import exp_cache, exp_manager from metagpt.logs import logger -@exp_cache(pass_exps_to_func=True) -async def produce(req, exps=None): - logger.info(f"Previous experiences: {exps}") +@exp_cache() +async def produce(req=""): return f"{req} {uuid.uuid4().hex}" async def main(): req = "Water" - resp = await produce(req) + resp = await produce(req=req) logger.info(f"The resp of `produce{req}` is: {resp}") exps = await exp_manager.query_exps(req) diff --git a/examples/write_novel.py b/examples/write_novel.py index a6e9ce05d..f49918fbb 100644 --- a/examples/write_novel.py +++ b/examples/write_novel.py @@ -50,9 +50,9 @@ async def generate_novel(): "Fill the empty nodes with your own ideas. Be creative! Use your own words!" "I will tip you $100,000 if you write a good novel." ) - novel_node = await ActionNode.from_pydantic(Novel).fill(context=instruction, llm=LLM()) + novel_node = await ActionNode.from_pydantic(Novel).fill(req=instruction, llm=LLM()) chap_node = await ActionNode.from_pydantic(Chapters).fill( - context=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM() + req=f"### instruction\n{instruction}\n### novel\n{novel_node.content}", llm=LLM() ) print(chap_node.instruct_content) diff --git a/metagpt/actions/action.py b/metagpt/actions/action.py index b760c96d8..8733947f5 100644 --- a/metagpt/actions/action.py +++ b/metagpt/actions/action.py @@ -90,7 +90,7 @@ class Action(SerializationMixin, ContextMixin, BaseModel): msgs = args[0] context = "## History Messages\n" context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))]) - return await self.node.fill(context=context, llm=self.llm) + return await self.node.fill(req=context, llm=self.llm) async def run(self, *args, **kwargs): """Run action""" diff --git a/metagpt/actions/action_node.py b/metagpt/actions/action_node.py index 48372f790..e1e0bddbb 100644 --- a/metagpt/actions/action_node.py +++ b/metagpt/actions/action_node.py @@ -18,6 +18,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential from metagpt.actions.action_outcls_registry import register_action_outcls from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT +from metagpt.exp_pool import exp_cache from metagpt.llm import BaseLLM from metagpt.logs import logger from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess @@ -465,9 +466,33 @@ class ActionNode: return self + @classmethod + def deserialize_to_action_node(cls, serialized_data) -> "ActionNode": + """Customized deserialization, it will be triggered when a perfect experience is found. + + ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'. + """ + + class InstructContent: + def __init__(self, json_data): + self.json_data = json_data + + def model_dump_json(self): + return self.json_data + + action_node = cls(key="", expected_type=Type[str], instruction="", example="") + action_node.instruct_content = InstructContent(serialized_data) + + return action_node + + @exp_cache( + resp_serialize=lambda action_node: action_node.instruct_content.model_dump_json(), + resp_deserialize=lambda resp: ActionNode.deserialize_to_action_node(resp), + ) async def fill( self, - context, + *, + req, llm, schema="json", mode="auto", @@ -478,7 +503,7 @@ class ActionNode: ): """Fill the node(s) with mode. - :param context: Everything we should know when filling node. + :param req: Everything we should know when filling node. :param llm: Large Language Model with pre-defined system message. :param schema: json/markdown, determine example and output format. - raw: free form text @@ -497,7 +522,7 @@ class ActionNode: :return: self """ self.set_llm(llm) - self.set_context(context) + self.set_context(req) if self.schema: schema = self.schema diff --git a/metagpt/actions/design_api.py b/metagpt/actions/design_api.py index cc88171ff..1bfad20a2 100644 --- a/metagpt/actions/design_api.py +++ b/metagpt/actions/design_api.py @@ -178,12 +178,12 @@ class WriteDesign(Action): ) async def _new_system_design(self, context): - node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema) + node = await DESIGN_API_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) return node async def _merge(self, prd_doc, system_design_doc): context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content) - node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema) + node = await REFINED_DESIGN_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) system_design_doc.content = node.instruct_content.model_dump_json() return system_design_doc diff --git a/metagpt/actions/generate_questions.py b/metagpt/actions/generate_questions.py index c96a37649..bf0ba6277 100644 --- a/metagpt/actions/generate_questions.py +++ b/metagpt/actions/generate_questions.py @@ -22,4 +22,4 @@ class GenerateQuestions(Action): name: str = "GenerateQuestions" async def run(self, context) -> ActionNode: - return await QUESTIONS.fill(context=context, llm=self.llm) + return await QUESTIONS.fill(req=context, llm=self.llm) diff --git a/metagpt/actions/prepare_interview.py b/metagpt/actions/prepare_interview.py index 04cc954d2..0a7eb6581 100644 --- a/metagpt/actions/prepare_interview.py +++ b/metagpt/actions/prepare_interview.py @@ -22,4 +22,4 @@ class PrepareInterview(Action): name: str = "PrepareInterview" async def run(self, context): - return await QUESTIONS.fill(context=context, llm=self.llm) + return await QUESTIONS.fill(req=context, llm=self.llm) diff --git a/metagpt/actions/project_management.py b/metagpt/actions/project_management.py index a39840bf1..ca2df2da9 100644 --- a/metagpt/actions/project_management.py +++ b/metagpt/actions/project_management.py @@ -151,12 +151,12 @@ class WriteTasks(Action): return task_doc async def _run_new_tasks(self, context: str): - node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema) + node = await PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) return node async def _merge(self, system_design_doc, task_doc) -> Document: context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content) - node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema) + node = await REFINED_PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema) task_doc.content = node.instruct_content.model_dump_json() return task_doc diff --git a/metagpt/actions/write_code_an_draft.py b/metagpt/actions/write_code_an_draft.py index ce030b0e9..4c3fd4c19 100644 --- a/metagpt/actions/write_code_an_draft.py +++ b/metagpt/actions/write_code_an_draft.py @@ -578,7 +578,7 @@ class WriteCodeAN(Action): async def run(self, context): self.llm.system_prompt = "You are an outstanding engineer and can implement any code" - return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json") + return await WRITE_MOVE_NODE.fill(req=context, llm=self.llm, schema="json") async def main(): diff --git a/metagpt/actions/write_code_plan_and_change_an.py b/metagpt/actions/write_code_plan_and_change_an.py index 31482a94d..989df52f2 100644 --- a/metagpt/actions/write_code_plan_and_change_an.py +++ b/metagpt/actions/write_code_plan_and_change_an.py @@ -229,7 +229,7 @@ class WriteCodePlanAndChange(Action): code=await self.get_old_codes(), ) logger.info("Writing code plan and change..") - return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json") + return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(req=context, llm=self.llm, schema="json") async def get_old_codes(self) -> str: old_codes = await self.repo.srcs.get_all() diff --git a/metagpt/actions/write_prd.py b/metagpt/actions/write_prd.py index 7199ec415..810823a24 100644 --- a/metagpt/actions/write_prd.py +++ b/metagpt/actions/write_prd.py @@ -211,7 +211,7 @@ class WritePRD(Action): context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name) exclude = [PROJECT_NAME.key] if project_name else [] node = await WRITE_PRD_NODE.fill( - context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema + req=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema ) # schema=schema return node @@ -238,7 +238,7 @@ class WritePRD(Action): async def _is_bugfix(self, context: str) -> bool: if not self.repo.code_files_exists(): return False - node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm) + node = await WP_ISSUE_TYPE_NODE.fill(req=context, llm=self.llm) return node.get("issue_type") == "BUG" async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]: @@ -248,14 +248,14 @@ class WritePRD(Action): async def _is_related(self, req: Document, old_prd: Document) -> bool: context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content) - node = await WP_IS_RELATIVE_NODE.fill(context, self.llm) + node = await WP_IS_RELATIVE_NODE.fill(req=context, llm=self.llm) return node.get("is_relative") == "YES" async def _merge(self, req: Document, related_doc: Document) -> Document: if not self.project_name: self.project_name = Path(self.project_path).name prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content) - node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema) + node = await REFINED_PRD_NODE.fill(req=prompt, llm=self.llm, schema=self.prompt_schema) related_doc.content = node.instruct_content.model_dump_json() await self._rename_workspace(node) return related_doc diff --git a/metagpt/actions/write_review.py b/metagpt/actions/write_review.py index db8512946..907a1e990 100644 --- a/metagpt/actions/write_review.py +++ b/metagpt/actions/write_review.py @@ -36,4 +36,4 @@ class WriteReview(Action): name: str = "WriteReview" async def run(self, context): - return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json") + return await WRITE_REVIEW_NODE.fill(req=context, llm=self.llm, schema="json") diff --git a/metagpt/configs/exp_pool_config.py b/metagpt/configs/exp_pool_config.py index 3f86173c1..0c92312da 100644 --- a/metagpt/configs/exp_pool_config.py +++ b/metagpt/configs/exp_pool_config.py @@ -4,5 +4,9 @@ from metagpt.utils.yaml_model import YamlModel class ExperiencePoolConfig(YamlModel): - enable_read: bool = Field(default=True, description="Enable to read from experience pool.") - enable_write: bool = Field(default=True, description="Enable to write to experience pool.") + enable_read: bool = Field(default=False, description="Enable to read from experience pool.") + enable_write: bool = Field(default=False, description="Enable to write to experience pool.") + persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.") + init_exp: bool = Field( + default=False, description="Put some basic experiences associated with the roles into the experience pool." + ) diff --git a/metagpt/exp_pool/context_builders/__init__.py b/metagpt/exp_pool/context_builders/__init__.py new file mode 100644 index 000000000..047558be0 --- /dev/null +++ b/metagpt/exp_pool/context_builders/__init__.py @@ -0,0 +1,7 @@ +"""Context builders init.""" + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder +from metagpt.exp_pool.context_builders.simple import SimpleContextBuilder +from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder + +__all__ = ["BaseContextBuilder", "SimpleContextBuilder", "RoleZeroContextBuilder"] diff --git a/metagpt/exp_pool/context_builders/base.py b/metagpt/exp_pool/context_builders/base.py new file mode 100644 index 000000000..e3fe320a6 --- /dev/null +++ b/metagpt/exp_pool/context_builders/base.py @@ -0,0 +1,52 @@ +"""Base context builder.""" + +import re +from abc import ABC, abstractmethod +from typing import Any + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Experience + +EXP_TEMPLATE = """Given the request: {req}, We can get the response: {resp}, Which scored: {score}.""" + + +class BaseContextBuilder(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + exps: list[Experience] = [] + + @abstractmethod + async def build(self, *args, **kwargs) -> Any: + """Build context from parameters.""" + + def format_exps(self) -> str: + """Format experiences into a numbered list of strings.""" + + result = [] + for i, exp in enumerate(self.exps, start=1): + result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=exp.metric.score.val)) + + return "\n".join(result) + + @staticmethod + def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str: + """Replace the content between `start_marker` and `end_marker` in the text with `new_content`. + + Args: + text (str): The original text. + new_content (str): The new content to replace the old content. + start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'. + end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'. + + Returns: + str: The text with the content replaced. + """ + + pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL) + + def replacement(match): + return f"{match.group(1)}{new_content}\n{match.group(3)}" + + replaced_text = pattern.sub(replacement, text) + return replaced_text diff --git a/metagpt/exp_pool/context_builders/role_zero.py b/metagpt/exp_pool/context_builders/role_zero.py new file mode 100644 index 000000000..60f71ef59 --- /dev/null +++ b/metagpt/exp_pool/context_builders/role_zero.py @@ -0,0 +1,26 @@ +"""RoleZero context builder.""" + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + + +class RoleZeroContextBuilder(BaseContextBuilder): + async def build(self, *args, **kwargs) -> list[dict]: + """Builds the context by updating the req with formatted experiences. + + If there are no experiences, retains the original examples in req, otherwise replaces the examples with the formatted experiences. + """ + + req = kwargs.get("req", []) + if not req: + return req + + exps_str = self.format_exps() + if not exps_str: + return req + + req[-1]["content"] = self.replace_example_content(req[-1].get("content", ""), exps_str) + + return req + + def replace_example_content(self, text: str, new_example_content: str) -> str: + return self.replace_content_between_markers(text, "# Example", "# Instruction", new_example_content) diff --git a/metagpt/exp_pool/context_builders/simple.py b/metagpt/exp_pool/context_builders/simple.py new file mode 100644 index 000000000..35e2e1c8a --- /dev/null +++ b/metagpt/exp_pool/context_builders/simple.py @@ -0,0 +1,24 @@ +"""Simple context builder.""" + + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder + +SIMPLE_CONTEXT_TEMPLATE = """ +{req} + +### Experiences +----- +{exps} +----- + +## Instruction +Consider **Experiences** to generate a better answer. +""" + + +class SimpleContextBuilder(BaseContextBuilder): + async def build(self, *args, **kwargs) -> str: + req = kwargs.get("req", "") + exps = self.format_exps() + + return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req diff --git a/metagpt/exp_pool/decorator.py b/metagpt/exp_pool/decorator.py index 2a3bf2fba..c518bb7ea 100644 --- a/metagpt/exp_pool/decorator.py +++ b/metagpt/exp_pool/decorator.py @@ -2,18 +2,19 @@ import asyncio import functools -import inspect -import json from typing import Any, Callable, Optional, TypeVar from pydantic import BaseModel, ConfigDict, model_validator +from metagpt.config2 import config +from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder from metagpt.exp_pool.manager import ExperienceManager, exp_manager +from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score -from metagpt.exp_pool.scorers import ExperienceScorer, SimpleScorer +from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer +from metagpt.logs import logger from metagpt.utils.async_helper import NestAsyncio from metagpt.utils.exceptions import handle_exception -from metagpt.utils.reflection import get_class_name ReturnType = TypeVar("ReturnType") @@ -21,42 +22,64 @@ ReturnType = TypeVar("ReturnType") def exp_cache( _func: Optional[Callable[..., ReturnType]] = None, query_type: QueryType = QueryType.SEMANTIC, - scorer: Optional[ExperienceScorer] = None, manager: Optional[ExperienceManager] = None, - pass_exps_to_func: bool = False, + scorer: Optional[BaseScorer] = None, + perfect_judge: Optional[BasePerfectJudge] = None, + context_builder: Optional[BaseContextBuilder] = None, + req_serialize: Optional[Callable[..., str]] = None, + resp_serialize: Optional[Callable[..., str]] = None, + resp_deserialize: Optional[Callable[[str], Any]] = None, + tag: Optional[str] = None, ): """Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience. - This can be applied to both synchronous and asynchronous functions. + 1. This can be applied to both synchronous and asynchronous functions. + 2. The function must have a `req` parameter, and it must be provided as a keyword argument. + 3. If `config.exp_pool.enable_read` is False, the decorator will just directly execute the function. Args: _func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache(). query_type: The type of query to be used when fetching experiences. - scorer: Evaluate experience. Default SimpleScorer. - manager: How to fetch, evaluate and save experience, etc. Default exp_manager. - pass_exps_to_func: To control whether imperfect experiences are passed to the function, if True, the func must have a parameter named 'exps'. + manager: How to fetch, evaluate and save experience, etc. Default to `exp_manager`. + scorer: Evaluate experience. Default to `SimpleScorer()`. + perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`. + context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`. + req_serialize: Serializes the request for storage. Defaults to `lambda req: str(req)`. + resp_serialize: Serializes the function's return value for storage. Defaults to `lambda resp: str(resp)`. + resp_deserialize: Deserializes the stored response back to the function's return value. Defaults to `lambda resp: resp`. + tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`. """ def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]: + if not config.exp_pool.enable_read: + return func + @functools.wraps(func) async def get_or_create(args: Any, kwargs: Any) -> ReturnType: + logger.info("exp_cache is enabled.") handler = ExpCacheHandler( func=func, args=args, kwargs=kwargs, + query_type=query_type, exp_manager=manager, exp_scorer=scorer, - pass_exps_to_func=pass_exps_to_func, + exp_perfect_judge=perfect_judge, + context_builder=context_builder, + req_serialize=req_serialize, + resp_serialize=resp_serialize, + resp_deserialize=resp_deserialize, + tag=tag, ) - await handler.fetch_experiences(query_type) - if exp := handler.get_one_perfect_experience(): + await handler.fetch_experiences() + if exp := await handler.get_one_perfect_exp(): return exp await handler.execute_function() await handler.process_experience() - return handler._result + return handler._raw_resp return ExpCacheHandler.choose_wrapper(func, get_or_create) @@ -69,39 +92,59 @@ class ExpCacheHandler(BaseModel): func: Callable args: Any kwargs: Any + query_type: QueryType = QueryType.SEMANTIC exp_manager: Optional[ExperienceManager] = None - exp_scorer: Optional[ExperienceScorer] = None - pass_exps_to_func: bool = False + exp_scorer: Optional[BaseScorer] = None + exp_perfect_judge: Optional[BasePerfectJudge] = None + context_builder: Optional[BaseContextBuilder] = None + req_serialize: Optional[Callable[..., str]] = None + resp_serialize: Optional[Callable[..., str]] = None + resp_deserialize: Optional[Callable[[str], Any]] = None + tag: Optional[str] = None _exps: list[Experience] = None - _result: Any = None + _req: str = "" + _resp: str = "" + _raw_resp: Any = None _score: Score = None - _req: str = None @model_validator(mode="after") def initialize(self): - if self.exp_manager is None: - self.exp_manager = exp_manager + self._validate_params() - if self.exp_scorer is None: - self.exp_scorer = SimpleScorer() + self.exp_manager = self.exp_manager or exp_manager + self.exp_scorer = self.exp_scorer or SimpleScorer() + self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge() + self.context_builder = self.context_builder or SimpleContextBuilder() + self.req_serialize = self.req_serialize or (lambda resp: str(resp)) + self.resp_serialize = self.resp_serialize or (lambda resp: str(resp)) + self.resp_deserialize = self.resp_deserialize or (lambda resp: resp) + self.tag = self.tag or self._generate_tag() - self._req = self.generate_req_identifier(self.func, *self.args, **self.kwargs) + self._req = self.req_serialize(self.kwargs["req"]) return self - async def fetch_experiences(self, query_type: QueryType): + async def fetch_experiences(self): """Fetch experiences by query_type.""" - self._exps = await self.exp_manager.query_exps(self._req, query_type=query_type) + self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag) - def get_one_perfect_experience(self) -> Optional[Experience]: - """Get a potentially perfect experience.""" - return self.exp_manager.extract_one_perfect_exp(self._exps) + async def get_one_perfect_exp(self) -> Optional[Any]: + """Get a potentially perfect experience, and resolve resp.""" + + for exp in self._exps: + if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs): + logger.info(f"Get one perfect experience: {exp.req[:20]}...") + return self.resp_deserialize(exp.resp) + + return None async def execute_function(self): - """Execute the function, and save the result.""" - self._result = await self._execute_function() + """Execute the function, and save resp.""" + + self._raw_resp = await self._execute_function() + self._resp = self.resp_serialize(self._raw_resp) @handle_exception async def process_experience(self): @@ -110,41 +153,21 @@ class ExpCacheHandler(BaseModel): Evaluates and saves experience. Use `handle_exception` to ensure robustness, do not stop subsequent operations. """ + await self.evaluate_experience() self.save_experience() async def evaluate_experience(self): """Evaluate the experience, and save the score.""" - self._score = await self.exp_scorer.evaluate(self.func, self._result, self.args, self.kwargs) + self._score = await self.exp_scorer.evaluate(self.func, self._resp, self.args, self.kwargs) def save_experience(self): """Save the new experience.""" - exp = Experience(req=self._req, resp=self._result, metric=Metric(score=self._score)) - + exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score)) self.exp_manager.create_exp(exp) - @classmethod - def generate_req_identifier(cls, func, *args, **kwargs) -> str: - """Generate a unique request identifier for any given function and its arguments. - - Serializing args and kwargs into JSON strings and replacing ',' with '~' and ':' with '!'. - - Return Example: - SimpleClass.test_method@[1~2]@{"c"!3} - """ - cls_name = get_class_name(func) - func_name = f"{cls_name}.{func.__name__}" if cls_name else func.__name__ - - if cls_name and args and inspect.isfunction(func): - args = args[1:] - - args = cls._serialize_and_replace(args) - kwargs = cls._serialize_and_replace(kwargs) - - return f"{func_name}@{args}@{kwargs}" - @staticmethod def choose_wrapper(func, wrapped_func): """Choose how to run wrapped_func based on whether the function is asynchronous.""" @@ -158,25 +181,31 @@ class ExpCacheHandler(BaseModel): return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper - @classmethod - def _serialize_and_replace(cls, data): - json_str = json.dumps(data) - return json_str.replace(", ", "~").replace(": ", "!") + def _validate_params(self): + if "req" not in self.kwargs: + raise ValueError("`req` must be provided as a keyword argument.") + + def _generate_tag(self) -> str: + """Generates a tag for the self.func. + + "ClassName.method_name" if the first argument is a class instance, otherwise just "function_name". + """ + + if self.args and hasattr(self.args[0], "__class__"): + cls_name = type(self.args[0]).__name__ + return f"{cls_name}.{self.func.__name__}" + + return self.func.__name__ + + async def _build_context(self) -> str: + self.context_builder.exps = self._exps + + return await self.context_builder.build(*self.args, **self.kwargs) async def _execute_function(self): - if self.pass_exps_to_func: - return await self._execute_function_with_exps() + self.kwargs["req"] = await self._build_context() - return await self._execute_function_without_exps() - - async def _execute_function_without_exps(self): if asyncio.iscoroutinefunction(self.func): return await self.func(*self.args, **self.kwargs) return self.func(*self.args, **self.kwargs) - - async def _execute_function_with_exps(self): - if asyncio.iscoroutinefunction(self.func): - return await self.func(*self.args, **self.kwargs, exps=self._exps) - - return self.func(*self.args, **self.kwargs, exps=self._exps) diff --git a/metagpt/exp_pool/manager.py b/metagpt/exp_pool/manager.py index 35ee5fdac..276b1e8e3 100644 --- a/metagpt/exp_pool/manager.py +++ b/metagpt/exp_pool/manager.py @@ -1,13 +1,22 @@ """Experience Manager.""" -from typing import Optional - +from llama_index.vector_stores.chroma import ChromaVectorStore from pydantic import BaseModel, ConfigDict, model_validator from metagpt.config2 import Config, config -from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType +from metagpt.exp_pool.schema import ( + DEFAULT_COLLECTION_NAME, + DEFAULT_SIMILARITY_TOP_K, + EntryType, + Experience, + Metric, + QueryType, + Score, +) +from metagpt.logs import logger from metagpt.rag.engines import SimpleEngine from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig +from metagpt.strategy.experience_retriever import ENGINEER_EXAMPLE, TL_EXAMPLE from metagpt.utils.exceptions import handle_exception @@ -27,14 +36,33 @@ class ExperienceManager(BaseModel): @model_validator(mode="after") def initialize(self): if self.storage is None: - self.storage = SimpleEngine.from_objs( - retriever_configs=[ - ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data") - ], - ranker_configs=[LLMRankerConfig()], - ) + retriever_configs = [ + ChromaRetrieverConfig( + persist_path=self.config.exp_pool.persist_path, + collection_name=DEFAULT_COLLECTION_NAME, + similarity_top_k=DEFAULT_SIMILARITY_TOP_K, + ) + ] + ranker_configs = [LLMRankerConfig()] + + self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs) + + self.init_exp_pool() + return self + @handle_exception + def init_exp_pool(self): + if not self.config.exp_pool.init_exp: + return + + if self._has_exps(): + return + + self._init_teamleader_exps() + self._init_engineer2_exps() + logger.info("`init_exp_pool` done.") + @handle_exception def create_exp(self, exp: Experience): """Adds an experience to the storage if writing is enabled. @@ -74,39 +102,26 @@ class ExperienceManager(BaseModel): return exps - def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]: - """Extracts the first 'perfect' experience from a list of experiences. + def _has_exps(self) -> bool: + vector_store: ChromaVectorStore = self.storage._retriever._vector_store - Args: - exps (list[Experience]): The experiences to evaluate. + return bool(vector_store._get(limit=1, where={}).ids) - Returns: - Optional[Experience]: The first perfect experience if found, otherwise None. - """ - for exp in exps: - if self.is_perfect_exp(exp): - return exp + def _init_exp(self, req: str, resp: str, tag: str, metric: Metric = None): + exp = Experience( + req=req, + resp=resp, + entry_type=EntryType.MANUAL, + tag=tag, + metric=metric or Metric(score=Score(val=9, reason="Manual")), + ) + self.create_exp(exp) - return None + def _init_teamleader_exps(self): + self._init_exp(req=TL_EXAMPLE, resp=TL_EXAMPLE, tag="TeamLeader.llm_cached_aask") - @staticmethod - def is_perfect_exp(exp: Experience) -> bool: - """Determines if an experience is considered 'perfect'. - - Args: - exp (Experience): The experience to evaluate. - - Returns: - bool: True if the experience is manually entered, otherwise False. - """ - if not exp: - return False - - # TODO: need more metrics - if exp.metric and exp.metric.score.val == MAX_SCORE: - return True - - return False + def _init_engineer2_exps(self): + self._init_exp(req=ENGINEER_EXAMPLE, resp=ENGINEER_EXAMPLE, tag="Engineer2.llm_cached_aask") exp_manager = ExperienceManager() diff --git a/metagpt/exp_pool/perfect_judges/__init__.py b/metagpt/exp_pool/perfect_judges/__init__.py new file mode 100644 index 000000000..d8796c7c8 --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/__init__.py @@ -0,0 +1,6 @@ +"""Perfect judges init.""" + +from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge +from metagpt.exp_pool.perfect_judges.simple import SimplePerfectJudge + +__all__ = ["BasePerfectJudge", "SimplePerfectJudge"] diff --git a/metagpt/exp_pool/perfect_judges/base.py b/metagpt/exp_pool/perfect_judges/base.py new file mode 100644 index 000000000..293522993 --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/base.py @@ -0,0 +1,20 @@ +"""Base perfect judge.""" + +from abc import ABC, abstractmethod + +from pydantic import BaseModel, ConfigDict + +from metagpt.exp_pool.schema import Experience + + +class BasePerfectJudge(BaseModel, ABC): + model_config = ConfigDict(arbitrary_types_allowed=True) + + @abstractmethod + async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool: + """Determine whether the experience is perfect. + + Args: + exp (Experience): The experience to evaluate. + serialized_req (str): The serialized request to compare against the experience's request. + """ diff --git a/metagpt/exp_pool/perfect_judges/simple.py b/metagpt/exp_pool/perfect_judges/simple.py new file mode 100644 index 000000000..37ede95c3 --- /dev/null +++ b/metagpt/exp_pool/perfect_judges/simple.py @@ -0,0 +1,27 @@ +"""Simple perfect judge.""" + + +from pydantic import ConfigDict + +from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge +from metagpt.exp_pool.schema import MAX_SCORE, Experience + + +class SimplePerfectJudge(BasePerfectJudge): + model_config = ConfigDict(arbitrary_types_allowed=True) + + async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool: + """Determine whether the experience is perfect. + + Args: + exp (Experience): The experience to evaluate. + serialized_req (str): The serialized request to compare against the experience's request. + + Returns: + bool: True if the serialized request matches the experience's request and the experience's score is perfect, False otherwise. + """ + + if not exp.metric or not exp.metric.score: + return False + + return serialized_req == exp.req and exp.metric.score.val == MAX_SCORE diff --git a/metagpt/exp_pool/schema.py b/metagpt/exp_pool/schema.py index 9fc665cca..d59478742 100644 --- a/metagpt/exp_pool/schema.py +++ b/metagpt/exp_pool/schema.py @@ -1,13 +1,16 @@ """Experience schema.""" from enum import Enum -from typing import Any, Optional +from typing import Optional from llama_index.core.schema import TextNode from pydantic import BaseModel, Field MAX_SCORE = 10 +DEFAULT_COLLECTION_NAME = "experience_pool" +DEFAULT_SIMILARITY_TOP_K = 2 + class QueryType(str, Enum): """Type of query experiences.""" @@ -59,7 +62,7 @@ class Experience(BaseModel): """Experience.""" req: str = Field(..., description="") - resp: Any = Field(..., description="The type is string/json/code.") + resp: str = Field(..., description="The type is string/json/code.") metric: Optional[Metric] = Field(default=None, description="Metric.") exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.") entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.") diff --git a/metagpt/exp_pool/scorers/__init__.py b/metagpt/exp_pool/scorers/__init__.py index 85bea88ff..caa845b14 100644 --- a/metagpt/exp_pool/scorers/__init__.py +++ b/metagpt/exp_pool/scorers/__init__.py @@ -1,6 +1,6 @@ -"""Experience scorers init.""" +"""Scorers init.""" -from metagpt.exp_pool.scorers.base import ExperienceScorer +from metagpt.exp_pool.scorers.base import BaseScorer from metagpt.exp_pool.scorers.simple import SimpleScorer -__all__ = ["ExperienceScorer", "SimpleScorer"] +__all__ = ["BaseScorer", "SimpleScorer"] diff --git a/metagpt/exp_pool/scorers/base.py b/metagpt/exp_pool/scorers/base.py index a9d30cffe..94623c30f 100644 --- a/metagpt/exp_pool/scorers/base.py +++ b/metagpt/exp_pool/scorers/base.py @@ -1,6 +1,6 @@ -"""Experience Scorers.""" +"""Base scorer.""" -from abc import abstractmethod +from abc import ABC, abstractmethod from typing import Any, Callable from pydantic import BaseModel, ConfigDict @@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict from metagpt.exp_pool.schema import Score -class ExperienceScorer(BaseModel): +class BaseScorer(BaseModel, ABC): model_config = ConfigDict(arbitrary_types_allowed=True) @abstractmethod diff --git a/metagpt/exp_pool/scorers/simple.py b/metagpt/exp_pool/scorers/simple.py index 84995b60f..1fda189d1 100644 --- a/metagpt/exp_pool/scorers/simple.py +++ b/metagpt/exp_pool/scorers/simple.py @@ -1,4 +1,4 @@ -"""Simple Scorer.""" +"""Simple scorer.""" import inspect import json @@ -7,7 +7,7 @@ from typing import Any, Callable from pydantic import Field from metagpt.exp_pool.schema import Score -from metagpt.exp_pool.scorers.base import ExperienceScorer +from metagpt.exp_pool.scorers.base import BaseScorer from metagpt.llm import LLM from metagpt.provider.base_llm import BaseLLM from metagpt.utils.common import CodeParser @@ -54,7 +54,7 @@ Follow instructions, generate output and make sure it follows the **Constraint** """ -class SimpleScorer(ExperienceScorer): +class SimpleScorer(BaseScorer): llm: BaseLLM = Field(default_factory=LLM) async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score: diff --git a/metagpt/roles/di/role_zero.py b/metagpt/roles/di/role_zero.py index 906c5583c..e2a4cec78 100644 --- a/metagpt/roles/di/role_zero.py +++ b/metagpt/roles/di/role_zero.py @@ -1,5 +1,6 @@ from __future__ import annotations +import copy import inspect import json import re @@ -10,8 +11,14 @@ from pydantic import model_validator from metagpt.actions import Action from metagpt.actions.di.run_command import RunCommand +from metagpt.exp_pool import exp_cache +from metagpt.exp_pool.context_builders import RoleZeroContextBuilder from metagpt.logs import logger -from metagpt.prompts.di.role_zero import CMD_PROMPT, ROLE_INSTRUCTION, JSON_REPAIR_PROMPT +from metagpt.prompts.di.role_zero import ( + CMD_PROMPT, + JSON_REPAIR_PROMPT, + ROLE_INSTRUCTION, +) from metagpt.roles import Role from metagpt.schema import AIMessage, Message, UserMessage from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever @@ -21,8 +28,8 @@ from metagpt.tools.libs.editor import Editor from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender from metagpt.tools.tool_registry import register_tool from metagpt.utils.common import CodeParser +from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output from metagpt.utils.report import ThoughtReporter -from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType @register_tool(include_functions=["ask_human", "reply_to_human"]) @@ -154,11 +161,37 @@ class RoleZero(Role): context = self.llm.format_msg(memory + [UserMessage(content=prompt)]) # print(*context, sep="\n" + "*" * 5 + "\n") async with ThoughtReporter(enable_llm_stream=True): - self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg) + self.command_rsp = await self.llm_cached_aask(req=context, system_msgs=self.system_msg) self.rc.memory.add(AIMessage(content=self.command_rsp)) return True + @exp_cache(context_builder=RoleZeroContextBuilder(), req_serialize=lambda req: RoleZero._req_serialize(req)) + async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str]) -> str: + return await self.llm.aask(req, system_msgs=system_msgs) + + @staticmethod + def _req_serialize(req: list[dict]) -> str: + """Serialize the request for database storage, ensuring it is a string. + + This function deep copies the request and modifies the content of the last element + to remove unnecessary sections, making the request more concise. + """ + + req_copy = copy.deepcopy(req) + + last_content = req_copy[-1]["content"] + last_content = RoleZeroContextBuilder.replace_content_between_markers( + last_content, "# Data Structure", "# Current Plan", "" + ) + last_content = RoleZeroContextBuilder.replace_content_between_markers( + last_content, "# Example", "# Instruction", "" + ) + + req_copy[-1]["content"] = last_content + + return json.dumps(req_copy) + async def _act(self) -> Message: if self.use_fixed_sop: return await super()._act() @@ -166,7 +199,7 @@ class RoleZero(Role): try: commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp) commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON)) - except json.JSONDecodeError as e: + except json.JSONDecodeError: commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp)) commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands)) except Exception as e: diff --git a/metagpt/strategy/solver.py b/metagpt/strategy/solver.py index e532f736b..4aedb42aa 100644 --- a/metagpt/strategy/solver.py +++ b/metagpt/strategy/solver.py @@ -39,7 +39,7 @@ class NaiveSolver(BaseSolver): self.graph.topological_sort() for key in self.graph.execution_order: op = self.graph.nodes[key] - await op.fill(self.context, self.llm, mode="root") + await op.fill(req=self.context, llm=self.llm, mode="root") class TOTSolver(BaseSolver): diff --git a/metagpt/utils/reflection.py b/metagpt/utils/reflection.py index fe852635f..8b8237ae7 100644 --- a/metagpt/utils/reflection.py +++ b/metagpt/utils/reflection.py @@ -1,5 +1,4 @@ """class tools, including method inspection, class attributes, inheritance relationships, etc.""" -import inspect def check_methods(C, *methods): @@ -17,25 +16,3 @@ def check_methods(C, *methods): else: return NotImplemented return True - - -def get_class_name(func) -> str: - """Returns the class name of the object that a method belongs to. - - - If `func` is a bound method or a class method, extracts the class name directly from the method. - - Returns an empty string if it's a regular function or cannot determine the class. - """ - if inspect.ismethod(func): - if inspect.isclass(func.__self__): - return func.__self__.__name__ - - return func.__self__.__class__.__name__ - - if inspect.isfunction(func): - qualname_parts = func.__qualname__.split(".") - if len(qualname_parts) > 1: - class_name = qualname_parts[-2] - if class_name.isidentifier(): - return class_name - - return "" diff --git a/tests/metagpt/actions/test_action_node.py b/tests/metagpt/actions/test_action_node.py index bc85925a8..23779c984 100644 --- a/tests/metagpt/actions/test_action_node.py +++ b/tests/metagpt/actions/test_action_node.py @@ -91,10 +91,10 @@ async def test_action_node_two_layer(): assert node_b in root.children.values() # FIXME: ADD MARKDOWN SUPPORT. NEED TO TUNE MARKDOWN SYMBOL FIRST. - answer1 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM()) + answer1 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="simple", llm=LLM()) assert "579" in answer1.content - answer2 = await root.fill(context="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM()) + answer2 = await root.fill(req="what's the answer to 123+456?", schema="json", strgy="complex", llm=LLM()) assert "579" in answer2.content @@ -112,7 +112,7 @@ async def test_action_node_review(): with pytest.raises(RuntimeError): _ = await node_a.review() - _ = await node_a.fill(context=None, llm=LLM()) + _ = await node_a.fill(req=None, llm=LLM()) setattr(node_a.instruct_content, key, "game snake") # wrong content to review review_comments = await node_a.review(review_mode=ReviewMode.AUTO) @@ -126,7 +126,7 @@ async def test_action_node_review(): with pytest.raises(RuntimeError): _ = await node.review() - _ = await node.fill(context=None, llm=LLM()) + _ = await node.fill(req=None, llm=LLM()) review_comments = await node.review(review_mode=ReviewMode.AUTO) assert len(review_comments) == 1 @@ -151,7 +151,7 @@ async def test_action_node_revise(): with pytest.raises(RuntimeError): _ = await node_a.review() - _ = await node_a.fill(context=None, llm=LLM()) + _ = await node_a.fill(req=None, llm=LLM()) setattr(node_a.instruct_content, key, "game snake") # wrong content to revise revise_contents = await node_a.revise(revise_mode=ReviseMode.AUTO) assert len(revise_contents) == 1 @@ -164,7 +164,7 @@ async def test_action_node_revise(): with pytest.raises(RuntimeError): _ = await node.revise() - _ = await node.fill(context=None, llm=LLM()) + _ = await node.fill(req=None, llm=LLM()) setattr(node.instruct_content, key, "game snake") revise_contents = await node.revise(revise_mode=ReviseMode.AUTO) assert len(revise_contents) == 1 @@ -257,7 +257,7 @@ async def test_action_node_with_image(mocker): invoice_path = Path(__file__).parent.joinpath("..", "..", "data", "invoices", "invoice-2.png") img_base64 = encode_image(invoice_path) mocker.patch("metagpt.provider.openai_api.OpenAILLM._cons_kwargs", _cons_kwargs) - node = await invoice.fill(context="", llm=LLM(), images=[img_base64]) + node = await invoice.fill(req="", llm=LLM(), images=[img_base64]) assert node.instruct_content.invoice diff --git a/tests/metagpt/actions/test_design_api_an.py b/tests/metagpt/actions/test_design_api_an.py index 3d11f200d..4ed3cb362 100644 --- a/tests/metagpt/actions/test_design_api_an.py +++ b/tests/metagpt/actions/test_design_api_an.py @@ -38,7 +38,7 @@ async def test_write_design_an(mocker): mocker.patch("metagpt.actions.design_api_an.REFINED_DESIGN_NODE.fill", return_value=root) prompt = NEW_REQ_TEMPLATE.format(old_design=DESIGN_SAMPLE, context=dict_to_markdown(REFINED_PRD_JSON)) - node = await REFINED_DESIGN_NODE.fill(prompt, llm) + node = await REFINED_DESIGN_NODE.fill(req=prompt, llm=llm) assert "Refined Implementation Approach" in node.instruct_content.model_dump() assert "Refined File list" in node.instruct_content.model_dump() diff --git a/tests/metagpt/actions/test_project_management_an.py b/tests/metagpt/actions/test_project_management_an.py index 5a65e50c9..6d41109c9 100644 --- a/tests/metagpt/actions/test_project_management_an.py +++ b/tests/metagpt/actions/test_project_management_an.py @@ -42,7 +42,7 @@ async def test_project_management_an(mocker): root.instruct_content.model_dump = mock_task_json mocker.patch("metagpt.actions.project_management_an.PM_NODE.fill", return_value=root) - node = await PM_NODE.fill(dict_to_markdown(REFINED_DESIGN_JSON), llm) + node = await PM_NODE.fill(req=dict_to_markdown(REFINED_DESIGN_JSON), llm=llm) assert "Logic Analysis" in node.instruct_content.model_dump() assert "Task list" in node.instruct_content.model_dump() @@ -59,7 +59,7 @@ async def test_project_management_an_inc(mocker): mocker.patch("metagpt.actions.project_management_an.REFINED_PM_NODE.fill", return_value=root) prompt = NEW_REQ_TEMPLATE.format(old_task=TASK_SAMPLE, context=dict_to_markdown(REFINED_DESIGN_JSON)) - node = await REFINED_PM_NODE.fill(prompt, llm) + node = await REFINED_PM_NODE.fill(req=prompt, llm=llm) assert "Refined Logic Analysis" in node.instruct_content.model_dump() assert "Refined Task list" in node.instruct_content.model_dump() diff --git a/tests/metagpt/actions/test_write_prd_an.py b/tests/metagpt/actions/test_write_prd_an.py index 378ce42c3..b6e92d3d6 100644 --- a/tests/metagpt/actions/test_write_prd_an.py +++ b/tests/metagpt/actions/test_write_prd_an.py @@ -39,7 +39,7 @@ async def test_write_prd_an(mocker): requirements=NEW_REQUIREMENT_SAMPLE, old_prd=PRD_SAMPLE, ) - node = await REFINED_PRD_NODE.fill(prompt, llm) + node = await REFINED_PRD_NODE.fill(req=prompt, llm=llm) assert "Refined Requirements" in node.instruct_content.model_dump() assert "Refined Product Goals" in node.instruct_content.model_dump() diff --git a/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py new file mode 100644 index 000000000..17696e1b4 --- /dev/null +++ b/tests/metagpt/exp_pool/test_context_builders/test_base_context_builder.py @@ -0,0 +1,45 @@ +import pytest + +from metagpt.exp_pool.context_builders.base import ( + EXP_TEMPLATE, + BaseContextBuilder, + Experience, +) +from metagpt.exp_pool.schema import Metric, Score + + +class TestBaseContextBuilder: + class ConcreteContextBuilder(BaseContextBuilder): + async def build(self, *args, **kwargs): + pass + + @pytest.fixture + def context_builder(self): + return self.ConcreteContextBuilder() + + def test_format_exps(self, context_builder): + exp1 = Experience(req="req1", resp="resp1", metric=Metric(score=Score(val=8))) + exp2 = Experience(req="req2", resp="resp2", metric=Metric(score=Score(val=9))) + context_builder.exps = [exp1, exp2] + + result = context_builder.format_exps() + expected = "\n".join( + [ + f"1. {EXP_TEMPLATE.format(req='req1', resp='resp1', score=8)}", + f"2. {EXP_TEMPLATE.format(req='req2', resp='resp2', score=9)}", + ] + ) + assert result == expected + + def test_replace_content_between_markers(self): + text = "Start\n# Example\nOld content\n# Instruction\nEnd" + new_content = "New content" + result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content) + expected = "Start\n# Example\nNew content\n\n# Instruction\nEnd" + assert result == expected + + def test_replace_content_between_markers_no_match(self): + text = "Start\nNo markers\nEnd" + new_content = "New content" + result = BaseContextBuilder.replace_content_between_markers(text, "# Example", "# Instruction", new_content) + assert result == text diff --git a/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py new file mode 100644 index 000000000..0ea04432d --- /dev/null +++ b/tests/metagpt/exp_pool/test_context_builders/test_rolezero_context_builder.py @@ -0,0 +1,38 @@ +import pytest + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder +from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder + + +class TestRoleZeroContextBuilder: + @pytest.fixture + def context_builder(self): + return RoleZeroContextBuilder() + + @pytest.mark.asyncio + async def test_build_empty_req(self, context_builder): + result = await context_builder.build(req=[]) + assert result == [] + + @pytest.mark.asyncio + async def test_build_no_experiences(self, context_builder, mocker): + mocker.patch.object(BaseContextBuilder, "format_exps", return_value="") + req = [{"content": "Original content"}] + result = await context_builder.build(req=req) + assert result == req + + @pytest.mark.asyncio + async def test_build_with_experiences(self, context_builder, mocker): + mocker.patch.object(BaseContextBuilder, "format_exps", return_value="Formatted experiences") + mocker.patch.object(RoleZeroContextBuilder, "replace_example_content", return_value="Updated content") + req = [{"content": "Original content"}] + result = await context_builder.build(req=req) + assert result == [{"content": "Updated content"}] + + def test_replace_example_content(self, context_builder, mocker): + mocker.patch.object(BaseContextBuilder, "replace_content_between_markers", return_value="Replaced content") + result = context_builder.replace_example_content("Original text", "New example content") + assert result == "Replaced content" + context_builder.replace_content_between_markers.assert_called_once_with( + "Original text", "# Example", "# Instruction", "New example content" + ) diff --git a/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py b/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py new file mode 100644 index 000000000..e96addab9 --- /dev/null +++ b/tests/metagpt/exp_pool/test_context_builders/test_simple_context_builder.py @@ -0,0 +1,46 @@ +import pytest + +from metagpt.exp_pool.context_builders.base import BaseContextBuilder +from metagpt.exp_pool.context_builders.simple import ( + SIMPLE_CONTEXT_TEMPLATE, + SimpleContextBuilder, +) + + +class TestSimpleContextBuilder: + @pytest.fixture + def context_builder(self): + return SimpleContextBuilder() + + @pytest.mark.asyncio + async def test_build_with_experiences(self, context_builder, mocker): + # Mock the format_exps method + mock_exps = "Mocked experiences" + mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps) + + req = "Test request" + result = await context_builder.build(req=req) + + expected = SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=mock_exps) + assert result == expected + + @pytest.mark.asyncio + async def test_build_without_experiences(self, context_builder, mocker): + # Mock the format_exps method to return an empty string + mocker.patch.object(BaseContextBuilder, "format_exps", return_value="") + + req = "Test request" + result = await context_builder.build(req=req) + + assert result == req + + @pytest.mark.asyncio + async def test_build_without_req(self, context_builder, mocker): + # Mock the format_exps method + mock_exps = "Mocked experiences" + mocker.patch.object(BaseContextBuilder, "format_exps", return_value=mock_exps) + + result = await context_builder.build() + + expected = SIMPLE_CONTEXT_TEMPLATE.format(req="", exps=mock_exps) + assert result == expected diff --git a/tests/metagpt/exp_pool/test_decorator.py b/tests/metagpt/exp_pool/test_decorator.py index bedc4e391..c0b3fe36d 100644 --- a/tests/metagpt/exp_pool/test_decorator.py +++ b/tests/metagpt/exp_pool/test_decorator.py @@ -1,29 +1,17 @@ import asyncio -import inspect import pytest +from metagpt.exp_pool.context_builders import SimpleContextBuilder from metagpt.exp_pool.decorator import ExpCacheHandler, exp_cache from metagpt.exp_pool.manager import ExperienceManager +from metagpt.exp_pool.perfect_judges import SimplePerfectJudge from metagpt.exp_pool.schema import Experience, QueryType, Score from metagpt.exp_pool.scorers import SimpleScorer from metagpt.rag.engines import SimpleEngine -def for_test_function(a, b, c=None): - return a + b if c is None else a + b + c - - -class ForTestClass: - def for_test_method(self, x, y): - return x * y - - @classmethod - def for_test_class_method(cls, x, y): - return x**y - - -class TestExpCache: +class TestExpCacheHandler: @pytest.fixture def mock_func(self, mocker): return mocker.AsyncMock() @@ -34,7 +22,6 @@ class TestExpCache: manager.storage = mocker.MagicMock(spec=SimpleEngine) manager.query_exps = mocker.AsyncMock() manager.create_exp = mocker.MagicMock() - manager.extract_one_perfect_exp = mocker.MagicMock() return manager @pytest.fixture @@ -44,174 +31,165 @@ class TestExpCache: return scorer @pytest.fixture - def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer): + def mock_perfect_judge(self, mocker): + return mocker.MagicMock(spec=SimplePerfectJudge) + + @pytest.fixture + def mock_context_builder(self, mocker): + return mocker.MagicMock(spec=SimpleContextBuilder) + + @pytest.fixture + def exp_cache_handler(self, mock_func, mock_exp_manager, mock_scorer, mock_perfect_judge, mock_context_builder): return ExpCacheHandler( - func=mock_func, args=(), kwargs={}, exp_manager=mock_exp_manager, exp_scorer=mock_scorer, pass_exps=False + func=mock_func, + args=(), + kwargs={"req": "test_req"}, + exp_manager=mock_exp_manager, + exp_scorer=mock_scorer, + exp_perfect_judge=mock_perfect_judge, + context_builder=mock_context_builder, ) @pytest.mark.asyncio async def test_fetch_experiences(self, exp_cache_handler, mock_exp_manager): - await exp_cache_handler.fetch_experiences(QueryType.SEMANTIC) - mock_exp_manager.query_exps.assert_called_once() + mock_exp_manager.query_exps.return_value = [Experience(req="test_req", resp="test_resp")] + await exp_cache_handler.fetch_experiences() + mock_exp_manager.query_exps.assert_called_once_with( + "test_req", query_type=QueryType.SEMANTIC, tag=exp_cache_handler.tag + ) + assert len(exp_cache_handler._exps) == 1 @pytest.mark.asyncio - async def test_perfect_experience_found(self, exp_cache_handler, mock_exp_manager, mock_func): - # Setup: Assume perfect experience is found - perfect_exp = Experience(req="req", resp="resp") - mock_exp_manager.extract_one_perfect_exp.return_value = perfect_exp - - # Exec - exp_cache_handler._exps = [perfect_exp] # Simulate fetched experiences - result = exp_cache_handler.get_one_perfect_experience() - - # Assert - assert result.resp == "resp" - mock_func.assert_not_called() # Function should not be called + async def test_get_one_perfect_exp(self, exp_cache_handler, mock_perfect_judge): + exp = Experience(req="test_req", resp="perfect_resp") + exp_cache_handler._exps = [exp] + mock_perfect_judge.is_perfect_exp.return_value = True + result = await exp_cache_handler.get_one_perfect_exp() + assert result == "perfect_resp" @pytest.mark.asyncio - async def test_execute_function_when_no_perfect_exp(self, exp_cache_handler, mock_exp_manager, mock_func): - # Setup: No perfect experience - mock_exp_manager.extract_one_perfect_exp.return_value = None - mock_func.return_value = "Computed result" - - # Exec + async def test_execute_function(self, exp_cache_handler, mock_func, mock_context_builder): + mock_context_builder.build.return_value = "built_context" + mock_func.return_value = "function_result" await exp_cache_handler.execute_function() - - # Assert - assert exp_cache_handler._result == "Computed result" - mock_func.assert_called_once() + mock_context_builder.build.assert_called_once() + mock_func.assert_called_once_with(req="built_context") + assert exp_cache_handler._raw_resp == "function_result" + assert exp_cache_handler._resp == "function_result" @pytest.mark.asyncio - async def test_evaluate_and_save_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager): - # Setup - mock_scorer.evaluate.return_value = Score(value=100) - exp_cache_handler._result = "Computed result" - - # Exec - await exp_cache_handler.evaluate_experience() - exp_cache_handler.save_experience() - - # Assert + async def test_process_experience(self, exp_cache_handler, mock_scorer, mock_exp_manager): + exp_cache_handler._resp = "test_resp" + mock_scorer.evaluate.return_value = Score(val=8) + await exp_cache_handler.process_experience() mock_scorer.evaluate.assert_called_once() mock_exp_manager.create_exp.assert_called_once() @pytest.mark.asyncio - async def test_async_function_execution_with_exps(self, exp_cache_handler, mock_exp_manager, mock_func): - # Setup - exp_cache_handler.pass_exps_to_func = True - mock_func.return_value = "Async result with exps" - mock_exp_manager.extract_one_perfect_exp.return_value = None - exp_cache_handler._exps = [Experience(req="req", resp="resp")] + async def test_evaluate_experience(self, exp_cache_handler, mock_scorer): + exp_cache_handler._resp = "test_resp" + mock_scorer.evaluate.return_value = Score(val=9) + await exp_cache_handler.evaluate_experience() + assert exp_cache_handler._score.val == 9 - # Exec - await exp_cache_handler.execute_function() - - # Assert - mock_func.assert_called_once_with(exps=exp_cache_handler._exps) - assert exp_cache_handler._result == "Async result with exps" - - def test_sync_function_execution_with_exps(self, mocker, exp_cache_handler, mock_exp_manager, mock_func): - # Setup - exp_cache_handler.func = mocker.Mock(return_value="Sync result with exps") - exp_cache_handler.pass_exps_to_func = True - mock_exp_manager.extract_one_perfect_exp.return_value = None - exp_cache_handler._exps = [Experience(req="req", resp="resp")] - - # Exec - asyncio.get_event_loop().run_until_complete(exp_cache_handler.execute_function()) - - # Assert - exp_cache_handler.func.assert_called_once_with(exps=exp_cache_handler._exps) - assert exp_cache_handler._result == "Sync result with exps" - - def test_wrapper_selection_async(self, mocker, exp_cache_handler, mock_func): - # Setup - mock_func = mocker.AsyncMock() - - # Exec - wrapper = ExpCacheHandler.choose_wrapper(mock_func, exp_cache_handler.execute_function) - - # Assert - assert asyncio.iscoroutinefunction(wrapper), "Wrapper should be asynchronous" - - def test_wrapper_selection_sync(self, exp_cache_handler, mocker): - # Setup - sync_func = mocker.Mock() - - # Exec - wrapper = ExpCacheHandler.choose_wrapper(sync_func, exp_cache_handler.execute_function) - - # Assert - assert not asyncio.iscoroutinefunction(wrapper), "Wrapper should be synchronous" - - @pytest.mark.parametrize( - "func, args, kwargs, expected", - [ - (for_test_function, (1, 2), {"c": 3}, 'for_test_function@[1~2]@{"c"!3}'), - (ForTestClass().for_test_method, (4, 5), {}, "ForTestClass.for_test_method@[4~5]@{}"), - (ForTestClass.for_test_class_method, (6, 7), {}, "ForTestClass.for_test_class_method@[6~7]@{}"), - (for_test_function, (), {}, "for_test_function@[]@{}"), - ( - for_test_function, - ("hello", [1, 2]), - {"key": "value"}, - 'for_test_function@["hello"~[1~2]]@{"key"!"value"}', - ), - ], - ) - def test_generate_req_identifier(self, func, args, kwargs, expected): - req_identifier = ExpCacheHandler.generate_req_identifier(func, *args, **kwargs) - assert req_identifier == expected - - @pytest.mark.asyncio - async def test_exp_cache_with_perfect_experience(self, mocker, mock_exp_manager): - # Mock perfect experience - perfect_exp = Experience(req="test_req", resp="perfect_response") - mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[perfect_exp]) - mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=perfect_exp) - async_mock_func = mocker.AsyncMock() - - # Setup - decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager) - - # Exec - result: Experience = await decorated_func() - - # Assert - assert result.resp == "perfect_response", "Should return the perfect experience response" - async_mock_func.assert_not_called() - - @pytest.mark.asyncio - async def test_exp_cache_without_perfect_experience(self, mocker, mock_exp_manager): - # Mock - mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[]) - mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None) - async_mock_func = mocker.AsyncMock(return_value="computed_response") - async_mock_func.__signature__ = inspect.signature(for_test_function) - - # Setup - decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager) - - # Exec - result = await decorated_func() - - # Assert - assert result == "computed_response", "Should execute and return the function's response" - async_mock_func.assert_called_once() - - @pytest.mark.asyncio - async def test_exp_cache_saves_new_experience(self, mocker, mock_exp_manager, mock_scorer): - # Mock - mock_exp_manager.query_exps = mocker.AsyncMock(return_value=[]) - mock_exp_manager.extract_one_perfect_exp = mocker.MagicMock(return_value=None) - async_mock_func = mocker.AsyncMock(return_value="computed_response") - mock_scorer.evaluate = mocker.AsyncMock(return_value=Score(value=100)) - - # Setup - decorated_func = exp_cache(async_mock_func, manager=mock_exp_manager, scorer=mock_scorer) - - # Exec - await decorated_func() - - # Assert + def test_save_experience(self, exp_cache_handler, mock_exp_manager): + exp_cache_handler._req = "test_req" + exp_cache_handler._resp = "test_resp" + exp_cache_handler._score = Score(val=7) + exp_cache_handler.save_experience() mock_exp_manager.create_exp.assert_called_once() + + def test_choose_wrapper_async(self, mocker): + async def async_func(): + pass + + wrapper = ExpCacheHandler.choose_wrapper(async_func, mocker.AsyncMock()) + assert asyncio.iscoroutinefunction(wrapper) + + def test_choose_wrapper_sync(self, mocker): + def sync_func(): + pass + + wrapper = ExpCacheHandler.choose_wrapper(sync_func, mocker.AsyncMock()) + assert not asyncio.iscoroutinefunction(wrapper) + + def test_validate_params(self): + with pytest.raises(ValueError): + ExpCacheHandler(func=lambda x: x, args=(), kwargs={}) + + def test_generate_tag(self): + class TestClass: + def test_method(self): + pass + + handler = ExpCacheHandler(func=TestClass().test_method, args=(TestClass(),), kwargs={"req": "test"}) + assert handler._generate_tag() == "TestClass.test_method" + + handler = ExpCacheHandler(func=lambda x: x, args=(), kwargs={"req": "test"}) + assert handler._generate_tag() == "" + + +class TestExpCache: + @pytest.fixture + def mock_exp_manager(self, mocker): + manager = mocker.MagicMock(spec=ExperienceManager) + manager.storage = mocker.MagicMock(spec=SimpleEngine) + manager.query_exps = mocker.AsyncMock() + manager.create_exp = mocker.MagicMock() + return manager + + @pytest.fixture + def mock_scorer(self, mocker): + scorer = mocker.MagicMock(spec=SimpleScorer) + scorer.evaluate = mocker.AsyncMock(return_value=Score()) + return scorer + + @pytest.fixture + def mock_perfect_judge(self, mocker): + return mocker.MagicMock(spec=SimplePerfectJudge) + + @pytest.fixture + def mock_config(self, mocker): + return mocker.patch("metagpt.exp_pool.decorator.config") + + @pytest.mark.asyncio + async def test_exp_cache_disabled(self, mock_config, mock_exp_manager): + mock_config.exp_pool.enable_read = False + + @exp_cache(manager=mock_exp_manager) + async def test_func(req): + return "result" + + result = await test_func(req="test") + assert result == "result" + mock_exp_manager.query_exps.assert_not_called() + + @pytest.mark.asyncio + async def test_exp_cache_enabled_no_perfect_exp(self, mock_config, mock_exp_manager, mock_scorer): + mock_config.exp_pool.enable_read = True + mock_exp_manager.query_exps.return_value = [] + + @exp_cache(manager=mock_exp_manager, scorer=mock_scorer) + async def test_func(req): + return "computed_result" + + result = await test_func(req="test") + assert result == "computed_result" + mock_exp_manager.query_exps.assert_called() + mock_exp_manager.create_exp.assert_called() + + @pytest.mark.asyncio + async def test_exp_cache_enabled_with_perfect_exp(self, mock_config, mock_exp_manager, mock_perfect_judge): + mock_config.exp_pool.enable_read = True + perfect_exp = Experience(req="test", resp="perfect_result") + mock_exp_manager.query_exps.return_value = [perfect_exp] + mock_perfect_judge.is_perfect_exp.return_value = True + + @exp_cache(manager=mock_exp_manager, perfect_judge=mock_perfect_judge) + async def test_func(req): + return "should_not_be_called" + + result = await test_func(req="test") + assert result == "perfect_result" + mock_exp_manager.query_exps.assert_called_once() + mock_exp_manager.create_exp.assert_not_called() diff --git a/tests/metagpt/exp_pool/test_manager.py b/tests/metagpt/exp_pool/test_manager.py index 3e8f47417..c12fc7e8c 100644 --- a/tests/metagpt/exp_pool/test_manager.py +++ b/tests/metagpt/exp_pool/test_manager.py @@ -4,20 +4,25 @@ from metagpt.config2 import Config from metagpt.configs.exp_pool_config import ExperiencePoolConfig from metagpt.configs.llm_config import LLMConfig from metagpt.exp_pool.manager import ExperienceManager -from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score +from metagpt.exp_pool.schema import Experience from metagpt.rag.engines import SimpleEngine class TestExperienceManager: @pytest.fixture def mock_config(self): - return Config(llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True)) + return Config( + llm=LLMConfig(), exp_pool=ExperiencePoolConfig(enable_write=True, enable_read=True, init_exp=False) + ) @pytest.fixture def mock_storage(self, mocker): engine = mocker.MagicMock(spec=SimpleEngine) engine.add_objs = mocker.MagicMock() engine.aretrieve = mocker.AsyncMock(return_value=[]) + engine._retriever = mocker.MagicMock() + engine._retriever._vector_store = mocker.MagicMock() + engine._retriever._vector_store._get = mocker.MagicMock(return_value=mocker.MagicMock(ids=[])) return engine @pytest.fixture @@ -33,7 +38,7 @@ class TestExperienceManager: def test_create_exp(self, mock_experience_manager, mock_experience): mock_experience_manager.create_exp(mock_experience) - mock_experience_manager.storage.add_objs.assert_called_once_with([mock_experience]) + mock_experience_manager.storage.add_objs.assert_called_with([mock_experience]) def test_create_exp_write_disabled(self, mock_experience_manager, mock_experience, mock_config): mock_config.exp_pool.enable_write = False @@ -60,18 +65,44 @@ class TestExperienceManager: result = await mock_experience_manager.query_exps("query") assert result == [] - def test_extract_one_perfect_exp(self, mock_experience_manager): - experiences = [ - Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))), - Experience(req="req", resp="resp"), - ] - perfect_exp: Experience = mock_experience_manager.extract_one_perfect_exp(experiences) - assert perfect_exp is not None - assert perfect_exp.metric.score.val == MAX_SCORE + def test_init_exp_pool(self, mock_experience_manager, mock_config, mocker): + mock_experience_manager._has_exps = mocker.MagicMock(return_value=False) + mock_experience_manager._init_teamleader_exps = mocker.MagicMock() + mock_experience_manager._init_engineer2_exps = mocker.MagicMock() - def test_is_perfect_exp(self): - exp = Experience(req="req", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))) - assert ExperienceManager.is_perfect_exp(exp) == True + mock_config.exp_pool.init_exp = True + mock_experience_manager.init_exp_pool() - exp = Experience(req="req", resp="resp") - assert ExperienceManager.is_perfect_exp(exp) == False + mock_experience_manager._has_exps.assert_called_once() + mock_experience_manager._init_teamleader_exps.assert_called_once() + mock_experience_manager._init_engineer2_exps.assert_called_once() + + def test_init_exp_pool_already_has_exps(self, mock_experience_manager, mock_config, mocker): + mock_experience_manager._has_exps = mocker.MagicMock(return_value=True) + mock_experience_manager._init_teamleader_exps = mocker.MagicMock() + mock_experience_manager._init_engineer2_exps = mocker.MagicMock() + + mock_config.exp_pool.init_exp = True + mock_experience_manager.init_exp_pool() + + mock_experience_manager._has_exps.assert_called_once() + mock_experience_manager._init_teamleader_exps.assert_not_called() + mock_experience_manager._init_engineer2_exps.assert_not_called() + + def test_has_exps(self, mock_experience_manager, mock_storage): + mock_storage._retriever._vector_store._get.return_value.ids = ["id1"] + + assert mock_experience_manager._has_exps() is True + + mock_storage._retriever._vector_store._get.return_value.ids = [] + assert mock_experience_manager._has_exps() is False + + def test_init_teamleader_exps(self, mock_experience_manager, mocker): + mock_experience_manager._init_exp = mocker.MagicMock() + mock_experience_manager._init_teamleader_exps() + mock_experience_manager._init_exp.assert_called_once() + + def test_init_engineer2_exps(self, mock_experience_manager, mocker): + mock_experience_manager._init_exp = mocker.MagicMock() + mock_experience_manager._init_engineer2_exps() + mock_experience_manager._init_exp.assert_called_once() diff --git a/tests/metagpt/exp_pool/test_perfect_judges/test_simple_perfect_judge.py b/tests/metagpt/exp_pool/test_perfect_judges/test_simple_perfect_judge.py new file mode 100644 index 000000000..5abd04f0d --- /dev/null +++ b/tests/metagpt/exp_pool/test_perfect_judges/test_simple_perfect_judge.py @@ -0,0 +1,40 @@ +import pytest + +from metagpt.exp_pool.perfect_judges import SimplePerfectJudge +from metagpt.exp_pool.schema import MAX_SCORE, Experience, Metric, Score + + +class TestSimplePerfectJudge: + @pytest.fixture + def simple_perfect_judge(self): + return SimplePerfectJudge() + + @pytest.mark.asyncio + async def test_is_perfect_exp_perfect_match(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))) + result = await simple_perfect_judge.is_perfect_exp(exp, "test_request") + assert result is True + + @pytest.mark.asyncio + async def test_is_perfect_exp_imperfect_score(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE - 1))) + result = await simple_perfect_judge.is_perfect_exp(exp, "test_request") + assert result is False + + @pytest.mark.asyncio + async def test_is_perfect_exp_mismatched_request(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp", metric=Metric(score=Score(val=MAX_SCORE))) + result = await simple_perfect_judge.is_perfect_exp(exp, "different_request") + assert result is False + + @pytest.mark.asyncio + async def test_is_perfect_exp_no_metric(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp") + result = await simple_perfect_judge.is_perfect_exp(exp, "test_request") + assert result is False + + @pytest.mark.asyncio + async def test_is_perfect_exp_no_score(self, simple_perfect_judge): + exp = Experience(req="test_request", resp="resp", metric=Metric()) + result = await simple_perfect_judge.is_perfect_exp(exp, "test_request") + assert result is False diff --git a/tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py b/tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py new file mode 100644 index 000000000..043f105d0 --- /dev/null +++ b/tests/metagpt/exp_pool/test_scorers/test_simple_scorer.py @@ -0,0 +1,49 @@ +import pytest + +from metagpt.exp_pool.schema import Score +from metagpt.exp_pool.scorers.simple import SIMPLE_SCORER_TEMPLATE, SimpleScorer +from metagpt.llm import BaseLLM + + +class TestSimpleScorer: + @pytest.fixture + def mock_llm(self, mocker): + mock_llm = mocker.MagicMock(spec=BaseLLM) + return mock_llm + + @pytest.fixture + def simple_scorer(self, mock_llm): + return SimpleScorer(llm=mock_llm) + + def test_init(self, mock_llm): + scorer = SimpleScorer(llm=mock_llm) + assert isinstance(scorer.llm, BaseLLM) + + @pytest.mark.asyncio + async def test_evaluate(self, simple_scorer, mock_llm): + # Mock function to evaluate + def mock_func(a, b): + """This is a mock function.""" + return a + b + + # Mock LLM response + mock_llm.aask.return_value = '```json\n{"val": 8, "reason": "Good performance"}\n```' + + # Test evaluate method + result = await simple_scorer.evaluate(mock_func, 5, args=(2, 3), kwargs={}) + + # Assert LLM was called with correct prompt + expected_prompt = SIMPLE_SCORER_TEMPLATE.format( + func_name=mock_func.__name__, + func_doc=mock_func.__doc__, + func_signature="(a, b)", + func_args=(2, 3), + func_kwargs={}, + func_result=5, + ) + mock_llm.aask.assert_called_once_with(expected_prompt) + + # Assert the result is correct + assert isinstance(result, Score) + assert result.val == 8 + assert result.reason == "Good performance"