mirror of
https://github.com/FoundationAgents/MetaGPT.git
synced 2026-06-29 15:59:42 +02:00
use llm cache to make exp_pool
This commit is contained in:
parent
d902a6f18c
commit
c624c0ffc7
41 changed files with 844 additions and 368 deletions
|
|
@ -90,7 +90,7 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
|
|||
msgs = args[0]
|
||||
context = "## History Messages\n"
|
||||
context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))])
|
||||
return await self.node.fill(context=context, llm=self.llm)
|
||||
return await self.node.fill(req=context, llm=self.llm)
|
||||
|
||||
async def run(self, *args, **kwargs):
|
||||
"""Run action"""
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
|
|||
|
||||
from metagpt.actions.action_outcls_registry import register_action_outcls
|
||||
from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT
|
||||
from metagpt.exp_pool import exp_cache
|
||||
from metagpt.llm import BaseLLM
|
||||
from metagpt.logs import logger
|
||||
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
|
||||
|
|
@ -465,9 +466,33 @@ class ActionNode:
|
|||
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
def deserialize_to_action_node(cls, serialized_data) -> "ActionNode":
|
||||
"""Customized deserialization, it will be triggered when a perfect experience is found.
|
||||
|
||||
ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'.
|
||||
"""
|
||||
|
||||
class InstructContent:
|
||||
def __init__(self, json_data):
|
||||
self.json_data = json_data
|
||||
|
||||
def model_dump_json(self):
|
||||
return self.json_data
|
||||
|
||||
action_node = cls(key="", expected_type=Type[str], instruction="", example="")
|
||||
action_node.instruct_content = InstructContent(serialized_data)
|
||||
|
||||
return action_node
|
||||
|
||||
@exp_cache(
|
||||
resp_serialize=lambda action_node: action_node.instruct_content.model_dump_json(),
|
||||
resp_deserialize=lambda resp: ActionNode.deserialize_to_action_node(resp),
|
||||
)
|
||||
async def fill(
|
||||
self,
|
||||
context,
|
||||
*,
|
||||
req,
|
||||
llm,
|
||||
schema="json",
|
||||
mode="auto",
|
||||
|
|
@ -478,7 +503,7 @@ class ActionNode:
|
|||
):
|
||||
"""Fill the node(s) with mode.
|
||||
|
||||
:param context: Everything we should know when filling node.
|
||||
:param req: Everything we should know when filling node.
|
||||
:param llm: Large Language Model with pre-defined system message.
|
||||
:param schema: json/markdown, determine example and output format.
|
||||
- raw: free form text
|
||||
|
|
@ -497,7 +522,7 @@ class ActionNode:
|
|||
:return: self
|
||||
"""
|
||||
self.set_llm(llm)
|
||||
self.set_context(context)
|
||||
self.set_context(req)
|
||||
if self.schema:
|
||||
schema = self.schema
|
||||
|
||||
|
|
|
|||
|
|
@ -178,12 +178,12 @@ class WriteDesign(Action):
|
|||
)
|
||||
|
||||
async def _new_system_design(self, context):
|
||||
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema)
|
||||
node = await DESIGN_API_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
|
||||
return node
|
||||
|
||||
async def _merge(self, prd_doc, system_design_doc):
|
||||
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
|
||||
node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema)
|
||||
node = await REFINED_DESIGN_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
|
||||
system_design_doc.content = node.instruct_content.model_dump_json()
|
||||
return system_design_doc
|
||||
|
||||
|
|
|
|||
|
|
@ -22,4 +22,4 @@ class GenerateQuestions(Action):
|
|||
name: str = "GenerateQuestions"
|
||||
|
||||
async def run(self, context) -> ActionNode:
|
||||
return await QUESTIONS.fill(context=context, llm=self.llm)
|
||||
return await QUESTIONS.fill(req=context, llm=self.llm)
|
||||
|
|
|
|||
|
|
@ -22,4 +22,4 @@ class PrepareInterview(Action):
|
|||
name: str = "PrepareInterview"
|
||||
|
||||
async def run(self, context):
|
||||
return await QUESTIONS.fill(context=context, llm=self.llm)
|
||||
return await QUESTIONS.fill(req=context, llm=self.llm)
|
||||
|
|
|
|||
|
|
@ -151,12 +151,12 @@ class WriteTasks(Action):
|
|||
return task_doc
|
||||
|
||||
async def _run_new_tasks(self, context: str):
|
||||
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
node = await PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
|
||||
return node
|
||||
|
||||
async def _merge(self, system_design_doc, task_doc) -> Document:
|
||||
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content)
|
||||
node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
|
||||
node = await REFINED_PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
|
||||
task_doc.content = node.instruct_content.model_dump_json()
|
||||
return task_doc
|
||||
|
||||
|
|
|
|||
|
|
@ -578,7 +578,7 @@ class WriteCodeAN(Action):
|
|||
|
||||
async def run(self, context):
|
||||
self.llm.system_prompt = "You are an outstanding engineer and can implement any code"
|
||||
return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
return await WRITE_MOVE_NODE.fill(req=context, llm=self.llm, schema="json")
|
||||
|
||||
|
||||
async def main():
|
||||
|
|
|
|||
|
|
@ -229,7 +229,7 @@ class WriteCodePlanAndChange(Action):
|
|||
code=await self.get_old_codes(),
|
||||
)
|
||||
logger.info("Writing code plan and change..")
|
||||
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(req=context, llm=self.llm, schema="json")
|
||||
|
||||
async def get_old_codes(self) -> str:
|
||||
old_codes = await self.repo.srcs.get_all()
|
||||
|
|
|
|||
|
|
@ -211,7 +211,7 @@ class WritePRD(Action):
|
|||
context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name)
|
||||
exclude = [PROJECT_NAME.key] if project_name else []
|
||||
node = await WRITE_PRD_NODE.fill(
|
||||
context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
|
||||
req=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
|
||||
) # schema=schema
|
||||
return node
|
||||
|
||||
|
|
@ -238,7 +238,7 @@ class WritePRD(Action):
|
|||
async def _is_bugfix(self, context: str) -> bool:
|
||||
if not self.repo.code_files_exists():
|
||||
return False
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
|
||||
node = await WP_ISSUE_TYPE_NODE.fill(req=context, llm=self.llm)
|
||||
return node.get("issue_type") == "BUG"
|
||||
|
||||
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
|
||||
|
|
@ -248,14 +248,14 @@ class WritePRD(Action):
|
|||
|
||||
async def _is_related(self, req: Document, old_prd: Document) -> bool:
|
||||
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
|
||||
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
|
||||
node = await WP_IS_RELATIVE_NODE.fill(req=context, llm=self.llm)
|
||||
return node.get("is_relative") == "YES"
|
||||
|
||||
async def _merge(self, req: Document, related_doc: Document) -> Document:
|
||||
if not self.project_name:
|
||||
self.project_name = Path(self.project_path).name
|
||||
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
|
||||
node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
|
||||
node = await REFINED_PRD_NODE.fill(req=prompt, llm=self.llm, schema=self.prompt_schema)
|
||||
related_doc.content = node.instruct_content.model_dump_json()
|
||||
await self._rename_workspace(node)
|
||||
return related_doc
|
||||
|
|
|
|||
|
|
@ -36,4 +36,4 @@ class WriteReview(Action):
|
|||
name: str = "WriteReview"
|
||||
|
||||
async def run(self, context):
|
||||
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")
|
||||
return await WRITE_REVIEW_NODE.fill(req=context, llm=self.llm, schema="json")
|
||||
|
|
|
|||
|
|
@ -4,5 +4,9 @@ from metagpt.utils.yaml_model import YamlModel
|
|||
|
||||
|
||||
class ExperiencePoolConfig(YamlModel):
|
||||
enable_read: bool = Field(default=True, description="Enable to read from experience pool.")
|
||||
enable_write: bool = Field(default=True, description="Enable to write to experience pool.")
|
||||
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
|
||||
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
|
||||
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
|
||||
init_exp: bool = Field(
|
||||
default=False, description="Put some basic experiences associated with the roles into the experience pool."
|
||||
)
|
||||
|
|
|
|||
7
metagpt/exp_pool/context_builders/__init__.py
Normal file
7
metagpt/exp_pool/context_builders/__init__.py
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
"""Context builders init."""
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
from metagpt.exp_pool.context_builders.simple import SimpleContextBuilder
|
||||
from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder
|
||||
|
||||
__all__ = ["BaseContextBuilder", "SimpleContextBuilder", "RoleZeroContextBuilder"]
|
||||
52
metagpt/exp_pool/context_builders/base.py
Normal file
52
metagpt/exp_pool/context_builders/base.py
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
"""Base context builder."""
|
||||
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.exp_pool.schema import Experience
|
||||
|
||||
EXP_TEMPLATE = """Given the request: {req}, We can get the response: {resp}, Which scored: {score}."""
|
||||
|
||||
|
||||
class BaseContextBuilder(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
exps: list[Experience] = []
|
||||
|
||||
@abstractmethod
|
||||
async def build(self, *args, **kwargs) -> Any:
|
||||
"""Build context from parameters."""
|
||||
|
||||
def format_exps(self) -> str:
|
||||
"""Format experiences into a numbered list of strings."""
|
||||
|
||||
result = []
|
||||
for i, exp in enumerate(self.exps, start=1):
|
||||
result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=exp.metric.score.val))
|
||||
|
||||
return "\n".join(result)
|
||||
|
||||
@staticmethod
|
||||
def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str:
|
||||
"""Replace the content between `start_marker` and `end_marker` in the text with `new_content`.
|
||||
|
||||
Args:
|
||||
text (str): The original text.
|
||||
new_content (str): The new content to replace the old content.
|
||||
start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'.
|
||||
end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'.
|
||||
|
||||
Returns:
|
||||
str: The text with the content replaced.
|
||||
"""
|
||||
|
||||
pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL)
|
||||
|
||||
def replacement(match):
|
||||
return f"{match.group(1)}{new_content}\n{match.group(3)}"
|
||||
|
||||
replaced_text = pattern.sub(replacement, text)
|
||||
return replaced_text
|
||||
26
metagpt/exp_pool/context_builders/role_zero.py
Normal file
26
metagpt/exp_pool/context_builders/role_zero.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
"""RoleZero context builder."""
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
|
||||
|
||||
class RoleZeroContextBuilder(BaseContextBuilder):
|
||||
async def build(self, *args, **kwargs) -> list[dict]:
|
||||
"""Builds the context by updating the req with formatted experiences.
|
||||
|
||||
If there are no experiences, retains the original examples in req, otherwise replaces the examples with the formatted experiences.
|
||||
"""
|
||||
|
||||
req = kwargs.get("req", [])
|
||||
if not req:
|
||||
return req
|
||||
|
||||
exps_str = self.format_exps()
|
||||
if not exps_str:
|
||||
return req
|
||||
|
||||
req[-1]["content"] = self.replace_example_content(req[-1].get("content", ""), exps_str)
|
||||
|
||||
return req
|
||||
|
||||
def replace_example_content(self, text: str, new_example_content: str) -> str:
|
||||
return self.replace_content_between_markers(text, "# Example", "# Instruction", new_example_content)
|
||||
24
metagpt/exp_pool/context_builders/simple.py
Normal file
24
metagpt/exp_pool/context_builders/simple.py
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
"""Simple context builder."""
|
||||
|
||||
|
||||
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
|
||||
|
||||
SIMPLE_CONTEXT_TEMPLATE = """
|
||||
{req}
|
||||
|
||||
### Experiences
|
||||
-----
|
||||
{exps}
|
||||
-----
|
||||
|
||||
## Instruction
|
||||
Consider **Experiences** to generate a better answer.
|
||||
"""
|
||||
|
||||
|
||||
class SimpleContextBuilder(BaseContextBuilder):
|
||||
async def build(self, *args, **kwargs) -> str:
|
||||
req = kwargs.get("req", "")
|
||||
exps = self.format_exps()
|
||||
|
||||
return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req
|
||||
|
|
@ -2,18 +2,19 @@
|
|||
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
from typing import Any, Callable, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from metagpt.config2 import config
|
||||
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
|
||||
from metagpt.exp_pool.manager import ExperienceManager, exp_manager
|
||||
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
|
||||
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
|
||||
from metagpt.exp_pool.scorers import ExperienceScorer, SimpleScorer
|
||||
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
|
||||
from metagpt.logs import logger
|
||||
from metagpt.utils.async_helper import NestAsyncio
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
from metagpt.utils.reflection import get_class_name
|
||||
|
||||
ReturnType = TypeVar("ReturnType")
|
||||
|
||||
|
|
@ -21,42 +22,64 @@ ReturnType = TypeVar("ReturnType")
|
|||
def exp_cache(
|
||||
_func: Optional[Callable[..., ReturnType]] = None,
|
||||
query_type: QueryType = QueryType.SEMANTIC,
|
||||
scorer: Optional[ExperienceScorer] = None,
|
||||
manager: Optional[ExperienceManager] = None,
|
||||
pass_exps_to_func: bool = False,
|
||||
scorer: Optional[BaseScorer] = None,
|
||||
perfect_judge: Optional[BasePerfectJudge] = None,
|
||||
context_builder: Optional[BaseContextBuilder] = None,
|
||||
req_serialize: Optional[Callable[..., str]] = None,
|
||||
resp_serialize: Optional[Callable[..., str]] = None,
|
||||
resp_deserialize: Optional[Callable[[str], Any]] = None,
|
||||
tag: Optional[str] = None,
|
||||
):
|
||||
"""Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience.
|
||||
|
||||
This can be applied to both synchronous and asynchronous functions.
|
||||
1. This can be applied to both synchronous and asynchronous functions.
|
||||
2. The function must have a `req` parameter, and it must be provided as a keyword argument.
|
||||
3. If `config.exp_pool.enable_read` is False, the decorator will just directly execute the function.
|
||||
|
||||
Args:
|
||||
_func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache().
|
||||
query_type: The type of query to be used when fetching experiences.
|
||||
scorer: Evaluate experience. Default SimpleScorer.
|
||||
manager: How to fetch, evaluate and save experience, etc. Default exp_manager.
|
||||
pass_exps_to_func: To control whether imperfect experiences are passed to the function, if True, the func must have a parameter named 'exps'.
|
||||
manager: How to fetch, evaluate and save experience, etc. Default to `exp_manager`.
|
||||
scorer: Evaluate experience. Default to `SimpleScorer()`.
|
||||
perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`.
|
||||
context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`.
|
||||
req_serialize: Serializes the request for storage. Defaults to `lambda req: str(req)`.
|
||||
resp_serialize: Serializes the function's return value for storage. Defaults to `lambda resp: str(resp)`.
|
||||
resp_deserialize: Deserializes the stored response back to the function's return value. Defaults to `lambda resp: resp`.
|
||||
tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
|
||||
if not config.exp_pool.enable_read:
|
||||
return func
|
||||
|
||||
@functools.wraps(func)
|
||||
async def get_or_create(args: Any, kwargs: Any) -> ReturnType:
|
||||
logger.info("exp_cache is enabled.")
|
||||
handler = ExpCacheHandler(
|
||||
func=func,
|
||||
args=args,
|
||||
kwargs=kwargs,
|
||||
query_type=query_type,
|
||||
exp_manager=manager,
|
||||
exp_scorer=scorer,
|
||||
pass_exps_to_func=pass_exps_to_func,
|
||||
exp_perfect_judge=perfect_judge,
|
||||
context_builder=context_builder,
|
||||
req_serialize=req_serialize,
|
||||
resp_serialize=resp_serialize,
|
||||
resp_deserialize=resp_deserialize,
|
||||
tag=tag,
|
||||
)
|
||||
|
||||
await handler.fetch_experiences(query_type)
|
||||
if exp := handler.get_one_perfect_experience():
|
||||
await handler.fetch_experiences()
|
||||
if exp := await handler.get_one_perfect_exp():
|
||||
return exp
|
||||
|
||||
await handler.execute_function()
|
||||
await handler.process_experience()
|
||||
|
||||
return handler._result
|
||||
return handler._raw_resp
|
||||
|
||||
return ExpCacheHandler.choose_wrapper(func, get_or_create)
|
||||
|
||||
|
|
@ -69,39 +92,59 @@ class ExpCacheHandler(BaseModel):
|
|||
func: Callable
|
||||
args: Any
|
||||
kwargs: Any
|
||||
query_type: QueryType = QueryType.SEMANTIC
|
||||
exp_manager: Optional[ExperienceManager] = None
|
||||
exp_scorer: Optional[ExperienceScorer] = None
|
||||
pass_exps_to_func: bool = False
|
||||
exp_scorer: Optional[BaseScorer] = None
|
||||
exp_perfect_judge: Optional[BasePerfectJudge] = None
|
||||
context_builder: Optional[BaseContextBuilder] = None
|
||||
req_serialize: Optional[Callable[..., str]] = None
|
||||
resp_serialize: Optional[Callable[..., str]] = None
|
||||
resp_deserialize: Optional[Callable[[str], Any]] = None
|
||||
tag: Optional[str] = None
|
||||
|
||||
_exps: list[Experience] = None
|
||||
_result: Any = None
|
||||
_req: str = ""
|
||||
_resp: str = ""
|
||||
_raw_resp: Any = None
|
||||
_score: Score = None
|
||||
_req: str = None
|
||||
|
||||
@model_validator(mode="after")
|
||||
def initialize(self):
|
||||
if self.exp_manager is None:
|
||||
self.exp_manager = exp_manager
|
||||
self._validate_params()
|
||||
|
||||
if self.exp_scorer is None:
|
||||
self.exp_scorer = SimpleScorer()
|
||||
self.exp_manager = self.exp_manager or exp_manager
|
||||
self.exp_scorer = self.exp_scorer or SimpleScorer()
|
||||
self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge()
|
||||
self.context_builder = self.context_builder or SimpleContextBuilder()
|
||||
self.req_serialize = self.req_serialize or (lambda resp: str(resp))
|
||||
self.resp_serialize = self.resp_serialize or (lambda resp: str(resp))
|
||||
self.resp_deserialize = self.resp_deserialize or (lambda resp: resp)
|
||||
self.tag = self.tag or self._generate_tag()
|
||||
|
||||
self._req = self.generate_req_identifier(self.func, *self.args, **self.kwargs)
|
||||
self._req = self.req_serialize(self.kwargs["req"])
|
||||
|
||||
return self
|
||||
|
||||
async def fetch_experiences(self, query_type: QueryType):
|
||||
async def fetch_experiences(self):
|
||||
"""Fetch experiences by query_type."""
|
||||
|
||||
self._exps = await self.exp_manager.query_exps(self._req, query_type=query_type)
|
||||
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
|
||||
|
||||
def get_one_perfect_experience(self) -> Optional[Experience]:
|
||||
"""Get a potentially perfect experience."""
|
||||
return self.exp_manager.extract_one_perfect_exp(self._exps)
|
||||
async def get_one_perfect_exp(self) -> Optional[Any]:
|
||||
"""Get a potentially perfect experience, and resolve resp."""
|
||||
|
||||
for exp in self._exps:
|
||||
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
|
||||
logger.info(f"Get one perfect experience: {exp.req[:20]}...")
|
||||
return self.resp_deserialize(exp.resp)
|
||||
|
||||
return None
|
||||
|
||||
async def execute_function(self):
|
||||
"""Execute the function, and save the result."""
|
||||
self._result = await self._execute_function()
|
||||
"""Execute the function, and save resp."""
|
||||
|
||||
self._raw_resp = await self._execute_function()
|
||||
self._resp = self.resp_serialize(self._raw_resp)
|
||||
|
||||
@handle_exception
|
||||
async def process_experience(self):
|
||||
|
|
@ -110,41 +153,21 @@ class ExpCacheHandler(BaseModel):
|
|||
Evaluates and saves experience.
|
||||
Use `handle_exception` to ensure robustness, do not stop subsequent operations.
|
||||
"""
|
||||
|
||||
await self.evaluate_experience()
|
||||
self.save_experience()
|
||||
|
||||
async def evaluate_experience(self):
|
||||
"""Evaluate the experience, and save the score."""
|
||||
|
||||
self._score = await self.exp_scorer.evaluate(self.func, self._result, self.args, self.kwargs)
|
||||
self._score = await self.exp_scorer.evaluate(self.func, self._resp, self.args, self.kwargs)
|
||||
|
||||
def save_experience(self):
|
||||
"""Save the new experience."""
|
||||
|
||||
exp = Experience(req=self._req, resp=self._result, metric=Metric(score=self._score))
|
||||
|
||||
exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score))
|
||||
self.exp_manager.create_exp(exp)
|
||||
|
||||
@classmethod
|
||||
def generate_req_identifier(cls, func, *args, **kwargs) -> str:
|
||||
"""Generate a unique request identifier for any given function and its arguments.
|
||||
|
||||
Serializing args and kwargs into JSON strings and replacing ',' with '~' and ':' with '!'.
|
||||
|
||||
Return Example:
|
||||
SimpleClass.test_method@[1~2]@{"c"!3}
|
||||
"""
|
||||
cls_name = get_class_name(func)
|
||||
func_name = f"{cls_name}.{func.__name__}" if cls_name else func.__name__
|
||||
|
||||
if cls_name and args and inspect.isfunction(func):
|
||||
args = args[1:]
|
||||
|
||||
args = cls._serialize_and_replace(args)
|
||||
kwargs = cls._serialize_and_replace(kwargs)
|
||||
|
||||
return f"{func_name}@{args}@{kwargs}"
|
||||
|
||||
@staticmethod
|
||||
def choose_wrapper(func, wrapped_func):
|
||||
"""Choose how to run wrapped_func based on whether the function is asynchronous."""
|
||||
|
|
@ -158,25 +181,31 @@ class ExpCacheHandler(BaseModel):
|
|||
|
||||
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
|
||||
|
||||
@classmethod
|
||||
def _serialize_and_replace(cls, data):
|
||||
json_str = json.dumps(data)
|
||||
return json_str.replace(", ", "~").replace(": ", "!")
|
||||
def _validate_params(self):
|
||||
if "req" not in self.kwargs:
|
||||
raise ValueError("`req` must be provided as a keyword argument.")
|
||||
|
||||
def _generate_tag(self) -> str:
|
||||
"""Generates a tag for the self.func.
|
||||
|
||||
"ClassName.method_name" if the first argument is a class instance, otherwise just "function_name".
|
||||
"""
|
||||
|
||||
if self.args and hasattr(self.args[0], "__class__"):
|
||||
cls_name = type(self.args[0]).__name__
|
||||
return f"{cls_name}.{self.func.__name__}"
|
||||
|
||||
return self.func.__name__
|
||||
|
||||
async def _build_context(self) -> str:
|
||||
self.context_builder.exps = self._exps
|
||||
|
||||
return await self.context_builder.build(*self.args, **self.kwargs)
|
||||
|
||||
async def _execute_function(self):
|
||||
if self.pass_exps_to_func:
|
||||
return await self._execute_function_with_exps()
|
||||
self.kwargs["req"] = await self._build_context()
|
||||
|
||||
return await self._execute_function_without_exps()
|
||||
|
||||
async def _execute_function_without_exps(self):
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
return await self.func(*self.args, **self.kwargs)
|
||||
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
async def _execute_function_with_exps(self):
|
||||
if asyncio.iscoroutinefunction(self.func):
|
||||
return await self.func(*self.args, **self.kwargs, exps=self._exps)
|
||||
|
||||
return self.func(*self.args, **self.kwargs, exps=self._exps)
|
||||
|
|
|
|||
|
|
@ -1,13 +1,22 @@
|
|||
"""Experience Manager."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.vector_stores.chroma import ChromaVectorStore
|
||||
from pydantic import BaseModel, ConfigDict, model_validator
|
||||
|
||||
from metagpt.config2 import Config, config
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType
|
||||
from metagpt.exp_pool.schema import (
|
||||
DEFAULT_COLLECTION_NAME,
|
||||
DEFAULT_SIMILARITY_TOP_K,
|
||||
EntryType,
|
||||
Experience,
|
||||
Metric,
|
||||
QueryType,
|
||||
Score,
|
||||
)
|
||||
from metagpt.logs import logger
|
||||
from metagpt.rag.engines import SimpleEngine
|
||||
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
|
||||
from metagpt.strategy.experience_retriever import ENGINEER_EXAMPLE, TL_EXAMPLE
|
||||
from metagpt.utils.exceptions import handle_exception
|
||||
|
||||
|
||||
|
|
@ -27,14 +36,33 @@ class ExperienceManager(BaseModel):
|
|||
@model_validator(mode="after")
|
||||
def initialize(self):
|
||||
if self.storage is None:
|
||||
self.storage = SimpleEngine.from_objs(
|
||||
retriever_configs=[
|
||||
ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data")
|
||||
],
|
||||
ranker_configs=[LLMRankerConfig()],
|
||||
)
|
||||
retriever_configs = [
|
||||
ChromaRetrieverConfig(
|
||||
persist_path=self.config.exp_pool.persist_path,
|
||||
collection_name=DEFAULT_COLLECTION_NAME,
|
||||
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
|
||||
)
|
||||
]
|
||||
ranker_configs = [LLMRankerConfig()]
|
||||
|
||||
self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
|
||||
|
||||
self.init_exp_pool()
|
||||
|
||||
return self
|
||||
|
||||
@handle_exception
|
||||
def init_exp_pool(self):
|
||||
if not self.config.exp_pool.init_exp:
|
||||
return
|
||||
|
||||
if self._has_exps():
|
||||
return
|
||||
|
||||
self._init_teamleader_exps()
|
||||
self._init_engineer2_exps()
|
||||
logger.info("`init_exp_pool` done.")
|
||||
|
||||
@handle_exception
|
||||
def create_exp(self, exp: Experience):
|
||||
"""Adds an experience to the storage if writing is enabled.
|
||||
|
|
@ -74,39 +102,26 @@ class ExperienceManager(BaseModel):
|
|||
|
||||
return exps
|
||||
|
||||
def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]:
|
||||
"""Extracts the first 'perfect' experience from a list of experiences.
|
||||
def _has_exps(self) -> bool:
|
||||
vector_store: ChromaVectorStore = self.storage._retriever._vector_store
|
||||
|
||||
Args:
|
||||
exps (list[Experience]): The experiences to evaluate.
|
||||
return bool(vector_store._get(limit=1, where={}).ids)
|
||||
|
||||
Returns:
|
||||
Optional[Experience]: The first perfect experience if found, otherwise None.
|
||||
"""
|
||||
for exp in exps:
|
||||
if self.is_perfect_exp(exp):
|
||||
return exp
|
||||
def _init_exp(self, req: str, resp: str, tag: str, metric: Metric = None):
|
||||
exp = Experience(
|
||||
req=req,
|
||||
resp=resp,
|
||||
entry_type=EntryType.MANUAL,
|
||||
tag=tag,
|
||||
metric=metric or Metric(score=Score(val=9, reason="Manual")),
|
||||
)
|
||||
self.create_exp(exp)
|
||||
|
||||
return None
|
||||
def _init_teamleader_exps(self):
|
||||
self._init_exp(req=TL_EXAMPLE, resp=TL_EXAMPLE, tag="TeamLeader.llm_cached_aask")
|
||||
|
||||
@staticmethod
|
||||
def is_perfect_exp(exp: Experience) -> bool:
|
||||
"""Determines if an experience is considered 'perfect'.
|
||||
|
||||
Args:
|
||||
exp (Experience): The experience to evaluate.
|
||||
|
||||
Returns:
|
||||
bool: True if the experience is manually entered, otherwise False.
|
||||
"""
|
||||
if not exp:
|
||||
return False
|
||||
|
||||
# TODO: need more metrics
|
||||
if exp.metric and exp.metric.score.val == MAX_SCORE:
|
||||
return True
|
||||
|
||||
return False
|
||||
def _init_engineer2_exps(self):
|
||||
self._init_exp(req=ENGINEER_EXAMPLE, resp=ENGINEER_EXAMPLE, tag="Engineer2.llm_cached_aask")
|
||||
|
||||
|
||||
exp_manager = ExperienceManager()
|
||||
|
|
|
|||
6
metagpt/exp_pool/perfect_judges/__init__.py
Normal file
6
metagpt/exp_pool/perfect_judges/__init__.py
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
"""Perfect judges init."""
|
||||
|
||||
from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge
|
||||
from metagpt.exp_pool.perfect_judges.simple import SimplePerfectJudge
|
||||
|
||||
__all__ = ["BasePerfectJudge", "SimplePerfectJudge"]
|
||||
20
metagpt/exp_pool/perfect_judges/base.py
Normal file
20
metagpt/exp_pool/perfect_judges/base.py
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
"""Base perfect judge."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from metagpt.exp_pool.schema import Experience
|
||||
|
||||
|
||||
class BasePerfectJudge(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool:
|
||||
"""Determine whether the experience is perfect.
|
||||
|
||||
Args:
|
||||
exp (Experience): The experience to evaluate.
|
||||
serialized_req (str): The serialized request to compare against the experience's request.
|
||||
"""
|
||||
27
metagpt/exp_pool/perfect_judges/simple.py
Normal file
27
metagpt/exp_pool/perfect_judges/simple.py
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
"""Simple perfect judge."""
|
||||
|
||||
|
||||
from pydantic import ConfigDict
|
||||
|
||||
from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge
|
||||
from metagpt.exp_pool.schema import MAX_SCORE, Experience
|
||||
|
||||
|
||||
class SimplePerfectJudge(BasePerfectJudge):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool:
|
||||
"""Determine whether the experience is perfect.
|
||||
|
||||
Args:
|
||||
exp (Experience): The experience to evaluate.
|
||||
serialized_req (str): The serialized request to compare against the experience's request.
|
||||
|
||||
Returns:
|
||||
bool: True if the serialized request matches the experience's request and the experience's score is perfect, False otherwise.
|
||||
"""
|
||||
|
||||
if not exp.metric or not exp.metric.score:
|
||||
return False
|
||||
|
||||
return serialized_req == exp.req and exp.metric.score.val == MAX_SCORE
|
||||
|
|
@ -1,13 +1,16 @@
|
|||
"""Experience schema."""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from typing import Optional
|
||||
|
||||
from llama_index.core.schema import TextNode
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
MAX_SCORE = 10
|
||||
|
||||
DEFAULT_COLLECTION_NAME = "experience_pool"
|
||||
DEFAULT_SIMILARITY_TOP_K = 2
|
||||
|
||||
|
||||
class QueryType(str, Enum):
|
||||
"""Type of query experiences."""
|
||||
|
|
@ -59,7 +62,7 @@ class Experience(BaseModel):
|
|||
"""Experience."""
|
||||
|
||||
req: str = Field(..., description="")
|
||||
resp: Any = Field(..., description="The type is string/json/code.")
|
||||
resp: str = Field(..., description="The type is string/json/code.")
|
||||
metric: Optional[Metric] = Field(default=None, description="Metric.")
|
||||
exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.")
|
||||
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Experience scorers init."""
|
||||
"""Scorers init."""
|
||||
|
||||
from metagpt.exp_pool.scorers.base import ExperienceScorer
|
||||
from metagpt.exp_pool.scorers.base import BaseScorer
|
||||
from metagpt.exp_pool.scorers.simple import SimpleScorer
|
||||
|
||||
__all__ = ["ExperienceScorer", "SimpleScorer"]
|
||||
__all__ = ["BaseScorer", "SimpleScorer"]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""Experience Scorers."""
|
||||
"""Base scorer."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
|
@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict
|
|||
from metagpt.exp_pool.schema import Score
|
||||
|
||||
|
||||
class ExperienceScorer(BaseModel):
|
||||
class BaseScorer(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
@abstractmethod
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
"""Simple Scorer."""
|
||||
"""Simple scorer."""
|
||||
|
||||
import inspect
|
||||
import json
|
||||
|
|
@ -7,7 +7,7 @@ from typing import Any, Callable
|
|||
from pydantic import Field
|
||||
|
||||
from metagpt.exp_pool.schema import Score
|
||||
from metagpt.exp_pool.scorers.base import ExperienceScorer
|
||||
from metagpt.exp_pool.scorers.base import BaseScorer
|
||||
from metagpt.llm import LLM
|
||||
from metagpt.provider.base_llm import BaseLLM
|
||||
from metagpt.utils.common import CodeParser
|
||||
|
|
@ -54,7 +54,7 @@ Follow instructions, generate output and make sure it follows the **Constraint**
|
|||
"""
|
||||
|
||||
|
||||
class SimpleScorer(ExperienceScorer):
|
||||
class SimpleScorer(BaseScorer):
|
||||
llm: BaseLLM = Field(default_factory=LLM)
|
||||
|
||||
async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score:
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import re
|
||||
|
|
@ -10,8 +11,14 @@ from pydantic import model_validator
|
|||
|
||||
from metagpt.actions import Action
|
||||
from metagpt.actions.di.run_command import RunCommand
|
||||
from metagpt.exp_pool import exp_cache
|
||||
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
|
||||
from metagpt.logs import logger
|
||||
from metagpt.prompts.di.role_zero import CMD_PROMPT, ROLE_INSTRUCTION, JSON_REPAIR_PROMPT
|
||||
from metagpt.prompts.di.role_zero import (
|
||||
CMD_PROMPT,
|
||||
JSON_REPAIR_PROMPT,
|
||||
ROLE_INSTRUCTION,
|
||||
)
|
||||
from metagpt.roles import Role
|
||||
from metagpt.schema import AIMessage, Message, UserMessage
|
||||
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
|
||||
|
|
@ -21,8 +28,8 @@ from metagpt.tools.libs.editor import Editor
|
|||
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
|
||||
from metagpt.tools.tool_registry import register_tool
|
||||
from metagpt.utils.common import CodeParser
|
||||
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
|
||||
from metagpt.utils.report import ThoughtReporter
|
||||
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType
|
||||
|
||||
|
||||
@register_tool(include_functions=["ask_human", "reply_to_human"])
|
||||
|
|
@ -154,11 +161,37 @@ class RoleZero(Role):
|
|||
context = self.llm.format_msg(memory + [UserMessage(content=prompt)])
|
||||
# print(*context, sep="\n" + "*" * 5 + "\n")
|
||||
async with ThoughtReporter(enable_llm_stream=True):
|
||||
self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg)
|
||||
self.command_rsp = await self.llm_cached_aask(req=context, system_msgs=self.system_msg)
|
||||
self.rc.memory.add(AIMessage(content=self.command_rsp))
|
||||
|
||||
return True
|
||||
|
||||
@exp_cache(context_builder=RoleZeroContextBuilder(), req_serialize=lambda req: RoleZero._req_serialize(req))
|
||||
async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str]) -> str:
|
||||
return await self.llm.aask(req, system_msgs=system_msgs)
|
||||
|
||||
@staticmethod
|
||||
def _req_serialize(req: list[dict]) -> str:
|
||||
"""Serialize the request for database storage, ensuring it is a string.
|
||||
|
||||
This function deep copies the request and modifies the content of the last element
|
||||
to remove unnecessary sections, making the request more concise.
|
||||
"""
|
||||
|
||||
req_copy = copy.deepcopy(req)
|
||||
|
||||
last_content = req_copy[-1]["content"]
|
||||
last_content = RoleZeroContextBuilder.replace_content_between_markers(
|
||||
last_content, "# Data Structure", "# Current Plan", ""
|
||||
)
|
||||
last_content = RoleZeroContextBuilder.replace_content_between_markers(
|
||||
last_content, "# Example", "# Instruction", ""
|
||||
)
|
||||
|
||||
req_copy[-1]["content"] = last_content
|
||||
|
||||
return json.dumps(req_copy)
|
||||
|
||||
async def _act(self) -> Message:
|
||||
if self.use_fixed_sop:
|
||||
return await super()._act()
|
||||
|
|
@ -166,7 +199,7 @@ class RoleZero(Role):
|
|||
try:
|
||||
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
|
||||
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
|
||||
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
|
||||
except Exception as e:
|
||||
|
|
|
|||
|
|
@ -39,7 +39,7 @@ class NaiveSolver(BaseSolver):
|
|||
self.graph.topological_sort()
|
||||
for key in self.graph.execution_order:
|
||||
op = self.graph.nodes[key]
|
||||
await op.fill(self.context, self.llm, mode="root")
|
||||
await op.fill(req=self.context, llm=self.llm, mode="root")
|
||||
|
||||
|
||||
class TOTSolver(BaseSolver):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
"""class tools, including method inspection, class attributes, inheritance relationships, etc."""
|
||||
import inspect
|
||||
|
||||
|
||||
def check_methods(C, *methods):
|
||||
|
|
@ -17,25 +16,3 @@ def check_methods(C, *methods):
|
|||
else:
|
||||
return NotImplemented
|
||||
return True
|
||||
|
||||
|
||||
def get_class_name(func) -> str:
|
||||
"""Returns the class name of the object that a method belongs to.
|
||||
|
||||
- If `func` is a bound method or a class method, extracts the class name directly from the method.
|
||||
- Returns an empty string if it's a regular function or cannot determine the class.
|
||||
"""
|
||||
if inspect.ismethod(func):
|
||||
if inspect.isclass(func.__self__):
|
||||
return func.__self__.__name__
|
||||
|
||||
return func.__self__.__class__.__name__
|
||||
|
||||
if inspect.isfunction(func):
|
||||
qualname_parts = func.__qualname__.split(".")
|
||||
if len(qualname_parts) > 1:
|
||||
class_name = qualname_parts[-2]
|
||||
if class_name.isidentifier():
|
||||
return class_name
|
||||
|
||||
return ""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue