use llm cache to make exp_pool

This commit is contained in:
seehi 2024-07-08 10:09:36 +08:00
parent d902a6f18c
commit c624c0ffc7
41 changed files with 844 additions and 368 deletions

View file

@ -90,7 +90,7 @@ class Action(SerializationMixin, ContextMixin, BaseModel):
msgs = args[0]
context = "## History Messages\n"
context += "\n".join([f"{idx}: {i}" for idx, i in enumerate(reversed(msgs))])
return await self.node.fill(context=context, llm=self.llm)
return await self.node.fill(req=context, llm=self.llm)
async def run(self, *args, **kwargs):
"""Run action"""

View file

@ -18,6 +18,7 @@ from tenacity import retry, stop_after_attempt, wait_random_exponential
from metagpt.actions.action_outcls_registry import register_action_outcls
from metagpt.const import MARKDOWN_TITLE_PREFIX, USE_CONFIG_TIMEOUT
from metagpt.exp_pool import exp_cache
from metagpt.llm import BaseLLM
from metagpt.logs import logger
from metagpt.provider.postprocess.llm_output_postprocess import llm_output_postprocess
@ -465,9 +466,33 @@ class ActionNode:
return self
@classmethod
def deserialize_to_action_node(cls, serialized_data) -> "ActionNode":
"""Customized deserialization, it will be triggered when a perfect experience is found.
ActionNode cannot be serialized, it throws an error 'cannot pickle 'SSLContext' object'.
"""
class InstructContent:
def __init__(self, json_data):
self.json_data = json_data
def model_dump_json(self):
return self.json_data
action_node = cls(key="", expected_type=Type[str], instruction="", example="")
action_node.instruct_content = InstructContent(serialized_data)
return action_node
@exp_cache(
resp_serialize=lambda action_node: action_node.instruct_content.model_dump_json(),
resp_deserialize=lambda resp: ActionNode.deserialize_to_action_node(resp),
)
async def fill(
self,
context,
*,
req,
llm,
schema="json",
mode="auto",
@ -478,7 +503,7 @@ class ActionNode:
):
"""Fill the node(s) with mode.
:param context: Everything we should know when filling node.
:param req: Everything we should know when filling node.
:param llm: Large Language Model with pre-defined system message.
:param schema: json/markdown, determine example and output format.
- raw: free form text
@ -497,7 +522,7 @@ class ActionNode:
:return: self
"""
self.set_llm(llm)
self.set_context(context)
self.set_context(req)
if self.schema:
schema = self.schema

View file

@ -178,12 +178,12 @@ class WriteDesign(Action):
)
async def _new_system_design(self, context):
node = await DESIGN_API_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema)
node = await DESIGN_API_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
return node
async def _merge(self, prd_doc, system_design_doc):
context = NEW_REQ_TEMPLATE.format(old_design=system_design_doc.content, context=prd_doc.content)
node = await REFINED_DESIGN_NODE.fill(context=context, llm=self.llm, schema=self.prompt_schema)
node = await REFINED_DESIGN_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
system_design_doc.content = node.instruct_content.model_dump_json()
return system_design_doc

View file

@ -22,4 +22,4 @@ class GenerateQuestions(Action):
name: str = "GenerateQuestions"
async def run(self, context) -> ActionNode:
return await QUESTIONS.fill(context=context, llm=self.llm)
return await QUESTIONS.fill(req=context, llm=self.llm)

View file

@ -22,4 +22,4 @@ class PrepareInterview(Action):
name: str = "PrepareInterview"
async def run(self, context):
return await QUESTIONS.fill(context=context, llm=self.llm)
return await QUESTIONS.fill(req=context, llm=self.llm)

View file

@ -151,12 +151,12 @@ class WriteTasks(Action):
return task_doc
async def _run_new_tasks(self, context: str):
node = await PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
node = await PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
return node
async def _merge(self, system_design_doc, task_doc) -> Document:
context = NEW_REQ_TEMPLATE.format(context=system_design_doc.content, old_task=task_doc.content)
node = await REFINED_PM_NODE.fill(context, self.llm, schema=self.prompt_schema)
node = await REFINED_PM_NODE.fill(req=context, llm=self.llm, schema=self.prompt_schema)
task_doc.content = node.instruct_content.model_dump_json()
return task_doc

View file

@ -578,7 +578,7 @@ class WriteCodeAN(Action):
async def run(self, context):
self.llm.system_prompt = "You are an outstanding engineer and can implement any code"
return await WRITE_MOVE_NODE.fill(context=context, llm=self.llm, schema="json")
return await WRITE_MOVE_NODE.fill(req=context, llm=self.llm, schema="json")
async def main():

View file

@ -229,7 +229,7 @@ class WriteCodePlanAndChange(Action):
code=await self.get_old_codes(),
)
logger.info("Writing code plan and change..")
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(context=context, llm=self.llm, schema="json")
return await WRITE_CODE_PLAN_AND_CHANGE_NODE.fill(req=context, llm=self.llm, schema="json")
async def get_old_codes(self) -> str:
old_codes = await self.repo.srcs.get_all()

View file

@ -211,7 +211,7 @@ class WritePRD(Action):
context = CONTEXT_TEMPLATE.format(requirements=requirement, project_name=project_name)
exclude = [PROJECT_NAME.key] if project_name else []
node = await WRITE_PRD_NODE.fill(
context=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
req=context, llm=self.llm, exclude=exclude, schema=self.prompt_schema
) # schema=schema
return node
@ -238,7 +238,7 @@ class WritePRD(Action):
async def _is_bugfix(self, context: str) -> bool:
if not self.repo.code_files_exists():
return False
node = await WP_ISSUE_TYPE_NODE.fill(context, self.llm)
node = await WP_ISSUE_TYPE_NODE.fill(req=context, llm=self.llm)
return node.get("issue_type") == "BUG"
async def get_related_docs(self, req: Document, docs: list[Document]) -> list[Document]:
@ -248,14 +248,14 @@ class WritePRD(Action):
async def _is_related(self, req: Document, old_prd: Document) -> bool:
context = NEW_REQ_TEMPLATE.format(old_prd=old_prd.content, requirements=req.content)
node = await WP_IS_RELATIVE_NODE.fill(context, self.llm)
node = await WP_IS_RELATIVE_NODE.fill(req=context, llm=self.llm)
return node.get("is_relative") == "YES"
async def _merge(self, req: Document, related_doc: Document) -> Document:
if not self.project_name:
self.project_name = Path(self.project_path).name
prompt = NEW_REQ_TEMPLATE.format(requirements=req.content, old_prd=related_doc.content)
node = await REFINED_PRD_NODE.fill(context=prompt, llm=self.llm, schema=self.prompt_schema)
node = await REFINED_PRD_NODE.fill(req=prompt, llm=self.llm, schema=self.prompt_schema)
related_doc.content = node.instruct_content.model_dump_json()
await self._rename_workspace(node)
return related_doc

View file

@ -36,4 +36,4 @@ class WriteReview(Action):
name: str = "WriteReview"
async def run(self, context):
return await WRITE_REVIEW_NODE.fill(context=context, llm=self.llm, schema="json")
return await WRITE_REVIEW_NODE.fill(req=context, llm=self.llm, schema="json")

View file

@ -4,5 +4,9 @@ from metagpt.utils.yaml_model import YamlModel
class ExperiencePoolConfig(YamlModel):
enable_read: bool = Field(default=True, description="Enable to read from experience pool.")
enable_write: bool = Field(default=True, description="Enable to write to experience pool.")
enable_read: bool = Field(default=False, description="Enable to read from experience pool.")
enable_write: bool = Field(default=False, description="Enable to write to experience pool.")
persist_path: str = Field(default=".chroma_exp_data", description="The persist path for experience pool.")
init_exp: bool = Field(
default=False, description="Put some basic experiences associated with the roles into the experience pool."
)

View file

@ -0,0 +1,7 @@
"""Context builders init."""
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
from metagpt.exp_pool.context_builders.simple import SimpleContextBuilder
from metagpt.exp_pool.context_builders.role_zero import RoleZeroContextBuilder
__all__ = ["BaseContextBuilder", "SimpleContextBuilder", "RoleZeroContextBuilder"]

View file

@ -0,0 +1,52 @@
"""Base context builder."""
import re
from abc import ABC, abstractmethod
from typing import Any
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Experience
EXP_TEMPLATE = """Given the request: {req}, We can get the response: {resp}, Which scored: {score}."""
class BaseContextBuilder(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
exps: list[Experience] = []
@abstractmethod
async def build(self, *args, **kwargs) -> Any:
"""Build context from parameters."""
def format_exps(self) -> str:
"""Format experiences into a numbered list of strings."""
result = []
for i, exp in enumerate(self.exps, start=1):
result.append(f"{i}. " + EXP_TEMPLATE.format(req=exp.req, resp=exp.resp, score=exp.metric.score.val))
return "\n".join(result)
@staticmethod
def replace_content_between_markers(text: str, start_marker: str, end_marker: str, new_content: str) -> str:
"""Replace the content between `start_marker` and `end_marker` in the text with `new_content`.
Args:
text (str): The original text.
new_content (str): The new content to replace the old content.
start_marker (str): The marker indicating the start of the content to be replaced, such as '# Example'.
end_marker (str): The marker indicating the end of the content to be replaced, such as '# Instruction'.
Returns:
str: The text with the content replaced.
"""
pattern = re.compile(f"({start_marker}\n)(.*?)(\n{end_marker})", re.DOTALL)
def replacement(match):
return f"{match.group(1)}{new_content}\n{match.group(3)}"
replaced_text = pattern.sub(replacement, text)
return replaced_text

View file

@ -0,0 +1,26 @@
"""RoleZero context builder."""
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
class RoleZeroContextBuilder(BaseContextBuilder):
async def build(self, *args, **kwargs) -> list[dict]:
"""Builds the context by updating the req with formatted experiences.
If there are no experiences, retains the original examples in req, otherwise replaces the examples with the formatted experiences.
"""
req = kwargs.get("req", [])
if not req:
return req
exps_str = self.format_exps()
if not exps_str:
return req
req[-1]["content"] = self.replace_example_content(req[-1].get("content", ""), exps_str)
return req
def replace_example_content(self, text: str, new_example_content: str) -> str:
return self.replace_content_between_markers(text, "# Example", "# Instruction", new_example_content)

View file

@ -0,0 +1,24 @@
"""Simple context builder."""
from metagpt.exp_pool.context_builders.base import BaseContextBuilder
SIMPLE_CONTEXT_TEMPLATE = """
{req}
### Experiences
-----
{exps}
-----
## Instruction
Consider **Experiences** to generate a better answer.
"""
class SimpleContextBuilder(BaseContextBuilder):
async def build(self, *args, **kwargs) -> str:
req = kwargs.get("req", "")
exps = self.format_exps()
return SIMPLE_CONTEXT_TEMPLATE.format(req=req, exps=exps) if exps else req

View file

@ -2,18 +2,19 @@
import asyncio
import functools
import inspect
import json
from typing import Any, Callable, Optional, TypeVar
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.config2 import config
from metagpt.exp_pool.context_builders import BaseContextBuilder, SimpleContextBuilder
from metagpt.exp_pool.manager import ExperienceManager, exp_manager
from metagpt.exp_pool.perfect_judges import BasePerfectJudge, SimplePerfectJudge
from metagpt.exp_pool.schema import Experience, Metric, QueryType, Score
from metagpt.exp_pool.scorers import ExperienceScorer, SimpleScorer
from metagpt.exp_pool.scorers import BaseScorer, SimpleScorer
from metagpt.logs import logger
from metagpt.utils.async_helper import NestAsyncio
from metagpt.utils.exceptions import handle_exception
from metagpt.utils.reflection import get_class_name
ReturnType = TypeVar("ReturnType")
@ -21,42 +22,64 @@ ReturnType = TypeVar("ReturnType")
def exp_cache(
_func: Optional[Callable[..., ReturnType]] = None,
query_type: QueryType = QueryType.SEMANTIC,
scorer: Optional[ExperienceScorer] = None,
manager: Optional[ExperienceManager] = None,
pass_exps_to_func: bool = False,
scorer: Optional[BaseScorer] = None,
perfect_judge: Optional[BasePerfectJudge] = None,
context_builder: Optional[BaseContextBuilder] = None,
req_serialize: Optional[Callable[..., str]] = None,
resp_serialize: Optional[Callable[..., str]] = None,
resp_deserialize: Optional[Callable[[str], Any]] = None,
tag: Optional[str] = None,
):
"""Decorator to get a perfect experience, otherwise, it executes the function, and create a new experience.
This can be applied to both synchronous and asynchronous functions.
1. This can be applied to both synchronous and asynchronous functions.
2. The function must have a `req` parameter, and it must be provided as a keyword argument.
3. If `config.exp_pool.enable_read` is False, the decorator will just directly execute the function.
Args:
_func: Just to make the decorator more flexible, for example, it can be used directly with @exp_cache by default, without the need for @exp_cache().
query_type: The type of query to be used when fetching experiences.
scorer: Evaluate experience. Default SimpleScorer.
manager: How to fetch, evaluate and save experience, etc. Default exp_manager.
pass_exps_to_func: To control whether imperfect experiences are passed to the function, if True, the func must have a parameter named 'exps'.
manager: How to fetch, evaluate and save experience, etc. Default to `exp_manager`.
scorer: Evaluate experience. Default to `SimpleScorer()`.
perfect_judge: Determines if an experience is perfect. Defaults to `SimplePerfectJudge()`.
context_builder: Build the context from exps and the function parameters. Default to `SimpleContextBuilder()`.
req_serialize: Serializes the request for storage. Defaults to `lambda req: str(req)`.
resp_serialize: Serializes the function's return value for storage. Defaults to `lambda resp: str(resp)`.
resp_deserialize: Deserializes the stored response back to the function's return value. Defaults to `lambda resp: resp`.
tag: An optional tag for the experience. Default to `ClassName.method_name` or `function_name`.
"""
def decorator(func: Callable[..., ReturnType]) -> Callable[..., ReturnType]:
if not config.exp_pool.enable_read:
return func
@functools.wraps(func)
async def get_or_create(args: Any, kwargs: Any) -> ReturnType:
logger.info("exp_cache is enabled.")
handler = ExpCacheHandler(
func=func,
args=args,
kwargs=kwargs,
query_type=query_type,
exp_manager=manager,
exp_scorer=scorer,
pass_exps_to_func=pass_exps_to_func,
exp_perfect_judge=perfect_judge,
context_builder=context_builder,
req_serialize=req_serialize,
resp_serialize=resp_serialize,
resp_deserialize=resp_deserialize,
tag=tag,
)
await handler.fetch_experiences(query_type)
if exp := handler.get_one_perfect_experience():
await handler.fetch_experiences()
if exp := await handler.get_one_perfect_exp():
return exp
await handler.execute_function()
await handler.process_experience()
return handler._result
return handler._raw_resp
return ExpCacheHandler.choose_wrapper(func, get_or_create)
@ -69,39 +92,59 @@ class ExpCacheHandler(BaseModel):
func: Callable
args: Any
kwargs: Any
query_type: QueryType = QueryType.SEMANTIC
exp_manager: Optional[ExperienceManager] = None
exp_scorer: Optional[ExperienceScorer] = None
pass_exps_to_func: bool = False
exp_scorer: Optional[BaseScorer] = None
exp_perfect_judge: Optional[BasePerfectJudge] = None
context_builder: Optional[BaseContextBuilder] = None
req_serialize: Optional[Callable[..., str]] = None
resp_serialize: Optional[Callable[..., str]] = None
resp_deserialize: Optional[Callable[[str], Any]] = None
tag: Optional[str] = None
_exps: list[Experience] = None
_result: Any = None
_req: str = ""
_resp: str = ""
_raw_resp: Any = None
_score: Score = None
_req: str = None
@model_validator(mode="after")
def initialize(self):
if self.exp_manager is None:
self.exp_manager = exp_manager
self._validate_params()
if self.exp_scorer is None:
self.exp_scorer = SimpleScorer()
self.exp_manager = self.exp_manager or exp_manager
self.exp_scorer = self.exp_scorer or SimpleScorer()
self.exp_perfect_judge = self.exp_perfect_judge or SimplePerfectJudge()
self.context_builder = self.context_builder or SimpleContextBuilder()
self.req_serialize = self.req_serialize or (lambda resp: str(resp))
self.resp_serialize = self.resp_serialize or (lambda resp: str(resp))
self.resp_deserialize = self.resp_deserialize or (lambda resp: resp)
self.tag = self.tag or self._generate_tag()
self._req = self.generate_req_identifier(self.func, *self.args, **self.kwargs)
self._req = self.req_serialize(self.kwargs["req"])
return self
async def fetch_experiences(self, query_type: QueryType):
async def fetch_experiences(self):
"""Fetch experiences by query_type."""
self._exps = await self.exp_manager.query_exps(self._req, query_type=query_type)
self._exps = await self.exp_manager.query_exps(self._req, query_type=self.query_type, tag=self.tag)
def get_one_perfect_experience(self) -> Optional[Experience]:
"""Get a potentially perfect experience."""
return self.exp_manager.extract_one_perfect_exp(self._exps)
async def get_one_perfect_exp(self) -> Optional[Any]:
"""Get a potentially perfect experience, and resolve resp."""
for exp in self._exps:
if await self.exp_perfect_judge.is_perfect_exp(exp, self._req, *self.args, **self.kwargs):
logger.info(f"Get one perfect experience: {exp.req[:20]}...")
return self.resp_deserialize(exp.resp)
return None
async def execute_function(self):
"""Execute the function, and save the result."""
self._result = await self._execute_function()
"""Execute the function, and save resp."""
self._raw_resp = await self._execute_function()
self._resp = self.resp_serialize(self._raw_resp)
@handle_exception
async def process_experience(self):
@ -110,41 +153,21 @@ class ExpCacheHandler(BaseModel):
Evaluates and saves experience.
Use `handle_exception` to ensure robustness, do not stop subsequent operations.
"""
await self.evaluate_experience()
self.save_experience()
async def evaluate_experience(self):
"""Evaluate the experience, and save the score."""
self._score = await self.exp_scorer.evaluate(self.func, self._result, self.args, self.kwargs)
self._score = await self.exp_scorer.evaluate(self.func, self._resp, self.args, self.kwargs)
def save_experience(self):
"""Save the new experience."""
exp = Experience(req=self._req, resp=self._result, metric=Metric(score=self._score))
exp = Experience(req=self._req, resp=self._resp, tag=self.tag, metric=Metric(score=self._score))
self.exp_manager.create_exp(exp)
@classmethod
def generate_req_identifier(cls, func, *args, **kwargs) -> str:
"""Generate a unique request identifier for any given function and its arguments.
Serializing args and kwargs into JSON strings and replacing ',' with '~' and ':' with '!'.
Return Example:
SimpleClass.test_method@[1~2]@{"c"!3}
"""
cls_name = get_class_name(func)
func_name = f"{cls_name}.{func.__name__}" if cls_name else func.__name__
if cls_name and args and inspect.isfunction(func):
args = args[1:]
args = cls._serialize_and_replace(args)
kwargs = cls._serialize_and_replace(kwargs)
return f"{func_name}@{args}@{kwargs}"
@staticmethod
def choose_wrapper(func, wrapped_func):
"""Choose how to run wrapped_func based on whether the function is asynchronous."""
@ -158,25 +181,31 @@ class ExpCacheHandler(BaseModel):
return async_wrapper if asyncio.iscoroutinefunction(func) else sync_wrapper
@classmethod
def _serialize_and_replace(cls, data):
json_str = json.dumps(data)
return json_str.replace(", ", "~").replace(": ", "!")
def _validate_params(self):
if "req" not in self.kwargs:
raise ValueError("`req` must be provided as a keyword argument.")
def _generate_tag(self) -> str:
"""Generates a tag for the self.func.
"ClassName.method_name" if the first argument is a class instance, otherwise just "function_name".
"""
if self.args and hasattr(self.args[0], "__class__"):
cls_name = type(self.args[0]).__name__
return f"{cls_name}.{self.func.__name__}"
return self.func.__name__
async def _build_context(self) -> str:
self.context_builder.exps = self._exps
return await self.context_builder.build(*self.args, **self.kwargs)
async def _execute_function(self):
if self.pass_exps_to_func:
return await self._execute_function_with_exps()
self.kwargs["req"] = await self._build_context()
return await self._execute_function_without_exps()
async def _execute_function_without_exps(self):
if asyncio.iscoroutinefunction(self.func):
return await self.func(*self.args, **self.kwargs)
return self.func(*self.args, **self.kwargs)
async def _execute_function_with_exps(self):
if asyncio.iscoroutinefunction(self.func):
return await self.func(*self.args, **self.kwargs, exps=self._exps)
return self.func(*self.args, **self.kwargs, exps=self._exps)

View file

@ -1,13 +1,22 @@
"""Experience Manager."""
from typing import Optional
from llama_index.vector_stores.chroma import ChromaVectorStore
from pydantic import BaseModel, ConfigDict, model_validator
from metagpt.config2 import Config, config
from metagpt.exp_pool.schema import MAX_SCORE, Experience, QueryType
from metagpt.exp_pool.schema import (
DEFAULT_COLLECTION_NAME,
DEFAULT_SIMILARITY_TOP_K,
EntryType,
Experience,
Metric,
QueryType,
Score,
)
from metagpt.logs import logger
from metagpt.rag.engines import SimpleEngine
from metagpt.rag.schema import ChromaRetrieverConfig, LLMRankerConfig
from metagpt.strategy.experience_retriever import ENGINEER_EXAMPLE, TL_EXAMPLE
from metagpt.utils.exceptions import handle_exception
@ -27,14 +36,33 @@ class ExperienceManager(BaseModel):
@model_validator(mode="after")
def initialize(self):
if self.storage is None:
self.storage = SimpleEngine.from_objs(
retriever_configs=[
ChromaRetrieverConfig(collection_name="experience_pool", persist_path=".chroma_exp_data")
],
ranker_configs=[LLMRankerConfig()],
)
retriever_configs = [
ChromaRetrieverConfig(
persist_path=self.config.exp_pool.persist_path,
collection_name=DEFAULT_COLLECTION_NAME,
similarity_top_k=DEFAULT_SIMILARITY_TOP_K,
)
]
ranker_configs = [LLMRankerConfig()]
self.storage = SimpleEngine.from_objs(retriever_configs=retriever_configs, ranker_configs=ranker_configs)
self.init_exp_pool()
return self
@handle_exception
def init_exp_pool(self):
if not self.config.exp_pool.init_exp:
return
if self._has_exps():
return
self._init_teamleader_exps()
self._init_engineer2_exps()
logger.info("`init_exp_pool` done.")
@handle_exception
def create_exp(self, exp: Experience):
"""Adds an experience to the storage if writing is enabled.
@ -74,39 +102,26 @@ class ExperienceManager(BaseModel):
return exps
def extract_one_perfect_exp(self, exps: list[Experience]) -> Optional[Experience]:
"""Extracts the first 'perfect' experience from a list of experiences.
def _has_exps(self) -> bool:
vector_store: ChromaVectorStore = self.storage._retriever._vector_store
Args:
exps (list[Experience]): The experiences to evaluate.
return bool(vector_store._get(limit=1, where={}).ids)
Returns:
Optional[Experience]: The first perfect experience if found, otherwise None.
"""
for exp in exps:
if self.is_perfect_exp(exp):
return exp
def _init_exp(self, req: str, resp: str, tag: str, metric: Metric = None):
exp = Experience(
req=req,
resp=resp,
entry_type=EntryType.MANUAL,
tag=tag,
metric=metric or Metric(score=Score(val=9, reason="Manual")),
)
self.create_exp(exp)
return None
def _init_teamleader_exps(self):
self._init_exp(req=TL_EXAMPLE, resp=TL_EXAMPLE, tag="TeamLeader.llm_cached_aask")
@staticmethod
def is_perfect_exp(exp: Experience) -> bool:
"""Determines if an experience is considered 'perfect'.
Args:
exp (Experience): The experience to evaluate.
Returns:
bool: True if the experience is manually entered, otherwise False.
"""
if not exp:
return False
# TODO: need more metrics
if exp.metric and exp.metric.score.val == MAX_SCORE:
return True
return False
def _init_engineer2_exps(self):
self._init_exp(req=ENGINEER_EXAMPLE, resp=ENGINEER_EXAMPLE, tag="Engineer2.llm_cached_aask")
exp_manager = ExperienceManager()

View file

@ -0,0 +1,6 @@
"""Perfect judges init."""
from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge
from metagpt.exp_pool.perfect_judges.simple import SimplePerfectJudge
__all__ = ["BasePerfectJudge", "SimplePerfectJudge"]

View file

@ -0,0 +1,20 @@
"""Base perfect judge."""
from abc import ABC, abstractmethod
from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Experience
class BasePerfectJudge(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod
async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool:
"""Determine whether the experience is perfect.
Args:
exp (Experience): The experience to evaluate.
serialized_req (str): The serialized request to compare against the experience's request.
"""

View file

@ -0,0 +1,27 @@
"""Simple perfect judge."""
from pydantic import ConfigDict
from metagpt.exp_pool.perfect_judges.base import BasePerfectJudge
from metagpt.exp_pool.schema import MAX_SCORE, Experience
class SimplePerfectJudge(BasePerfectJudge):
model_config = ConfigDict(arbitrary_types_allowed=True)
async def is_perfect_exp(self, exp: Experience, serialized_req: str, *args, **kwargs) -> bool:
"""Determine whether the experience is perfect.
Args:
exp (Experience): The experience to evaluate.
serialized_req (str): The serialized request to compare against the experience's request.
Returns:
bool: True if the serialized request matches the experience's request and the experience's score is perfect, False otherwise.
"""
if not exp.metric or not exp.metric.score:
return False
return serialized_req == exp.req and exp.metric.score.val == MAX_SCORE

View file

@ -1,13 +1,16 @@
"""Experience schema."""
from enum import Enum
from typing import Any, Optional
from typing import Optional
from llama_index.core.schema import TextNode
from pydantic import BaseModel, Field
MAX_SCORE = 10
DEFAULT_COLLECTION_NAME = "experience_pool"
DEFAULT_SIMILARITY_TOP_K = 2
class QueryType(str, Enum):
"""Type of query experiences."""
@ -59,7 +62,7 @@ class Experience(BaseModel):
"""Experience."""
req: str = Field(..., description="")
resp: Any = Field(..., description="The type is string/json/code.")
resp: str = Field(..., description="The type is string/json/code.")
metric: Optional[Metric] = Field(default=None, description="Metric.")
exp_type: ExperienceType = Field(default=ExperienceType.SUCCESS, description="The type of experience.")
entry_type: EntryType = Field(default=EntryType.AUTOMATIC, description="Type of entry: Manual or Automatic.")

View file

@ -1,6 +1,6 @@
"""Experience scorers init."""
"""Scorers init."""
from metagpt.exp_pool.scorers.base import ExperienceScorer
from metagpt.exp_pool.scorers.base import BaseScorer
from metagpt.exp_pool.scorers.simple import SimpleScorer
__all__ = ["ExperienceScorer", "SimpleScorer"]
__all__ = ["BaseScorer", "SimpleScorer"]

View file

@ -1,6 +1,6 @@
"""Experience Scorers."""
"""Base scorer."""
from abc import abstractmethod
from abc import ABC, abstractmethod
from typing import Any, Callable
from pydantic import BaseModel, ConfigDict
@ -8,7 +8,7 @@ from pydantic import BaseModel, ConfigDict
from metagpt.exp_pool.schema import Score
class ExperienceScorer(BaseModel):
class BaseScorer(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
@abstractmethod

View file

@ -1,4 +1,4 @@
"""Simple Scorer."""
"""Simple scorer."""
import inspect
import json
@ -7,7 +7,7 @@ from typing import Any, Callable
from pydantic import Field
from metagpt.exp_pool.schema import Score
from metagpt.exp_pool.scorers.base import ExperienceScorer
from metagpt.exp_pool.scorers.base import BaseScorer
from metagpt.llm import LLM
from metagpt.provider.base_llm import BaseLLM
from metagpt.utils.common import CodeParser
@ -54,7 +54,7 @@ Follow instructions, generate output and make sure it follows the **Constraint**
"""
class SimpleScorer(ExperienceScorer):
class SimpleScorer(BaseScorer):
llm: BaseLLM = Field(default_factory=LLM)
async def evaluate(self, func: Callable, result: Any, args: tuple = None, kwargs: dict = None) -> Score:

View file

@ -1,5 +1,6 @@
from __future__ import annotations
import copy
import inspect
import json
import re
@ -10,8 +11,14 @@ from pydantic import model_validator
from metagpt.actions import Action
from metagpt.actions.di.run_command import RunCommand
from metagpt.exp_pool import exp_cache
from metagpt.exp_pool.context_builders import RoleZeroContextBuilder
from metagpt.logs import logger
from metagpt.prompts.di.role_zero import CMD_PROMPT, ROLE_INSTRUCTION, JSON_REPAIR_PROMPT
from metagpt.prompts.di.role_zero import (
CMD_PROMPT,
JSON_REPAIR_PROMPT,
ROLE_INSTRUCTION,
)
from metagpt.roles import Role
from metagpt.schema import AIMessage, Message, UserMessage
from metagpt.strategy.experience_retriever import DummyExpRetriever, ExpRetriever
@ -21,8 +28,8 @@ from metagpt.tools.libs.editor import Editor
from metagpt.tools.tool_recommend import BM25ToolRecommender, ToolRecommender
from metagpt.tools.tool_registry import register_tool
from metagpt.utils.common import CodeParser
from metagpt.utils.repair_llm_raw_output import RepairType, repair_llm_raw_output
from metagpt.utils.report import ThoughtReporter
from metagpt.utils.repair_llm_raw_output import repair_llm_raw_output, RepairType
@register_tool(include_functions=["ask_human", "reply_to_human"])
@ -154,11 +161,37 @@ class RoleZero(Role):
context = self.llm.format_msg(memory + [UserMessage(content=prompt)])
# print(*context, sep="\n" + "*" * 5 + "\n")
async with ThoughtReporter(enable_llm_stream=True):
self.command_rsp = await self.llm.aask(context, system_msgs=self.system_msg)
self.command_rsp = await self.llm_cached_aask(req=context, system_msgs=self.system_msg)
self.rc.memory.add(AIMessage(content=self.command_rsp))
return True
@exp_cache(context_builder=RoleZeroContextBuilder(), req_serialize=lambda req: RoleZero._req_serialize(req))
async def llm_cached_aask(self, *, req: list[dict], system_msgs: list[str]) -> str:
return await self.llm.aask(req, system_msgs=system_msgs)
@staticmethod
def _req_serialize(req: list[dict]) -> str:
"""Serialize the request for database storage, ensuring it is a string.
This function deep copies the request and modifies the content of the last element
to remove unnecessary sections, making the request more concise.
"""
req_copy = copy.deepcopy(req)
last_content = req_copy[-1]["content"]
last_content = RoleZeroContextBuilder.replace_content_between_markers(
last_content, "# Data Structure", "# Current Plan", ""
)
last_content = RoleZeroContextBuilder.replace_content_between_markers(
last_content, "# Example", "# Instruction", ""
)
req_copy[-1]["content"] = last_content
return json.dumps(req_copy)
async def _act(self) -> Message:
if self.use_fixed_sop:
return await super()._act()
@ -166,7 +199,7 @@ class RoleZero(Role):
try:
commands = CodeParser.parse_code(block=None, lang="json", text=self.command_rsp)
commands = json.loads(repair_llm_raw_output(output=commands, req_keys=[None], repair_type=RepairType.JSON))
except json.JSONDecodeError as e:
except json.JSONDecodeError:
commands = await self.llm.aask(msg=JSON_REPAIR_PROMPT.format(json_data=self.command_rsp))
commands = json.loads(CodeParser.parse_code(block=None, lang="json", text=commands))
except Exception as e:

View file

@ -39,7 +39,7 @@ class NaiveSolver(BaseSolver):
self.graph.topological_sort()
for key in self.graph.execution_order:
op = self.graph.nodes[key]
await op.fill(self.context, self.llm, mode="root")
await op.fill(req=self.context, llm=self.llm, mode="root")
class TOTSolver(BaseSolver):

View file

@ -1,5 +1,4 @@
"""class tools, including method inspection, class attributes, inheritance relationships, etc."""
import inspect
def check_methods(C, *methods):
@ -17,25 +16,3 @@ def check_methods(C, *methods):
else:
return NotImplemented
return True
def get_class_name(func) -> str:
"""Returns the class name of the object that a method belongs to.
- If `func` is a bound method or a class method, extracts the class name directly from the method.
- Returns an empty string if it's a regular function or cannot determine the class.
"""
if inspect.ismethod(func):
if inspect.isclass(func.__self__):
return func.__self__.__name__
return func.__self__.__class__.__name__
if inspect.isfunction(func):
qualname_parts = func.__qualname__.split(".")
if len(qualname_parts) > 1:
class_name = qualname_parts[-2]
if class_name.isidentifier():
return class_name
return ""